Source code for neat.trees.netree

"""
File contains:

    - :class:`Kernel`
    - :class:`NETNode`
    - :class:`NET`

Author: W. Wybo
"""


import numpy as np
import matplotlib.pyplot as pl

from .stree import STree, SNode

import copy


[docs]class Kernel(object): """ Implements a kernel as a superposition of exponentials: .. math:: k(t) = \sum_n c_n e^{ - a_n t} Kernels can be added and subtracted, as this class overloads the __add__ and __subtract__ functions. They can be evaluated as a function of time by calling the object with a time array. They can be evaluated in the Fourrier domain with `Kernel.ft` Parameters ---------- kernel: dict, float, neat.Kernel, tuple or list If dict, has the form {'a': `np.array`, 'c': `np.array`}. If float, sets `c` single exponential prefactor and assumes `a` is 1 kHz. If `neat.Kernel`, copies the object. If tuple or list, sets 'a' as first element and 'c' as last element. Attributes ---------- a: np.array of float or complex The exponential coefficients (kHz) c: np.array of float or complex The exponential prefactors """ def __init__(self, kernel): # set kernel time scales and exponential prefactors if isinstance(kernel, dict): self.a = copy.deepcopy(kernel['a']) self.c = copy.deepcopy(kernel['c']) elif isinstance(kernel, float) or isinstance(kernel, int): self.a = np.array([1.]) self.c = np.array([kernel]).astype(float) elif isinstance(kernel, Kernel): self.a = copy.deepcopy(kernel.a) self.c = copy.deepcopy(kernel.c) else: self.a = copy.deepcopy(kernel[0]) self.c = copy.deepcopy(kernel[1]) if isinstance(self.a, float): self.a = np.array([self.a]) elif not isinstance(self.a, np.ndarray): self.a = np.array(self.a) if isinstance(self.c, float): self.c = np.array([self.c]) elif not isinstance(self.c, np.ndarray): self.c = np.array(self.c) def __getitem__(self, ind): if ind == 0: return self.a elif ind == 1: return self.c elif ind == 'a': return self.a elif ind == 'c': return self.c elif ind == 'alphas': return self.a elif ind == 'gammas': return self.c else: raise IndexError('Index should be \'0\' or \'1\'') def __call__(self, t_arr): return np.dot(np.exp(-t_arr[:,np.newaxis] * self.a[np.newaxis,:]), \ self.c[:,np.newaxis]).flatten().real def __add__(self, kernel): if kernel.a.shape[0] == self.a.shape[0] and \ np.allclose(kernel.a, self.a): a = copy.copy(self.a) c = kernel.c + self.c else: a = np.concatenate((self.a, kernel.a)) c = np.concatenate((self.c, kernel.c)) return Kernel((a, c)) def __sub__(self, kernel): if kernel.a.shape[0] == self.a.shape[0] and \ np.allclose(kernel.a, self.a): a = copy.copy(self.a) c = self.c - kernel.c else: a = np.concatenate((self.a, kernel.a)) c = np.concatenate((self.c, -kernel.c)) return Kernel((a, c)) def getKBar(self): """ The total surface under the kernel """ return np.sum(self.c / self.a).real def setKBar(self, kk): raise AttributeError('`k_bar` is a read-only attribute, adjust attribute `c` ' + \ 'by multiplying with a factor to change `k_bar`') k_bar = property(getKBar, setKBar) def __str__(self, as_timescale=False): if as_timescale: return 't = ' + np.array2string(1./self.a, precision=4, max_line_width=1000) + '\n' + \ 'c = ' + np.array2string(self.c, precision=4, max_line_width=1000) else: return 'a = ' + np.array2string(self.a, precision=4, max_line_width=1000) + '\n' + \ 'c = ' + np.array2string(self.c, precision=4, max_line_width=1000)
[docs] def t(self, t_arr): """ Evaluates the kernel in the time domain Parameters ---------- t_arr: np.array of float the time array at which the kernel is evaluated Returns ------- np.array of float the temporal kernel """ return self(t_arr)
[docs] def ft(self, s_arr): """ Evaluates the kernel in the Fourrier domain Parameters ---------- s_arr: np.array of complex The frequencies (Hz) at which the kernel is to be evaluated Returns ------- np.array of complex The Fourrier transform of the kernel """ return np.sum(self.c[:,None]*1e3 / (self.a[:,None]*1e3 + s_arr[None,:]), 0)
[docs]class NETNode(SNode): """ Node associated with `neat.NET`. Attributes ---------- loc_inds: list of int The inidices of locations which the node integrates newloc_inds: list of int The locations for which the node is the most local component to integrate them z_kernel: `neat.Kernel` The impedance kernel with which the node integrates inputs z_bar: float The steady state impedance associated with the impedance kernel """ def __init__(self, index, loc_inds, newloc_inds=[], z_kernel=None): super().__init__(index) # location indices that node integrates self.loc_inds = loc_inds self.newloc_inds = newloc_inds # kernel associated with node self.z_kernel = z_kernel def __str__(self): if self.parent_node is not None: return 'NETNode ' + str(self.index) + \ ', loc inds: ' + str(self.loc_inds) + \ ', newloc inds: ' + str(self.newloc_inds) + \ ', parent: ' + str(self.parent_node.index) + \ ', z_bar (MOhm) = ' + str(self.z_bar) else: return 'NETNode ' + str(self.index) + \ ', loc inds: ' + str(self.loc_inds) + \ ', newloc inds: ' + str(self.newloc_inds) + \ ', parent: None' \ ', z_bar (MOhm) = ' + str(self.z_bar) def setZKernel(self, z_kernel): self._z_kernel = Kernel(z_kernel) def getZKernel(self): return self._z_kernel def getZ(self): return self._z_kernel.k_bar z_kernel = property(getZKernel, setZKernel) z_bar = property(getZ, setZKernel) def __contains__(self, loc_ind): return loc_ind in self.loc_inds def _setCompartmentData(self, node_list, z_root_list, z_comp_list, Iz=5.): node_inds = [node.index for node in node_list if node != None] z_root = np.array(z_root_list) z_comp = np.array(z_comp_list) comp_inds = np.where(z_comp / z_root > Iz)[0] # store the relevant quantities self._z_root = z_root[comp_inds] self._z_comp = z_comp[comp_inds] self._node_inds = [node_inds[ind] for ind in comp_inds] def _setTentativeCompartments(self, comps): self._comps = comps def _setSharedRootInd(self, ind): self._root_ind = self._node_inds.index(ind)
[docs]class NET(STree): """ Abstract tree class that implements the Neural Evaluation Tree (Wybo et al., 2019), representing the spatial voltage as a number of voltage components present at different spatial scales. """ def __init__(self, root=None): super().__init__(root) def __str__(self): string = 'NET\n' for node in self: string += ' > ' + str(node) + '\n' return string def _createCorrespondingNode(self, node_index): """ Creates a node with the given index corresponding to the tree class. Parameters ---------- node_index: int index of the new node """ return NETNode(node_index, [])
[docs] def getLocInds(self, sroot=None): """ Get the indices of the locations a subtree integrates Parameters ---------- sroot: `neat.NETNode`, int or None Root of the subtree, or index of the root. If ``None``, subtree is the whole tree. Returns ------- loc_inds: indices of locations """ if isinstance(sroot, int): sroot = self[sroot] elif sroot is None: sroot = self.root return sroot.loc_inds
[docs] def getLeafLocNode(self, loc_ind): """ Get the node for which ``loc_ind`` is a new location Parameters ---------- loc_ind: int index of the location Returns ------- :obj:`NETNode` """ for node in self: if loc_ind in node.newloc_inds: return node
[docs] def setNewLocInds(self): """ Set the new location indices in a tree """ for node in self: cloc_inds = set() for cnode in node.child_nodes: cloc_inds = cloc_inds.union(set(cnode.loc_inds)) node.newloc_inds = list(set(node.loc_inds) - cloc_inds)
[docs] def getReducedTree(self, loc_inds, indexing='NET eval'): """ Construct a reduced tree where only the locations index by ``loc_inds'' are retained Parameters ---------- loc_inds : iterable of ints the indices of the locations that are to be retained indexing : 'NET eval' or 'locs' if 'NET eval', indexing of ``NETNode.loc_inds`` will be taken to be the indices of locations for which the full NET is evaluated. Otherwise will be indices of the input ``loc_inds`` """ loc_inds_newtree = list({loc_ind for loc_ind in loc_inds \ if loc_ind in self.root}) if loc_inds_newtree: new_root = NETNode(0, loc_inds_newtree, z_kernel=self.root.z_kernel) new_tree = NET(new_root) for cnode in self.root.child_nodes: if cnode is not None: self._constructReducedTree(cnode, loc_inds_newtree, new_root, new_tree) new_tree.setNewLocInds() if indexing == 'NET eval': return new_tree else: for node in new_tree: # node.loc_inds = [np.where(loc_inds == ind)[0][0] for ind in node.loc_inds] # node.loc_inds = sum([np.where(loc_inds == ind)[0].tolist() for ind in set(node.loc_inds)], []) node.loc_inds = sum([np.where(loc_inds == ind)[0].tolist() for ind in node.loc_inds], []) new_tree.setNewLocInds() return new_tree else: return None
def _constructReducedTree(self, node, loc_inds, node_newtree, new_tree): loc_inds_subtree = list({loc_ind for loc_ind in loc_inds \ if loc_ind in node}) if len(loc_inds_subtree) > 0: if loc_inds_subtree == loc_inds: node_newtree.z_kernel += node.z_kernel else: newnode_newtree = NETNode(len(new_tree), loc_inds_subtree, z_kernel=node.z_kernel) new_tree.addNodeWithParent(newnode_newtree, node_newtree) node_newtree = newnode_newtree for cnode in node.child_nodes: if cnode is not None: self._constructReducedTree(cnode, loc_inds_subtree, node_newtree, new_tree) # def matchInputImpedance(self, z_input): # assert imp_mat.shape[0] == imp_mat.shape[1] # assert imp_mat.shape[0] == len(self.root.loc_inds) # for node in self: # if self.isLeaf(node): # if len(node.loc_inds) == 1: # p_imp = self.calcTotalImpedance(node.parent_node) # node.z_kernel.c *= (z_input[node.locs_inds[0]] - p_imp) / node.z_kernel.k_bar # else: # for loc_ind in node.loc_inds: # new_node = NETNode(len(tree), [loc_ind]) # self.addNodeWithParent
[docs] def calcTotalImpedance(self, node): """ Compute the total impedance associated with a node. I.e. the sum of all impedances on the path from node to root Parameters ---------- node : :class:`SNode` Returns ------- float total impedance """ return np.sum([node_.z_bar for node_ in self.pathToRoot(node)])
def calcTotalKernel(self, node): """ Compute the total impedance kernel associated with a node. I.e. the sum of all impedance kernels on the path from node to root Parameters ---------- node : :class:`SNode` Returns ------- :class:`Kernel` """ z_k = copy.deepcopy(node.z_kernel) if node.parent_node is not None: for pn in self.pathToRoot(node.parent_node): z_k += pn.z_kernel return z_k
[docs] def calcIZ(self, loc_inds): """ compute I_Z between any pair of locations in ``loc_inds`` Parameters ---------- loc_inds : iterable of ints the indices of locations between which I_Z has to be evaluated Returns ------- float or dict of tuple : float Returns a float if the number of location indices is two, otherwise a dictionary with location pairs (smallest is listed first) as keys and I_Z values as values """ Iz_dict = {} for ii, loc_ind0 in enumerate(loc_inds): for jj, loc_ind1 in enumerate(loc_inds): if jj < ii: net_red = self.getReducedTree([loc_ind0, loc_ind1]) key = (loc_ind0, loc_ind1) if loc_ind0 < loc_ind1 \ else (loc_ind1, loc_ind0) n0 = net_red.getLeafLocNode(loc_ind0) z0 = n0.z_bar if n0 != net_red.root else 0. n1 = net_red.getLeafLocNode(loc_ind1) z1 = n1.z_bar if n1 != net_red.root else 0. Iz_dict[key] = (z0 + z1) / (2. * net_red.root.z_bar) else: break if len(loc_inds) == 2: return list(Iz_dict.values())[0] else: return Iz_dict
[docs] def calcIZMatrix(self): """ compute the Iz matrix for all locations present in the tree Returns ------- np.ndarray of float The Iz matrix """ z_mat = self.calcImpedanceMatrix() z_in = np.diag(z_mat) return (z_in[:,np.newaxis] + z_in[np.newaxis,:]) / (2. * z_mat) - 1.
[docs] def calcImpedanceMatrix(self): """ Compute the impedance matrix approximation associated with the NET Returns ------- np.ndarray (ndim = 2) the impedance matrix approximation """ return self.calcImpMat()
[docs] def calcImpMat(self): """ Compute the impedance matrix approximation associated with the NET Returns ------- np.ndarray (ndim = 2) the impedance matrix approximation """ n_loc = len(self.root.loc_inds) loc_map = {loc_ind: map_ind for map_ind, loc_ind in enumerate(self.root.loc_inds)} z_mat = np.zeros((n_loc, n_loc)) self._addNodeToImpMat(self.root, z_mat, loc_map) return z_mat
def _addNodeToImpMat(self, node, z_mat, loc_map): inds = np.array([loc_map[loc_ind] for loc_ind in node.loc_inds]) z_mat[np.tile(inds, len(inds)), np.repeat(inds, len(inds))] += node.z_bar for cnode in node.child_nodes: self._addNodeToImpMat(cnode, z_mat, loc_map)
[docs] def getCompartmentalization(self, Iz, returntype='node index'): """ Returns a compartmentalization for the NET tree where each pair of compartments is separated by an Iz of at least ``Iz``. The compartmentalization is coded as a list of list, each sublist representing a the nodes closest to the root associated with the compartment. Parameters ---------- Iz : float the minimum Iz separating the compartments returntype: str ('node index', 'node') either returns the node indices or the node objects Returns ------- list of lists the compartments """ self._computeTentativeCompartments(Iz=Iz) # determine the nodes that contain the eventual compartments and # remove the rest net = copy.deepcopy(self) self._removeNonCompartments(net.leafs, net=net) # get the compartment nodes comp_nodes = self._setCompartmentsLeafbased(net.leafs, net) if returntype == 'node index': comp_inds = [] for node in comp_nodes: inds = node._comps[node._root_ind] comp_inds.append(inds) return comp_inds elif returntype == 'node': comp_nodes_ = [] for node in comp_nodes: inds = node._comps[node.rootind] comp_nodes_.append([self[ind] for ind in inds]) return comp_nodes_
def _setCompartmentsLeafbased(self, leafs, net): comp_nodes = [] for ii, leaf in enumerate(leafs): root, _, _ = net.sisterLeafs(leaf) new_leaf = leaf comp_bool = False while root.index in new_leaf._node_inds: comp_bool = True old_leaf = new_leaf new_leaf = old_leaf.parent_node if comp_bool: # mark the old_leaf as the compartment indexing node old_leaf._setSharedRootInd(root.index) comp_nodes.append(old_leaf) return comp_nodes def _removeNonCompartments(self, leafs, net=None, n_count=0): if net is None: warnings.warn('Modifying original tree') net = self # count number of leafs n_leaf = len(leafs) leaf = leafs[0] # shuffle list del leafs[0] leafs = leafs + [leaf] # leaf is not highest order common_root, sister_leafs, corresponding_children = net.sisterLeafs(leaf) if common_root.index == 0: pass if len(sister_leafs) == len(corresponding_children): # find the compartments with maximal size and closest to common root sleafs_comp = [] sinds_comp = [] for ii, leaf in enumerate(sister_leafs): newleaf = leaf comp_bool = False while common_root.index in newleaf._node_inds: comp_bool = True oldleaf = newleaf newleaf = oldleaf.parent_node if comp_bool: sinds_comp.append(ii) sleafs_comp.append(oldleaf) # delete the leafs that are not in compartments if len(sleafs_comp) <= 1 and not net.isRoot(common_root): # if at most one is compartment, we retain only the largest one ind = np.argmax([self.calcTotalImpedance(node) \ for node in sister_leafs]) newleaf = sister_leafs[ind] for ii, cnode in enumerate(corresponding_children): if ii != ind: net.softRemoveNode(cnode) leafs.remove(sister_leafs[ii]) else: # if more can be compartments, we retain all those for ii, cnode in enumerate(corresponding_children): if not ii in sinds_comp: net.softRemoveNode(cnode) leafs.remove(sister_leafs[ii]) if n_leaf != len(leafs) and len(leafs) > 0: self._removeNonCompartments(leafs, net=net, n_count=0) elif n_count < len(leafs): self._removeNonCompartments(leafs, net=net, n_count=n_count+1) elif n_count < len(leafs) and len(leafs) > 0: self._removeNonCompartments(leafs, net=net, n_count=n_count+1) def _computeTentativeCompartments(self, Iz=5.): # set the prerequisite impedances self._setCompartmentInfo(Iz=Iz) # set the tentative compartments for node in self: self._setCompartmentsRelative(node) def _setCompartmentInfo(self, Iz=5., node=None, z_p=0., node_list=[], z_root_list=[], z_comp_list=[]): if node != None: # list of dependent impedances try: z_root_list.append(z_root_list[-1] + z_p ) except IndexError: z_root_list.append(z_p) # list of independent impedances z_comp_list.append(0.) z_comp_list = [node.z_bar + z_c for z_c in z_comp_list] # list or nodes node_list.append(node.parent_node) # store the compartment information node._setCompartmentData(node_list, z_root_list, z_comp_list, Iz=Iz) else: node = self.root # compute node impedance self.root._setCompartmentData([], [], [], Iz=0.) # recurse to child nodes for cnode in node.child_nodes: self._setCompartmentInfo(Iz=Iz, node=cnode, z_p=node.z_bar, node_list=copy.copy(node_list), z_root_list=copy.copy(z_root_list), z_comp_list=copy.copy(z_comp_list)) def _setCompartmentsRelative(self, node): z_target = node._z_comp node_comps = [] for z_t in z_target: comp = [node.index] node_comps.append(comp) node._setTentativeCompartments(node_comps) def computeCondRescale(self, gs): assert len(gs) == len(self.root.loc_inds) # array for storing shunt factors sfs = np.ones_like(gs) # counter for recursion algorithm for node in self: node.counter = 0 # recursive algorithm to compute shunt factors self._sweep(self.leafs[0], self.leafs[1:], sfs, gs) # clean for node in self: node.counter = 0 return sfs def _sweep(self, node, leafs, sfs, gs): node.counter += 1 if node.counter >= len(node.child_nodes): if not self.isRoot(node): # compute the rescaled shunt factors denom = 1. + node.z_bar * np.sum(sfs[node.loc_inds] * gs[node.loc_inds]) sfs[node.loc_inds] = sfs[node.loc_inds] / denom # further recursion self._sweep(node.parent_node, leafs, sfs, gs) else: self._sweep(leafs[0], leafs[1:], sfs, gs) def improveInputImpedance(self, z_mat): nmaxind = np.max([n.index for n in self]) for node in self.getNodes(): if len(node.loc_inds) == 1: ind = node.loc_inds[0] # recompute the kernel of this single loc layer if node.parent_node is not None: p_k = self.calcTotalKernel(node.parent_node) else: p_k = Kernel((node.z_kernel.a, np.zeros_like(node.z_kernel.a))) f_z = (z_mat[ind,ind] - p_k.k_bar) / node.z_bar node.z_kernel.c *= f_z elif len(node.newloc_inds) > 0: z_k_approx = self.calcTotalKernel(node) # add new input nodes for the nodes that don't have one tbr_inds = [] for ind in node.newloc_inds: nmaxind += 1 f_z = (z_mat[ind,ind] - z_k_approx.k_bar) if np.abs(f_z) > 1e-7: f_z /= node.z_bar z_k_real = Kernel(dict(a=node.z_kernel.a, c=node.z_kernel.c*f_z)) # add node newnode = NETNode(nmaxind, [ind], z_kernel=z_k_real) newnode.newloc_inds = [ind] self.addNodeWithParent(newnode, node) tbr_inds.append(ind) for ind in tbr_inds: node.newloc_inds.remove(ind) # empty the new indices node.newloc_inds = [] self.setNewLocInds()
[docs] def plotDendrogram(self, ax, plotargs={}, labelargs={}, textargs={}, incolors={}, inlabels={}, nodelabels={}, cs_comp={}, cmap=None, z_max=None, add_scalebar=True): """ Generate a dendrogram of the NET Parameters ---------- ax: :class:`matplotlib.axes` the axes object in which the plot will be made plotargs : dict (string : value) keyword args for the matplotlib plot function, specifies the line properties of the dendrogram labelargs : dict (string : value) keyword args for the matplotlib plot function, specifies the marker properties for the node points. Or dict with keys node indices, and with values dicts with keyword args for the matplotlib function that specify the marker properties for specific node points. The entry under key -1 specifies the properties for all nodes not explicitly in the keys. textargs : dict (string : value) keyword args for matplotlib textproperties incolors : dict (int : string) dict with locinds as keys and colors as values inlabels : dict (int : string) dict with locinds as keys and label strings as values nodelabels: dict (int: string) or None labels of the nodes. If None, nodes are named by default according to their location indices. If empty dict, no labels are added. cs_comp : dict (int : float) dict with node inds as keys and compartment colors as values z_max: float or None specifies the y-scale. If None, the scale is computed from ``self`` add_scalebar: bool whether or not to add a scale bar """ if cs_comp: # compute the compartmental colormap if necessary arr = np.array([list(cs_comp.values())]) max_cs = np.max(arr) min_cs = np.min(arr) norm_cs = (max_cs - min_cs) * (1. + 1./100.) for key, val in cs_comp.items(): cs_comp[key] = (cs_comp[key] - min_cs) / norm_cs if cmap is None: cmap = pl.get_cmap('jet') cs_comp['cm'] = cmap Z = [[0,0],[0,0]] levels = np.linspace(min_cs, max_cs, 100) CS3 = pl.contourf(Z, levels, cmap=cmap) # get the number of leafs to determine the dendrogram spacing rnode = self.root n_branch = self.degreeOfNode(rnode) l_spacing = np.linspace(0., 1., n_branch+1) # determine input inpedances to fix the y scale if z_max == None: z_dict = {} for node in self.nodes: for ind in node.loc_inds: try: z_dict[ind] += node.z_bar except KeyError: z_dict[ind] = node.z_bar z_max = max(z_dict.values()) # plot the dendrogram self._expandDendrogram(rnode, 0.5, 0., l_spacing, z_max, ax, plotargs=plotargs, labelargs=labelargs, textargs=textargs, incolors=incolors, inlabels=inlabels, nodelabels=nodelabels, cs_comp=cs_comp) # limits ax.set_ylim((0.0, 1.2*z_max)) ax.set_xlim((0.,1.)) # scalebar if add_scalebar: sblength = np.around(z_max // 5, -2) if sblength < .1: sblength += np.around(z_max % 5, -1) if sblength < .1: sblength += np.around(z_max // 5, 0) sbwidth = plotargs['lw']*3 if 'lw' in plotargs else 3. sbtsize = textargs['size'] if 'size' in textargs else 'small' ax.plot([0.,0.], [0., sblength], 'k-', lw=sbwidth) ax.annotate(r'%.0f M$\Omega$'%sblength, xy=(0., sblength/2.), xytext=(-0.04, sblength/2.), size=sbtsize, rotation=90, ha='center', va='center') ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) ax.axison = False return z_max
def _expandDendrogram(self, node, x0, y0, l_spacing, z_max, ax, plotargs={}, labelargs={}, textargs={}, incolors={}, inlabels={}, nodelabels={}, cs_comp={}): # check if part of compartment if cs_comp: if node.index in list(cs_comp.keys()): plotargs = copy.deepcopy(plotargs) plotargs['color'] = cs_comp['cm'](cs_comp[node.index]) # impedance of layer ynew = y0 + node.z_bar # plot vertical connection line ax.vlines(x0, y0, ynew, **plotargs) # get the child nodes for recursion l0 = 0 for i, cnode in enumerate(node.child_nodes): # attribute space on xaxis deg = self.degreeOfNode(cnode) l1 = l0 + deg # new quantities xnew = (l_spacing[l0] + l_spacing[l1]) / 2. # horizontal connection line limits if i == 0: xnew0 = xnew if i == len(node.child_nodes)-1: xnew1 = xnew # recursion self._expandDendrogram(cnode, xnew, ynew, l_spacing[l0:l1+1], z_max, ax, plotargs=plotargs, labelargs=labelargs, textargs=textargs, incolors=incolors, inlabels=inlabels, nodelabels=nodelabels, cs_comp=cs_comp) # next index l0 = l1 # plot horizontal connection line if l0 > 0: ax.hlines(ynew, xnew0, xnew1, **plotargs) # add label and maybe text annotation to node if node.index in labelargs: ax.plot([x0], [ynew], **labelargs[node.index]) elif -1 in labelargs: ax.plot([x0], [ynew], **labelargs[-1]) else: try: ax.plot([x0], [ynew], **labelargs) except TypeError as e: pass if textargs: if nodelabels != None: if node.index in nodelabels: if labelargs == {}: ax.plot([x0], [ynew], **nodelabels[node.index][1]) ax.annotate(nodelabels[node.index][0], xy=(x0, ynew), xytext=(x0+0.04, ynew+z_max*0.04), bbox=dict(boxstyle='round', ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8)), **textargs) else: ax.annotate(nodelabels[node.index], xy=(x0, ynew), xytext=(x0+0.04, ynew+z_max*0.04), bbox=dict(boxstyle='round', ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8)), **textargs) else: ax.annotate(r'$N='+''.join([str(ind) for ind in node.loc_inds])+'$', xy=(x0, ynew), xytext=(x0+0.04, ynew+z_max*0.04), bbox=dict(boxstyle='round', ec=(1., 0.5, 0.5), fc=(1., 0.8, 0.8)), **textargs) # add input label if self.isLeaf(node): if inlabels != None: lwidth = plotargs['lw'] if 'lw' in plotargs else 1. ax.vlines(x0, ynew+z_max*0.04, z_max*1.1, lw=lwidth, linestyle=':', color='k') if node.loc_inds[0] in incolors: bboxdict = dict(boxstyle='round', ec=incolors[node.loc_inds[0]], fc=incolors[node.loc_inds[0]], alpha=0.5) else: bboxdict = dict(boxstyle='round', ec=(0.5, 0.5, 1.), fc=(0.8, 0.8, 1.)) if node.loc_inds[0] in inlabels: textstr = inlabels[node.loc_inds[0]] else: textstr = r'$'+str(node.loc_inds[0])+'$' ax.annotate(textstr, xy=(x0, z_max*1.1), xytext=(x0, z_max*1.14), ha='center', bbox=bboxdict, **textargs)