Source code for neat.trees.netree

# -*- coding: utf-8 -*-
#
# netree.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

import numpy as np
import matplotlib.pyplot as pl

from .stree import STree, SNode
from ..tools.kernelextraction import Kernel


import copy
import warnings


[docs] class NETNode(SNode): """ Node associated with `neat.NET`. Attributes ---------- loc_idxs: list of int The inidices of locations which the node integrates newloc_idxs: list of int The locations for which the node is the most local component to integrate them z_kernel: `neat.Kernel` The impedance kernel with which the node integrates inputs z_bar: float The steady state impedance associated with the impedance kernel """ def __init__(self, index, loc_idxs, newloc_idxs=[], z_kernel=None): super().__init__(index) # location indices that node integrates self.loc_idxs = loc_idxs self.newloc_idxs = newloc_idxs # kernel associated with node self.z_kernel = z_kernel def __str__(self): if self.parent_node is not None: return ( "NETNode " + str(self.index) + ", loc inds: " + str(self.loc_idxs) + ", newloc inds: " + str(self.newloc_idxs) + ", parent: " + str(self.parent_node.index) + ", z_bar (MOhm) = " + str(self.z_bar) ) else: return "NETNode " + str(self.index) + ", loc inds: " + str( self.loc_idxs ) + ", newloc inds: " + str( self.newloc_idxs ) + ", parent: None" ", z_bar (MOhm) = " + str( self.z_bar ) def set_z_kernel(self, z_kernel): self._z_kernel = Kernel(z_kernel) def get_z_kernel(self): return self._z_kernel def get_z(self): return self._z_kernel.k_bar z_kernel = property(get_z_kernel, set_z_kernel) z_bar = property(get_z, set_z_kernel) def __contains__(self, loc_idx): return loc_idx in self.loc_idxs def _set_compartment_data(self, node_list, z_root_list, z_comp_list, Iz=5.0): node_inds = [node.index for node in node_list if node != None] z_root = np.array(z_root_list) z_comp = np.array(z_comp_list) comp_inds = np.where(z_comp / z_root > Iz)[0] # store the relevant quantities self._z_root = z_root[comp_inds] self._z_comp = z_comp[comp_inds] self._node_inds = [node_inds[ind] for ind in comp_inds] def _set_tentative_compartments(self, comps): self._comps = comps def _set_shared_root_idx(self, ind): self._root_ind = self._node_inds.index(ind) def __str__(self, with_parent=True, with_morph_info=False): node_str = super().__str__(with_parent=with_parent) node_str += ( f" --- " f" loc inds: {str(self.loc_idxs)}" f", newloc inds: {str(self.newloc_idxs)}" f", z_bar = {self.z_bar} MOhm" ) return node_str def _get_repr_dict(self): repr_dict = super()._get_repr_dict() repr_dict.update( { "loc_idxs": self.loc_idxs, "newloc_idxs": self.newloc_idxs, "z_kernel": repr(self.z_kernel), } ) return repr_dict def __repr__(self): return repr(self._get_repr_dict())
[docs] class NET(STree): """ Abstract tree class that implements the Neural Evaluation Tree (Wybo et al., 2019), representing the spatial voltage as a number of voltage components present at different spatial scales. """ def __init__(self, root=None): super().__init__(root)
[docs] def create_corresponding_node(self, node_index): """ Creates a node with the given index corresponding to the tree class. Parameters ---------- node_index: int index of the new node """ return NETNode(node_index, [])
[docs] def get_loc_idxs(self, sroot=None): """ Get the indices of the locations a subtree integrates Parameters ---------- sroot: `neat.NETNode`, int or None Root of the subtree, or index of the root. If ``None``, subtree is the whole tree. Returns ------- loc_idxs: indices of locations """ if isinstance(sroot, int): sroot = self[sroot] elif sroot is None: sroot = self.root return sroot.loc_idxs
[docs] def get_leaf_loc_node(self, loc_idx): """ Get the node for which ``loc_idx`` is a new location Parameters ---------- loc_idx: int index of the location Returns ------- :obj:`NETNode` """ for node in self: if loc_idx in node.newloc_idxs: return node
[docs] def set_new_loc_idxs(self): """ Set the new location indices in a tree """ for node in self: cloc_idxs = set() for cnode in node.child_nodes: cloc_idxs = cloc_idxs.union(set(cnode.loc_idxs)) node.newloc_idxs = list(set(node.loc_idxs) - cloc_idxs)
[docs] def get_reduced_tree(self, loc_idxs, indexing="NET eval"): """ Construct a reduced tree where only the locations index by ``loc_idxs'' are retained Parameters ---------- loc_idxs : iterable of ints the indices of the locations that are to be retained indexing : 'NET eval' or 'locs' if 'NET eval', indexing of ``NETNode.loc_idxs`` will be taken to be the indices of locations for which the full NET is evaluated. Otherwise will be indices of the input ``loc_idxs`` """ loc_idxs_newtree = list( {loc_idx for loc_idx in loc_idxs if loc_idx in self.root} ) if loc_idxs_newtree: new_root = NETNode(0, loc_idxs_newtree, z_kernel=self.root.z_kernel) new_tree = NET(new_root) for cnode in self.root.child_nodes: if cnode is not None: self._construct_reduced_tree( cnode, loc_idxs_newtree, new_root, new_tree ) new_tree.set_new_loc_idxs() if indexing == "NET eval": return new_tree else: for node in new_tree: # node.loc_idxs = [np.where(loc_idxs == ind)[0][0] for ind in node.loc_idxs] # node.loc_idxs = sum([np.where(loc_idxs == ind)[0].tolist() for ind in set(node.loc_idxs)], []) node.loc_idxs = sum( [ np.where(loc_idxs == ind)[0].tolist() for ind in node.loc_idxs ], [], ) new_tree.set_new_loc_idxs() return new_tree else: return None
def _construct_reduced_tree(self, node, loc_idxs, node_newtree, new_tree): loc_idxs_subtree = list({loc_idx for loc_idx in loc_idxs if loc_idx in node}) if len(loc_idxs_subtree) > 0: if loc_idxs_subtree == loc_idxs: node_newtree.z_kernel += node.z_kernel else: newnode_newtree = NETNode( len(new_tree), loc_idxs_subtree, z_kernel=node.z_kernel ) new_tree.add_node_with_parent(newnode_newtree, node_newtree) node_newtree = newnode_newtree for cnode in node.child_nodes: if cnode is not None: self._construct_reduced_tree( cnode, loc_idxs_subtree, node_newtree, new_tree ) # def matchInputImpedance(self, z_input): # assert imp_mat.shape[0] == imp_mat.shape[1] # assert imp_mat.shape[0] == len(self.root.loc_idxs) # for node in self: # if self.is_leaf(node): # if len(node.loc_idxs) == 1: # p_imp = self.calc_total_impedance(node.parent_node) # node.z_kernel.c *= (z_input[node.locs_inds[0]] - p_imp) / node.z_kernel.k_bar # else: # for loc_idx in node.loc_idxs: # new_node = NETNode(len(tree), [loc_idx]) # self.add_node_with_parent
[docs] def calc_total_impedance(self, node): """ Compute the total impedance associated with a node. I.e. the sum of all impedances on the path from node to root Parameters ---------- node : :class:`SNode` Returns ------- float total impedance """ return np.sum([node_.z_bar for node_ in self.path_to_root(node)])
def calc_total_kernel(self, node): """ Compute the total impedance kernel associated with a node. I.e. the sum of all impedance kernels on the path from node to root Parameters ---------- node : :class:`SNode` Returns ------- :class:`Kernel` """ z_k = copy.deepcopy(node.z_kernel) if node.parent_node is not None: for pn in self.path_to_root(node.parent_node): z_k += pn.z_kernel return z_k
[docs] def calc_i_z(self, loc_idxs): """ compute I_Z between any pair of locations in ``loc_idxs`` Parameters ---------- loc_idxs : iterable of ints the indices of locations between which I_Z has to be evaluated Returns ------- float or dict of tuple : float Returns a float if the number of location indices is two, otherwise a dictionary with location pairs (smallest is listed first) as keys and I_Z values as values """ Iz_dict = {} for ii, loc_idx0 in enumerate(loc_idxs): for jj, loc_idx1 in enumerate(loc_idxs): if jj < ii: net_red = self.get_reduced_tree([loc_idx0, loc_idx1]) key = ( (loc_idx0, loc_idx1) if loc_idx0 < loc_idx1 else (loc_idx1, loc_idx0) ) n0 = net_red.get_leaf_loc_node(loc_idx0) z0 = n0.z_bar if n0 != net_red.root else 0.0 n1 = net_red.get_leaf_loc_node(loc_idx1) z1 = n1.z_bar if n1 != net_red.root else 0.0 Iz_dict[key] = (z0 + z1) / (2.0 * net_red.root.z_bar) else: break if len(loc_idxs) == 2: return list(Iz_dict.values())[0] else: return Iz_dict
[docs] def calc_i_z_matrix(self): """ compute the Iz matrix for all locations present in the tree Returns ------- np.ndarray of float The Iz matrix """ z_mat = self.calc_impedance_matrix() z_in = np.diag(z_mat) return (z_in[:, np.newaxis] + z_in[np.newaxis, :]) / (2.0 * z_mat) - 1.0
[docs] def calc_impedance_matrix(self): """ Compute the impedance matrix approximation associated with the NET Returns ------- np.ndarray (ndim = 2) the impedance matrix approximation """ n_loc = len(self.root.loc_idxs) loc_map = { loc_idx: map_ind for map_ind, loc_idx in enumerate(self.root.loc_idxs) } z_mat = np.zeros((n_loc, n_loc)) self._add_node_to_imp_mat(self.root, z_mat, loc_map) return z_mat
def _add_node_to_imp_mat(self, node, z_mat, loc_map): inds = np.array([loc_map[loc_idx] for loc_idx in node.loc_idxs]) z_mat[np.tile(inds, len(inds)), np.repeat(inds, len(inds))] += node.z_bar for cnode in node.child_nodes: self._add_node_to_imp_mat(cnode, z_mat, loc_map)
[docs] def calc_compartmentalization(self, Iz, returntype="node index"): """ Returns a compartmentalization for the NET tree where each pair of compartments is separated by an Iz of at least ``Iz``. The compartmentalization is coded as a list of list, each sublist representing a the nodes closest to the root associated with the compartment. Parameters ---------- Iz : float the minimum Iz separating the compartments returntype: str ('node index', 'node') either returns the node indices or the node objects Returns ------- list of lists the compartments """ self._compute_tentative_compartments(Iz=Iz) # determine the nodes that contain the eventual compartments and # remove the rest net = copy.deepcopy(self) self._remove_non_compartments(net.leafs, net=net) # get the compartment nodes comp_nodes = self._set_compartments_leafbased(net.leafs, net) if returntype == "node index": comp_inds = [] for node in comp_nodes: inds = node._comps[node._root_ind] comp_inds.append(inds) return comp_inds elif returntype == "node": comp_nodes_ = [] for node in comp_nodes: inds = node._comps[node.rootind] comp_nodes_.append([self[ind] for ind in inds]) return comp_nodes_
def _set_compartments_leafbased(self, leafs, net): comp_nodes = [] for ii, leaf in enumerate(leafs): root, _, _ = net.sister_leafs(leaf) new_leaf = leaf comp_bool = False while root.index in new_leaf._node_inds: comp_bool = True old_leaf = new_leaf new_leaf = old_leaf.parent_node if comp_bool: # mark the old_leaf as the compartment indexing node old_leaf._set_shared_root_idx(root.index) comp_nodes.append(old_leaf) return comp_nodes def _remove_non_compartments(self, leafs, net=None, n_count=0): if net is None: warnings.warn("Modifying original tree") net = self # count number of leafs n_leaf = len(leafs) leaf = leafs[0] # shuffle list del leafs[0] leafs = leafs + [leaf] # leaf is not highest order common_root, sister_leafs, corresponding_children = net.sister_leafs(leaf) if common_root.index == 0: pass if len(sister_leafs) == len(corresponding_children): # find the compartments with maximal size and closest to common root sleafs_comp = [] sinds_comp = [] for ii, leaf in enumerate(sister_leafs): newleaf = leaf comp_bool = False while common_root.index in newleaf._node_inds: comp_bool = True oldleaf = newleaf newleaf = oldleaf.parent_node if comp_bool: sinds_comp.append(ii) sleafs_comp.append(oldleaf) # delete the leafs that are not in compartments if len(sleafs_comp) <= 1 and not net.is_root(common_root): # if at most one is compartment, we retain only the largest one ind = np.argmax( [self.calc_total_impedance(node) for node in sister_leafs] ) newleaf = sister_leafs[ind] for ii, cnode in enumerate(corresponding_children): if ii != ind: net.soft_remove_node(cnode) leafs.remove(sister_leafs[ii]) else: # if more can be compartments, we retain all those for ii, cnode in enumerate(corresponding_children): if not ii in sinds_comp: net.soft_remove_node(cnode) leafs.remove(sister_leafs[ii]) if n_leaf != len(leafs) and len(leafs) > 0: self._remove_non_compartments(leafs, net=net, n_count=0) elif n_count < len(leafs): self._remove_non_compartments(leafs, net=net, n_count=n_count + 1) elif n_count < len(leafs) and len(leafs) > 0: self._remove_non_compartments(leafs, net=net, n_count=n_count + 1) def _compute_tentative_compartments(self, Iz=5.0): # set the prerequisite impedances self._set_compartment_info(Iz=Iz) # set the tentative compartments for node in self: self._set_compartments_relative(node) def _set_compartment_info( self, Iz=5.0, node=None, z_p=0.0, node_list=[], z_root_list=[], z_comp_list=[] ): if node != None: # list of dependent impedances try: z_root_list.append(z_root_list[-1] + z_p) except IndexError: z_root_list.append(z_p) # list of independent impedances z_comp_list.append(0.0) z_comp_list = [node.z_bar + z_c for z_c in z_comp_list] # list or nodes node_list.append(node.parent_node) # store the compartment information node._set_compartment_data(node_list, z_root_list, z_comp_list, Iz=Iz) else: node = self.root # compute node impedance self.root._set_compartment_data([], [], [], Iz=0.0) # recurse to child nodes for cnode in node.child_nodes: self._set_compartment_info( Iz=Iz, node=cnode, z_p=node.z_bar, node_list=copy.copy(node_list), z_root_list=copy.copy(z_root_list), z_comp_list=copy.copy(z_comp_list), ) def _set_compartments_relative(self, node): z_target = node._z_comp node_comps = [] for z_t in z_target: comp = [node.index] node_comps.append(comp) node._set_tentative_compartments(node_comps) def compute_cond_rescale(self, gs): assert len(gs) == len(self.root.loc_idxs) # array for storing shunt factors sfs = np.ones_like(gs) # counter for recursion algorithm for node in self: node.counter = 0 # recursive algorithm to compute shunt factors self._sweep(self.leafs[0], self.leafs[1:], sfs, gs) # clean for node in self: node.counter = 0 return sfs def _sweep(self, node, leafs, sfs, gs): node.counter += 1 if node.counter >= len(node.child_nodes): if not self.is_root(node): # compute the rescaled shunt factors denom = 1.0 + node.z_bar * np.sum( sfs[node.loc_idxs] * gs[node.loc_idxs] ) sfs[node.loc_idxs] = sfs[node.loc_idxs] / denom # further recursion self._sweep(node.parent_node, leafs, sfs, gs) else: self._sweep(leafs[0], leafs[1:], sfs, gs) def improve_input_resistance(self, z_mat): nmaxind = np.max([n.index for n in self]) for node in self.get_nodes(): if len(node.loc_idxs) == 1: ind = node.loc_idxs[0] # recompute the kernel of this single loc layer if node.parent_node is not None: p_k = self.calc_total_kernel(node.parent_node) else: p_k = Kernel((node.z_kernel.a, np.zeros_like(node.z_kernel.a))) f_z = (z_mat[ind, ind] - p_k.k_bar) / node.z_bar node.z_kernel.c *= f_z elif len(node.newloc_idxs) > 0: z_k_approx = self.calc_total_kernel(node) # add new input nodes for the nodes that don't have one tbr_inds = [] for ind in node.newloc_idxs: nmaxind += 1 f_z = z_mat[ind, ind] - z_k_approx.k_bar if np.abs(f_z) > 1e-7: f_z /= node.z_bar z_k_real = Kernel( dict(a=node.z_kernel.a, c=node.z_kernel.c * f_z) ) # add node newnode = NETNode(nmaxind, [ind], z_kernel=z_k_real) newnode.newloc_idxs = [ind] self.add_node_with_parent(newnode, node) tbr_inds.append(ind) for ind in tbr_inds: node.newloc_idxs.remove(ind) # empty the new indices node.newloc_idxs = [] self.set_new_loc_idxs()
[docs] def plot_dendrogram( self, ax, plotargs={}, labelargs={}, textargs={}, incolors={}, inlabels={}, nodelabels={}, cs_comp={}, cmap=None, z_max=None, add_scalebar=True, ): """ Generate a dendrogram of the NET Parameters ---------- ax: :class:`matplotlib.axes` the axes object in which the plot will be made plotargs : dict (string : value) keyword args for the matplotlib plot function, specifies the line properties of the dendrogram labelargs : dict (string : value) keyword args for the matplotlib plot function, specifies the marker properties for the node points. Or dict with keys node indices, and with values dicts with keyword args for the matplotlib function that specify the marker properties for specific node points. The entry under key -1 specifies the properties for all nodes not explicitly in the keys. textargs : dict (string : value) keyword args for matplotlib textproperties incolors : dict (int : string) dict with locinds as keys and colors as values inlabels : dict (int : string) dict with locinds as keys and label strings as values nodelabels: dict (int: string) or None labels of the nodes. If None, nodes are named by default according to their location indices. If empty dict, no labels are added. cs_comp : dict (int : float) dict with node inds as keys and compartment colors as values z_max: float or None specifies the y-scale. If None, the scale is computed from ``self`` add_scalebar: bool whether or not to add a scale bar """ if cs_comp: # compute the compartmental colormap if necessary arr = np.array([list(cs_comp.values())]) max_cs = np.max(arr) min_cs = np.min(arr) norm_cs = (max_cs - min_cs) * (1.0 + 1.0 / 100.0) for key, val in cs_comp.items(): cs_comp[key] = (cs_comp[key] - min_cs) / norm_cs if cmap is None: cmap = pl.get_cmap("jet") cs_comp["cm"] = cmap Z = [[0, 0], [0, 0]] levels = np.linspace(min_cs, max_cs, 100) CS3 = pl.contourf(Z, levels, cmap=cmap) # get the number of leafs to determine the dendrogram spacing rnode = self.root n_branch = self.degree_of_node(rnode) l_spacing = np.linspace(0.0, 1.0, n_branch + 1) # determine input inpedances to fix the y scale if z_max == None: z_dict = {} for node in self.nodes: for ind in node.loc_idxs: try: z_dict[ind] += node.z_bar except KeyError: z_dict[ind] = node.z_bar z_max = max(z_dict.values()) # plot the dendrogram self._expand_dendrogram( rnode, 0.5, 0.0, l_spacing, z_max, ax, plotargs=plotargs, labelargs=labelargs, textargs=textargs, incolors=incolors, inlabels=inlabels, nodelabels=nodelabels, cs_comp=cs_comp, ) # limits ax.set_ylim((0.0, 1.2 * z_max)) ax.set_xlim((0.0, 1.0)) # scalebar if add_scalebar: sblength = np.around(z_max // 5, -2) if sblength < 0.1: sblength += np.around(z_max % 5, -1) if sblength < 0.1: sblength += np.around(z_max // 5, 0) sbwidth = plotargs["lw"] * 3 if "lw" in plotargs else 3.0 sbtsize = textargs["size"] if "size" in textargs else "small" ax.plot([0.0, 0.0], [0.0, sblength], "k-", lw=sbwidth) ax.annotate( r"%.0f M$\Omega$" % sblength, xy=(0.0, sblength / 2.0), xytext=(-0.04, sblength / 2.0), size=sbtsize, rotation=90, ha="center", va="center", ) ax.axes.get_xaxis().set_visible(False) ax.axes.get_yaxis().set_visible(False) ax.axison = False return z_max
def _expand_dendrogram( self, node, x0, y0, l_spacing, z_max, ax, plotargs={}, labelargs={}, textargs={}, incolors={}, inlabels={}, nodelabels={}, cs_comp={}, ): # check if part of compartment if cs_comp: if node.index in list(cs_comp.keys()): plotargs = copy.deepcopy(plotargs) plotargs["color"] = cs_comp["cm"](cs_comp[node.index]) # impedance of layer ynew = y0 + node.z_bar # plot vertical connection line ax.vlines(x0, y0, ynew, **plotargs) # get the child nodes for recursion l0 = 0 for i, cnode in enumerate(node.child_nodes): # attribute space on xaxis deg = self.degree_of_node(cnode) l1 = l0 + deg # new quantities xnew = (l_spacing[l0] + l_spacing[l1]) / 2.0 # horizontal connection line limits if i == 0: xnew0 = xnew if i == len(node.child_nodes) - 1: xnew1 = xnew # recursion self._expand_dendrogram( cnode, xnew, ynew, l_spacing[l0 : l1 + 1], z_max, ax, plotargs=plotargs, labelargs=labelargs, textargs=textargs, incolors=incolors, inlabels=inlabels, nodelabels=nodelabels, cs_comp=cs_comp, ) # next index l0 = l1 # plot horizontal connection line if l0 > 0: ax.hlines(ynew, xnew0, xnew1, **plotargs) # add label and maybe text annotation to node if node.index in labelargs: ax.plot([x0], [ynew], **labelargs[node.index]) elif -1 in labelargs: ax.plot([x0], [ynew], **labelargs[-1]) else: try: ax.plot([x0], [ynew], **labelargs) except TypeError as e: pass if textargs: if nodelabels != None: if node.index in nodelabels: if labelargs == {}: ax.plot([x0], [ynew], **nodelabels[node.index][1]) ax.annotate( nodelabels[node.index][0], xy=(x0, ynew), xytext=(x0 + 0.04, ynew + z_max * 0.04), bbox=dict( boxstyle="round", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8) ), **textargs, ) else: ax.annotate( nodelabels[node.index], xy=(x0, ynew), xytext=(x0 + 0.04, ynew + z_max * 0.04), bbox=dict( boxstyle="round", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8) ), **textargs, ) else: ax.annotate( r"$N=" + "".join([str(ind) for ind in node.loc_idxs]) + "$", xy=(x0, ynew), xytext=(x0 + 0.04, ynew + z_max * 0.04), bbox=dict(boxstyle="round", ec=(1.0, 0.5, 0.5), fc=(1.0, 0.8, 0.8)), **textargs, ) # add input label if self.is_leaf(node): if inlabels != None: lwidth = plotargs["lw"] if "lw" in plotargs else 1.0 ax.vlines( x0, ynew + z_max * 0.04, z_max * 1.1, lw=lwidth, linestyle=":", color="k", ) if node.loc_idxs[0] in incolors: bboxdict = dict( boxstyle="round", ec=incolors[node.loc_idxs[0]], fc=incolors[node.loc_idxs[0]], alpha=0.5, ) else: bboxdict = dict( boxstyle="round", ec=(0.5, 0.5, 1.0), fc=(0.8, 0.8, 1.0) ) if node.loc_idxs[0] in inlabels: textstr = inlabels[node.loc_idxs[0]] else: textstr = r"$" + str(node.loc_idxs[0]) + "$" ax.annotate( textstr, xy=(x0, z_max * 1.1), xytext=(x0, z_max * 1.14), ha="center", bbox=bboxdict, **textargs, )