Source code for neat.trees.greenstree

"""
File contains:

    - `neat.GreensNode`
    - `neat.SomaGreensNode`
    - `neat.GreensTree`

Author: W. Wybo
"""

import numpy as np

import copy

from . import morphtree
from .morphtree import MorphLoc
from .phystree import PhysNode, PhysTree


[docs]class GreensNode(PhysNode): ''' Node that stores quantities and defines functions to implement the impedance matrix calculation based on Koch's algorithm (Koch & Poggio, 1985). Attributes ---------- expansion_points: dict {str: np.ndarray} Stores ion channel expansion points for this segment. ''' def __init__(self, index, p3d): super().__init__(index, p3d) self.expansion_points = {} def _rescaleLengthRadius(self): self.R_ = self.R * 1e-4 # convert to cm self.L_ = self.L * 1e-4 # convert to cm
[docs] def setExpansionPoint(self, channel_name, statevar): """ Set the choice for the state variables of the ion channel around which to linearize. Note that when adding an ion channel to the node, the default expansion point setting is to linearize around the asymptotic values for the state variables at the equilibrium potential store in `self.e_eq`. Hence, this function only needs to be called to change that setting. Parameters ---------- channel_name: string the name of the ion channel statevar: dict The expansion points for each of the ion channel state variables """ if statevar is None: statevar = {} self.expansion_points[channel_name] = statevar
def getExpansionPoint(self, channel_name): try: return self.expansion_points[channel_name] except KeyError: self.expansion_points[channel_name] = {} return self.expansion_points[channel_name] def _calcMembraneImpedance(self, freqs, channel_storage, use_conc=False): """ Compute the impedance of the membrane at the node Parameters ---------- freqs: `np.ndarray` (``dtype=complex``, ``ndim=1``) The frequencies at which the impedance is to be evaluated channel_storage: dict of ion channels (optional) The ion channels that have been initialized already. If not provided, a new channel is initialized use_conc: bool if True, also uses concentration mechanisms to compute linearized membrane impedance Returns ------- `np.ndarray` (``dtype=complex``, ``ndim=1``) The membrane impedance """ if use_conc: g_m_ions = {conc: np.zeros_like(freqs) for conc in list(self.concmechs.keys())} g_m_aux = self.c_m * freqs + self.currents['L'][0] # loop over channels that do not read concentrations for channel_name in set(self.currents.keys()) - set('L'): g, e = self.currents[channel_name] if g > 1e-10: # create the ionchannel object channel = channel_storage[channel_name] if len(channel.conc) == 0: # check if needs to be computed around expansion point sv = self.getExpansionPoint(channel_name) # add channel contribution to membrane impedance g_ = g * channel.computeLinSum(self.e_eq, freqs, e, **sv) g_m_aux -= g_ if use_conc: g_m_ions[channel.ion] += g_ # g_m_ions[channel.ion] += g * channel.computePOpen(self.e_eq, statevars=sv) # loop over channels that do read concentrations for channel_name in set(self.currents.keys()) - set('L'): g, e = self.currents[channel_name] if g > 1e-10: # create the ionchannel object channel = channel_storage[channel_name] if len(channel.conc) > 0: # check if needs to be computed around expansion point sv = self.getExpansionPoint(channel_name) # add channel contribution to membrane impedance g_m_aux -= g * channel.computeLinSum(self.e_eq, freqs, e, **sv) if use_conc: for ion in channel.conc: g_m_aux -= g * channel.computeLinConc(self.e_eq, freqs, e, ion) * \ self.concmechs[ion].computeLinear(freqs) * \ g_m_ions[ion] return 1. / (2. * np.pi * self.R_ * g_m_aux) def _setImpedance(self, freqs, channel_storage, use_conc=False): self.counter = 0 self.z_m = self._calcMembraneImpedance(freqs, channel_storage, use_conc=use_conc) self.z_a = self.r_a / (np.pi * self.R_**2) self.gamma = np.sqrt(self.z_a / self.z_m) self.z_c = self.z_a / self.gamma def _setImpedanceDistal(self): """ Set the boundary condition at the distal end of the segment """ if len(self.child_nodes) == 0: self.z_distal = np.infty*np.ones(len(self.z_m)) if self.g_shunt < 1e-10 else \ 1. / self.g_shunt else: self.z_distal = 1. / (np.sum([1. / cnode._collapseBranchToRoot() \ for cnode in self.child_nodes], 0) + \ self.g_shunt) def _setImpedanceProximal(self): """ Set the boundary condition at the proximal end of the segment """ # child nodes of parent node without the current node sister_nodes = copy.copy(self.parent_node.child_nodes[:]) sister_nodes.remove(self) # compute the impedance val = 0. if self.parent_node is not None: val += 1. / self.parent_node._collapseBranchToLeaf() val += self.parent_node.g_shunt for snode in sister_nodes: val += 1. / snode._collapseBranchToRoot() self.z_proximal = 1. / val def _collapseBranchToLeaf(self): return self.z_c * (self.z_proximal * np.cosh(self.gamma * self.L_) + \ self.z_c * np.sinh(self.gamma * self.L_)) / \ (self.z_proximal * np.sinh(self.gamma * self.L_) + self.z_c * np.cosh(self.gamma * self.L_)) def _collapseBranchToRoot(self): zr = self.z_c * (np.cosh(self.gamma * self.L_) + self.z_c / self.z_distal * np.sinh(self.gamma * self.L_)) / \ (np.sinh(self.gamma * self.L_) + self.z_c / self.z_distal * np.cosh(self.gamma * self.L_)) return zr def _setImpedanceArrays(self): self.gammaL = self.gamma * self.L_ self.z_cp = self.z_c / self.z_proximal self.z_cd = self.z_c / self.z_distal self.wrongskian = np.cosh(self.gammaL) / self.z_c * \ (self.z_cp + self.z_cd + \ (1. + self.z_cp * self.z_cd) * np.tanh(self.gammaL)) self.z_00 = (self.z_cd * np.sinh(self.gammaL) + np.cosh(self.gammaL)) / \ self.wrongskian self.z_11 = (self.z_cp * np.sinh(self.gammaL) + np.cosh(self.gammaL)) / \ self.wrongskian self.z_01 = 1. / self.wrongskian def _calcZF(self, x1, x2): if x1 < 1e-3 and x2 < 1e-3: return self.z_00 elif x1 > 1.-1e-3 and x2 > 1.-1e-3: return self.z_11 elif (x1 < 1e-3 and x2 > 1.-1e-3) or (x1 > 1.-1e-3 and x2 < 1e-3): return self.z_01 elif x1 < x2: return (self.z_cp * np.sinh(self.gammaL*x1) + np.cosh(self.gammaL*x1)) * \ (self.z_cd * np.sinh(self.gammaL*(1.-x2)) + np.cosh(self.gammaL*(1.-x2))) / \ self.wrongskian else: return (self.z_cp * np.sinh(self.gammaL*x2) + np.cosh(self.gammaL*x2)) * \ (self.z_cd * np.sinh(self.gammaL*(1.-x1)) + np.cosh(self.gammaL*(1.-x1))) / \ self.wrongskian
class SomaGreensNode(GreensNode): def _calcMembraneImpedance(self, freqs, channel_storage, use_conc=False): z_m = super()._calcMembraneImpedance(freqs, channel_storage, use_conc=use_conc) # rescale for soma surface instead of cylinder radius # return z_m / (2. * self.R_) return 1. / (2. * self.R_ / z_m + self.g_shunt) def _setImpedance(self, freqs, channel_storage, use_conc=False): self.counter = 0 self.z_soma = self._calcMembraneImpedance(freqs, channel_storage, use_conc=use_conc) def _collapseBranchToLeaf(self): return self.z_soma def _setImpedanceArrays(self): val = 1. / self.z_soma for node in self.child_nodes: val += 1. / node._collapseBranchToRoot() self.z_in = 1. / val def _calcZF(self, x1, x2): return self.z_in
[docs]class GreensTree(PhysTree): """ Class that computes the Green's function in the Fourrier domain of a given neuronal morphology (Koch, 1985). This three defines a special `neat.SomaGreensNode` as a derived class from `neat.GreensNode` as some functions required for Green's function calculation are different and thus overwritten. The calculation proceeds on the computational tree (see docstring of `neat.MorphNode`). Thus it makes no sense to look for Green's function related quantities in the original tree. Attributes ---------- freqs: np.array of complex Frequencies at which impedances are evaluated ``[Hz]`` """ def __init__(self, file_n=None, types=[1,3,4]): super().__init__(file_n=file_n, types=types) self.freqs = None def _createCorrespondingNode(self, node_index, p3d=None): """ Creates a node with the given index corresponding to the tree class. Parameters ---------- node_index: `int` index of the new node """ if node_index == 1: return SomaGreensNode(node_index, p3d) else: return GreensNode(node_index, p3d)
[docs] def removeExpansionPoints(self): """ Remove expansion points from all nodes in the tree """ for node in self: node.expansion_points = {}
@morphtree.computationalTreetypeDecorator def setImpedance(self, freqs, use_conc=False, pprint=False): """ Set the boundary impedances for each node in the tree Parameters ---------- freqs: `np.ndarray` (``dtype=complex``, ``ndim=1``) frequencies at which the impedances will be evaluated ``[Hz]`` use_conc: bool whether or not to incorporate concentrations in the calculation pprint: bool (default ``False``) whether or not to print info on the progression of the algorithm """ self.freqs = freqs # set the node specific impedances for node in self: node._rescaleLengthRadius() node._setImpedance(freqs, self.channel_storage, use_conc=use_conc) # recursion if len(self) > 1: self._impedanceFromLeaf(self.leafs[0], self.leafs[1:], pprint=pprint) self._impedanceFromRoot(self.root) # clean for node in self: node.counter = 0 node._setImpedanceArrays() def _impedanceFromLeaf(self, node, leafs, pprint=False): if pprint: print('Forward sweep: ' + str(node)) pnode = node.parent_node # log how many times recursion has passed at node if not self.isLeaf(node): node.counter += 1 # if the number of childnodes of node is equal to the amount of times # the recursion has passed node, the distal impedance can be set. Otherwise # we start a new recursion at another leaf. if node.counter == len(node.child_nodes): node._setImpedanceDistal() if not self.isRoot(node): self._impedanceFromLeaf(pnode, leafs, pprint=pprint) elif len(leafs) > 0: self._impedanceFromLeaf(leafs[0], leafs[1:], pprint=pprint) def _impedanceFromRoot(self, node): if node != self.root: node._setImpedanceProximal() for cnode in node.child_nodes: self._impedanceFromRoot(cnode) @morphtree.computationalTreetypeDecorator def calcZF(self, loc1, loc2): """ Computes the transfer impedance between two locations for all frequencies in `self.freqs`. Parameters ---------- loc1: dict, tuple or `:class:MorphLoc` One of two locations between which the transfer impedance is computed loc2: dict, tuple or `:class:MorphLoc` One of two locations between which the transfer impedance is computed Returns ------- nd.ndarray (dtype = complex, ndim = 1) The transfer impedance ``[MOhm]`` as a function of frequency """ # cast to morphlocs loc1 = MorphLoc(loc1, self) loc2 = MorphLoc(loc2, self) # the path between the nodes path = self.pathBetweenNodes(self[loc1['node']], self[loc2['node']]) # compute the kernel z_f = np.ones_like(self.freqs) if len(path) == 1: # both locations are on same node z_f *= path[0]._calcZF(loc1['x'], loc2['x']) else: # different cases whether path goes to or from root if path[1] == self[loc1['node']].parent_node: z_f *= path[0]._calcZF(loc1['x'], 0.) else: z_f *= path[0]._calcZF(loc1['x'], 1.) z_f /= path[0]._calcZF(1., 1.) if path[-2] == self[loc2['node']].parent_node: z_f *= path[-1]._calcZF(loc2['x'], 0.) else: z_f *= path[-1]._calcZF(loc2['x'], 1.) z_f /= path[-1]._calcZF(1., 1.) # nodes within the path ll = 1 for node in path[1:-1]: z_f /= node._calcZF(1., 1.) if path[ll-1] not in node.child_nodes or \ path[ll+1] not in node.child_nodes: z_f *= node._calcZF(0., 1.) ll += 1 return z_f @morphtree.computationalTreetypeDecorator def calcImpedanceMatrix(self, locarg, explicit_method=True): """ Computes the impedance matrix of a given set of locations for each frequency stored in `self.freqs`. Parameters ---------- locarg: `list` of locations or string if `list` of locations, specifies the locations for which the impedance matrix is evaluated, if ``string``, specifies the name under which a set of location is stored explicit_method: bool, optional (default ``True``) if ``False``, will use the transitivity property of the impedance matrix to further optimize the computation. Returns ------- `np.ndarray` (``dtype = self.freqs.dtype``, ``ndim = 3``) the impedance matrix, first dimension corresponds to the frequency, second and third dimensions contain the impedance matrix ``[MOhm]`` at that frequency """ if isinstance(locarg, list): locs = [MorphLoc(loc, self) for loc in locarg] elif isinstance(locarg, str): locs = self.getLocs(locarg) else: raise IOError('`locarg` should be list of locs or string') z_mat = np.zeros((len(self.freqs), len(locs), len(locs)), dtype=self.freqs.dtype) if explicit_method: for ii, loc0 in enumerate(locs): jj = 0 while jj < ii: loc1 = locs[jj] z_f = self.calcZF(loc0, loc1) z_mat[:,ii,jj] = z_f z_mat[:,jj,ii] = z_f jj += 1 z_f = self.calcZF(loc0, loc0) z_mat[:,ii,ii] = z_f else: for ii in range(len(locs)): self._calcImpedanceMatrixFromNode(ii, locs, z_mat) return z_mat def _calcImpedanceMatrixFromNode(self, ii, locs, z_mat): node = self[locs[ii]['node']] for jj, loc in enumerate(locs): if loc['node'] == node.index and jj >= ii: z_new = node._calcZF(locs[ii]['x'],loc['x']) z_mat[:,ii,jj] = z_new z_mat[:,jj,ii] = z_new # move down for c_node in node.child_nodes: z_new = node._calcZF(locs[ii]['x'], 1.) self._calcImpedanceMatrixDown(ii, z_new, c_node, locs, z_mat) if node.parent_node is not None: z_new = node._calcZF(locs[ii]['x'], 0.) # move to sister nodes for c_node in set(node.parent_node.child_nodes) - {node}: self._calcImpedanceMatrixDown(ii, z_new, c_node, locs, z_mat) # move up self._calcImpedanceMatrixUp(ii, z_new, node.parent_node, locs, z_mat) def _calcImpedanceMatrixUp(self, ii, z_0, node, locs, z_mat): # compute impedances z_in = node._calcZF(1.,1.) for jj, loc in enumerate(locs): if jj > ii and loc['node'] == node.index: z_new = z_0 / z_in * node._calcZF(1.,loc['x']) z_mat[:,ii,jj] = z_new z_mat[:,jj,ii] = z_new if node.parent_node is not None: z_new = z_0 / z_in * node._calcZF(0., 1.) # move to sister nodes for c_node in set(node.parent_node.child_nodes) - {node}: self._calcImpedanceMatrixDown(ii, z_new, c_node, locs, z_mat) # move to parent node z_new = z_0 / z_in * node._calcZF(0., 1.) self._calcImpedanceMatrixUp(ii, z_new, node.parent_node, locs, z_mat) def _calcImpedanceMatrixDown(self, ii, z_0, node, locs, z_mat): # compute impedances z_in = node._calcZF(0.,0.) for jj, loc in enumerate(locs): if jj > ii and loc['node'] == node.index: z_new = z_0 / z_in * node._calcZF(0., loc['x']) z_mat[:,ii,jj] = z_new z_mat[:,jj,ii] = z_new # recurse to child nodes z_new = z_0 / z_in * node._calcZF(0., 1.) for c_node in node.child_nodes: self._calcImpedanceMatrixDown(ii, z_new, c_node, locs, z_mat)