Source code for neat.trees.sovtree

"""
File contains:

    - `neat.SOVNode`
    - `neat.SomaSOVNode`
    - `neat.SOVTree`

Author: W. Wybo
"""

import numpy as np

import itertools
import copy

from . import morphtree
from .morphtree import MorphLoc
from .phystree import PhysNode, PhysTree
from .netree import NETNode, NET, Kernel

from ..tools.fittools import zerofinding as zf
from ..tools.fittools import histogramsegmentation as hs


def _consecutive(data, stepsize=1):
    return np.split(data, np.where(np.diff(data) != stepsize)[0]+1)


[docs]class SOVNode(PhysNode): """ Node that defines functions and stores quantities to implement separation of variables calculation (Major, 1993) """ def __init__(self, index, p3d=None): super().__init__(index, p3d) def _setSOV(self, channel_storage, tau_0=0.02): self.counter = 0 # segment parameters self.g_m = self.getGTot(channel_storage) # uS/cm^2 # parameters for SOV approach self.R_sov = self.R * 1e-4 # convert um to cm self.L_sov = self.L * 1e-4 # convert um to cm self.tau_m = self.c_m / self.g_m # s self.eps_m = self.tau_m / tau_0 self.lambda_m = np.sqrt(self.R_sov / (2.*self.g_m*self.r_a)) # cm self.tau_0 = tau_0 # s self.z_a = self.r_a / (np.pi * self.R_sov**2) # MOhm/cm self.g_inf_m = 1. / (self.z_a * self.lambda_m) # uS # # segment amplitude information self.kappa_m = np.NaN self.mu_vals_m = np.NaN self.q_vals_m = np.NaN def q_m(self, x): return np.sqrt(self.eps_m*x**2 - 1.) def dq_dp_m(self, x): return -self.tau_m / (2.*self.q_m(x)) def mu_m(self, x): cns = self.child_nodes if len(cns) == 0: return self.z_a * self.lambda_m / self.q_m(x) * self.g_shunt else: x = zf._to_complex(x) mu_ds = [cn.mu_m(x) for cn in cns] q_ds = [cn.q_m(x) for cn in cns] return ( self.g_shunt - np.sum([ \ cn.g_inf_m*q_ds[i] * \ ( 1.-mu_ds[i]/np.tan(q_ds[i]*cn.L_sov/cn.lambda_m) ) / ( 1./np.tan(q_ds[i]*cn.L_sov/cn.lambda_m)+mu_ds[i]) \ for i, cn in enumerate(cns)], 0) ) / (self.g_inf_m * self.q_m(x)) def dmu_dp_m(self, x): cns = self.child_nodes if len(cns) == 0: return -self.dq_dp_m(x) * self.mu_m(x) / self.q_m(x) else: x = zf._to_complex(x) mu_ds = [cn.mu_m(x) for cn in cns] dmu_dp_ds = [cn.dmu_dp_m(x) for cn in cns] q_ds = [cn.q_m(x) for cn in cns] dq_dp_ds = [cn.dq_dp_m(x) for cn in cns] return (-self.dq_dp_m(x) * self.mu_m(x) - np.sum([ cn.g_inf_m * ( \ dq_dp_ds[i] * \ ( 1.-mu_ds[i]/np.tan(q_ds[i]*cn.L_sov/cn.lambda_m) ) / ( 1./np.tan(q_ds[i]*cn.L_sov/cn.lambda_m)+mu_ds[i]) + \ q_ds[i] * \ ( (1.+mu_ds[i]**2) * dq_dp_ds[i] * cn.L_sov/cn.lambda_m - dmu_dp_ds[i] ) / \ ( np.cos(q_ds[i]*cn.L_sov/cn.lambda_m) + mu_ds[i]*np.sin(q_ds[i]*cn.L_sov/cn.lambda_m) )**2 \ ) for i, cn in enumerate(cns)], 0) / self.g_inf_m ) / self.q_m(x) def _setKappaFactors(self, xzeros): xzeros = zf._to_complex(xzeros) self.kappa_m = self.parent_node.kappa_m / \ (np.cos(self.q_m(xzeros)*self.L_sov/self.lambda_m) + \ self.mu_m(xzeros)*np.sin(self.q_m(xzeros)*self.L_sov/self.lambda_m)) def _setMuVals(self, xzeros): xzeros = zf._to_complex(xzeros) self.mu_vals_m = self.mu_m(xzeros) def _setQVals(self, xzeros): xzeros = zf._to_complex(xzeros) self.q_vals_m = self.q_m(xzeros) def _findLocalPoles(self, maxspace_freq=500): poles = [] pmultiplicities = [] n = 0 val = 0. while val < maxspace_freq: poles.append(np.sqrt((1.+val**2)/self.eps_m)) if val == 0: pmultiplicities.append(.5) else: pmultiplicities.append(1.) n += 1 val = n*np.pi * self.lambda_m / self.L_sov return poles, pmultiplicities def _setZerosPoles(self, maxspace_freq=500, pprint=False): cns = self.child_nodes # find the poles of (1 + mu*cot(qL/l))/(cot(qL/l) + mu) lpoles, lpmultiplicities = self._findLocalPoles(maxspace_freq) for cn in cns: lpoles.extend(cn.poles) lpmultiplicities.extend(cn.pmultiplicities) inds = np.argsort(lpoles) lpoles = np.array(lpoles)[inds]; lpmultiplicities = np.array(lpmultiplicities)[inds] # construct the function cot(qL/l) + mu f = lambda x: 1./np.tan(self.q_m(x)*self.L_sov/self.lambda_m) + self.mu_m(x) dfdx = lambda x: -2.*x * ( -(self.L_sov/self.lambda_m)/np.sin(self.q_m(x)*self.L_sov/self.lambda_m)**2 * self.dq_dp_m(x) +\ self.dmu_dp_m(x) ) / self.tau_0 # find its zeros, this are the poles of the next level xval = 1.5/np.sqrt(self.eps_m) for cn in cns: c_eps_m = cn.eps_m xval_ = 1.5/np.sqrt(c_eps_m) if xval_ > xval: xval = xval_ if np.abs(f(xval)) < 1e-20: xval = (xval+lpoles[1])/2. if pprint: print('') print('xval: ', xval) # find zeros larger than xval if pprint: print('finding real poles') PF = zf.poleFinder(fun=f, dfun=dfdx, global_poles={'poles': lpoles, 'pmultiplicities': lpmultiplicities}) poles, pmultiplicities = PF.find_real_zeros(vmin=xval) # find the first zero if pprint: print('finding first pole') p1 = []; pm1 = [] zf.find_zeros_on_segment(p1, pm1, 0., xval, f, dfdx, lpoles, lpmultiplicities, pprint=pprint) self.poles = np.concatenate((p1, poles)).real; self.pmultiplicities = np.concatenate((pm1, pmultiplicities)).real
class SomaSOVNode(SOVNode): """ Subclass of SOVNode to threat the special case of the soma The following member functions are not supposed to work properly, calling them may result in errors: `neat.SOVNode._setKappaFactors()` `neat.SOVNode._setMuVals()` `neat.SOVNode._setQVals()` `neat.SOVNode._findLocalPoles()` """ def __init__(self, index, p3d=None): super().__init__(index, p3d) def _setSOV(self, channel_storage, tau_0=0.02): self.counter = 0 # convert to cm self.R_sov = self.R * 1e-4 # convert um to cm self.L_sov = self.L * 1e-4 # convert um to cm # surface self.A = 4.0*np.pi*self.R_sov**2 # cm^2 # total conductance self.g_m = self.getGTot(channel_storage=channel_storage) # uS/cm^2 # parameters for the SOV approach self.tau_m = self.c_m / self.g_m # s self.eps_m = self.tau_m / tau_0 # ns self.g_s = self.g_m*self.A + self.g_shunt # uS self.c_s = self.c_m*self.A # uF self.tau_0 = tau_0 # s # segment amplitude factors self.kappa_m = 1. def f_transc(self, x): cns = self.child_nodes x = zf._to_complex(x) mu_ds = [cn.mu_m(x) for cn in cns] q_ds = [cn.q_m(x) for cn in cns] return self.g_s * (1.-self.eps_m*x**2) - np.sum([ \ cn.g_inf_m*q_ds[i] * \ ( 1.-mu_ds[i]/np.tan(q_ds[i]*cn.L_sov/cn.lambda_m) ) / ( 1./np.tan(q_ds[i]*cn.L_sov/cn.lambda_m)+mu_ds[i]) \ for i, cn in enumerate(cns)], 0) def dN_dp(self, x): cns = self.child_nodes x = zf._to_complex(x) mu_ds = [cn.mu_m(x) for cn in cns] dmu_dp_ds = [cn.dmu_dp_m(x) for cn in cns] q_ds = [cn.q_m(x) for cn in cns] dq_dp_ds = [cn.dq_dp_m(x) for cn in cns] return self.c_s - np.sum([ cn.g_inf_m * ( \ dq_dp_ds[i] * \ ( 1.-mu_ds[i]/np.tan(q_ds[i]*cn.L_sov/cn.lambda_m) ) / ( 1./np.tan(q_ds[i]*cn.L_sov/cn.lambda_m)+mu_ds[i]) + \ q_ds[i] * \ ( (1.+mu_ds[i]**2) * dq_dp_ds[i] * cn.L_sov/cn.lambda_m - dmu_dp_ds[i] ) / \ ( np.cos(q_ds[i]*cn.L_sov/cn.lambda_m) + mu_ds[i]*np.sin(q_ds[i]*cn.L_sov/cn.lambda_m) )**2 \ ) for i, cn in enumerate(cns)], 0) def _setZerosPoles(self, maxspace_freq=500, pprint=False): # find the poles of cot(qL/l) + mu lpoles = []; lpmultiplicities = [] for cn in self.child_nodes: lpoles.extend(cn.poles) lpmultiplicities.extend(cn.pmultiplicities) inds = np.argsort(lpoles) lpoles = np.array(lpoles)[inds]; lpmultiplicities = np.array(lpmultiplicities)[inds] # construct the function cot(qL/l) + mu f = lambda x: self.f_transc(x) dfdx = lambda x: -2.*x * self.dN_dp(x) / self.tau_0 # find its zeros, this are the inverse timescales of the model xval = 1.5/np.sqrt(self.eps_m) for cn in self.child_nodes: c_eps_m = cn.eps_m xval_ = 1.5/np.sqrt(c_eps_m) if xval_ > xval: xval = xval_ if np.abs(f(xval)) < 1e-20: xval = (xval+lpoles[1])/2. if pprint: print('xval: ', xval) # find zeros larger than xval PF = zf.poleFinder(fun=f, dfun=dfdx, global_poles={'poles': lpoles, 'pmultiplicities': lpmultiplicities}) zeros, multiplicities = PF.find_real_zeros(vmin=xval) # find the first zero z1 = []; zm1 = [] zf.find_zeros_on_segment(z1, zm1, 0., xval, f, dfdx, lpoles, lpmultiplicities, pprint=pprint) self.zeros = np.concatenate((z1, zeros)).real; self.zmultiplicities = np.concatenate((zm1, multiplicities)).real self.prefactors = self.dN_dp(self.zeros).real
[docs]class SOVTree(PhysTree): """ Class that computes the separation of variables time scales and spatial mode functions for a given morphology and electrical parameter set. Employs the algorithm by (Major, 1994). This three defines a special `neat.SomaSOVNode` on as a derived class from `neat.SOVNode` as some functions required for SOV calculation are different and thus overwritten. The SOV calculation proceeds on the computational tree (see docstring of `neat.MorphNode`). Thus it makes no sense to look for sov quantities in the original tree. """ def __init__(self, file_n=None, types=[1,3,4]): super().__init__(file_n=file_n, types=types) def _createCorrespondingNode(self, node_index, p3d=None): """ Creates a node with the given index corresponding to the tree class. Parameters ---------- node_index: int index of the new node """ if node_index == 1: return SomaSOVNode(node_index, p3d=p3d) else: return SOVNode(node_index, p3d=p3d) @morphtree.computationalTreetypeDecorator def getSOVMatrices(self, locarg): """ returns the alphas, the reciprocals of the mode time scales [1/ms] as well as the spatial functions evaluated at ``locs`` Parameters ---------- locarg: see :func:`neat.MorphTree._parseLocArg()` the locations at which to evaluate the SOV matrices Returns ------- alphas: np.ndarray of complex (ndim = 1) the reciprocals of mode time-scales (kHz) gammas: np.ndarray of complex (ndim = 2) the spatial function associated with each mode, evaluated at each locations. Dimension 0 is number of modes and dimension 1 number of locations """ locs = self._parseLocArg(locarg) if len(self) > 1: # set up the matrices zeros = self.root.zeros prefactors = self.root.prefactors alphas = zeros**2 / (self.tau_0*1e3) gammas = np.zeros((len(alphas), len(locs)), dtype=complex) # fill the matrix of prefactors for ii, loc in enumerate(locs): if loc['node'] == 1: x = 0. node = self.root.child_nodes[0] else: x = loc['x'] node = self[loc['node']] # fill a column of the matrix, corresponding to current loc gammas[:, ii] = node.kappa_m * \ (np.cos(node.q_vals_m*(1.-x)*node.L_sov/node.lambda_m) + \ node.mu_vals_m * np.sin(node.q_vals_m*(1.-x)*node.L_sov/node.lambda_m)) / \ np.sqrt(prefactors*1e3) else: alphas = np.array([1e-3 / self.root.tau_m]) gammas = np.array([[np.sqrt(alphas[0] / self.root.g_s)]]) # return the matrices return alphas, gammas @morphtree.computationalTreetypeDecorator def calcSOVEquations(self, maxspace_freq=500., pprint=False): """ Calculate the timescales and spatial functions of the separation of variables approach, using the algorithm by (Major, 1993). The (reciprocals) of the timescales (i.e. the roots of the transcendental equation) are stored in the somanode. The spatial factors are stored in each (computational) node. Parameters ---------- maxspace_freq: float (default is 500) roughly corresponds to the maximal spatial frequency of the smallest time-scale mode """ self.tau_0 = np.pi#1. for node in self: node._setSOV(self.channel_storage, tau_0=self.tau_0) if len(self) > 1: # start the recursion through the tree self._SOVFromLeaf(self.leafs[0], self.leafs[1:], maxspace_freq=maxspace_freq, pprint=pprint) # zeros are now found, set the kappa factors zeros = self.root.zeros self._SOVFromRoot(self.root, zeros) # clean for node in self: node.counter = 0 else: self[1]._setSOV(self.channel_storage, tau_0=self.tau_0) def _SOVFromLeaf(self, node, leafs, count=0, maxspace_freq=500., pprint=False): if pprint: print('Forward sweep: ' + str(node)) pnode = node.parent_node # log how many times recursion has passed at node if not self.isLeaf(node): node.counter += 1 # if the number of childnodes of node is equal to the amount of times # the recursion has passed node, the mu functions can be set. Otherwise # we start a new recursion at another leaf. if node.counter == len(node.child_nodes): # node._setMuFunctions() node._setZerosPoles(maxspace_freq=maxspace_freq) if not self.isRoot(node): self._SOVFromLeaf(pnode, leafs, count=count+1, maxspace_freq=maxspace_freq, pprint=pprint) elif len(leafs) > 0: self._SOVFromLeaf(leafs[0], leafs[1:], count=count+1, maxspace_freq=maxspace_freq, pprint=pprint) def _SOVFromRoot(self, node, zeros): for cnode in node.child_nodes: cnode._setKappaFactors(zeros) cnode._setMuVals(zeros) cnode._setQVals(zeros) self._SOVFromRoot(cnode, zeros)
[docs] def getModeImportance(self, locarg=None, sov_data=None, importance_type='simple'): """ Gives the overal importance of the SOV modes for a certain set of locations Parameters ---------- locarg: None or list of locations sov_data: None or tuple of mode matrices One of the keyword arguments ``locarg`` or ``sov_data`` must not be ``None``. If ``locarg`` is not ``None``, the importance is evaluated at these locations (see :func:`neat.MorphTree._parseLocArg`). If ``sov_data`` is not ``None``, it is a tuple of a vector of the reciprocals of the mode timescales and a matrix with the corresponding spatial mode functions. importance_type: string ('relative' or 'absolute') when 'absolute', returns an absolute measure of the importance, when 'relative', normalizes so that maximum importance is one. Defaults to 'relative'. Returns ------- np.ndarray (ndim = 1) the importances associated with each mode for the provided set of locations """ if locarg is not None: locs = self._parseLocArg(locarg) alphas, gammas = self.getSOVMatrices(locs) elif sov_data is not None: alphas = sov_data[0] gammas = sov_data[1] else: raise IOError('One of the kwargs `locarg` or `sov_data` must not be ``None``') if importance_type == 'simple': absolute_importance = np.sum(np.abs(gammas), 1) / np.abs(alphas) elif importance_type == 'full': absolute_importance = np.zeros(len(alphas)) for kk, (alpha, phivec) in enumerate(zip(alphas, gammas)): absolute_importance[kk] = np.sqrt(np.sum(np.abs(np.dot(phivec[:,None], phivec[None,:]))) / np.abs(alpha)) else: raise ValueError('`importance_type` argument can be \'simple\' or \ \'full\'') return absolute_importance / np.max(absolute_importance)
[docs] def getImportantModes(self, locarg=None, sov_data=None, eps=1e-4, sort_type='timescale', return_importance=False): """ Returns the most importand eigenmodes (those whose importance is above the threshold defined by `eps`) Parameters ---------- locarg: None or list of locations sov_data: None or tuple of mode matrices One of the keyword arguments ``locarg`` or ``sov_data`` must not be ``None``. If ``locarg`` is not ``None``, the importance is evaluated at these locations (see :func:`neat.MorphTree._parseLocArg`). If ``sov_data`` is not ``None``, it is a tuple of a vector of the reciprocals of the mode timescales and a matrix with the corresponding spatial mode functions. eps: float the cutoff threshold in relative importance below which modes are truncated sort_type: string ('timescale' or 'importance') specifies in which order the modes are returned. If 'timescale', modes are sorted in order of decreasing time-scale, if 'importance', modes are sorted in order of decreasing importance. return_importance: bool if ``True``, returns the importance metric associated with each mode Returns ------- alphas: np.ndarray of complex (ndim = 1) the reciprocals of mode time-scales ``[kHz]`` gammas: np.ndarray of complex (ndim = 2) the spatial function associated with each mode, evaluated at each locations. Dimension 0 is number of modes and dimension 1 number of locations importance: np.ndarray (`shape` matches `alphas`, only if `return_importance` is ``True``) value of importance metric for each mode """ if locarg is not None: locs = self._parseLocArg(locarg) alphas, gammas = self.getSOVMatrices(locs) elif sov_data is not None: alphas = sov_data[0] gammas = sov_data[1] else: raise IOError('One of the kwargs `locarg` or `sov_data` must not be ``None``') importance = self.getModeImportance(sov_data=(alphas, gammas), importance_type='simple') inds = np.where(importance > eps)[0] # only modes above importance cutoff alphas, gammas, importance = alphas[inds], gammas[inds,:], importance[inds] if sort_type == 'timescale': inds_sort = np.argsort(np.abs(alphas)) elif sort_type == 'importance': inds_sort = np.argsort(importance)[::-1] else: raise ValueError('`sort_type` argument can be \'timescale\' or \ \'importance\'') if return_importance: return alphas[inds_sort], gammas[inds_sort,:], importance[inds_sort] else: return alphas[inds_sort], gammas[inds_sort,:]
[docs] def calcImpedanceMatrix(self, locarg=None, sov_data=None, name=None, eps=1e-4, mem_limit=500, freqs=None): """ Compute the impedance matrix for a set of locations Parameters ---------- locarg: None or list of locations sov_data: None or tuple of mode matrices One of the keyword arguments ``locarg`` or ``sov_data`` must not be ``None``. If ``locarg`` is not ``None``, the importance is evaluated at these locations (see :func:`neat.MorphTree._parseLocArg`). If ``sov_data`` is not ``None``, it is a tuple of a vector of the reciprocals of the mode timescales and a matrix with the corresponding spatial mode functions. eps: float the cutoff threshold in relative importance below which modes are truncated mem_limit: int parameter governs whether the fast (but memory intense) method or the slow method is used freqs: np.ndarray of complex or None (default) if ``None``, returns the steady state impedance matrix, if a array of complex numbers, returns the impedance matrix for each Fourrier frequency in the array Returns ------- np.ndarray of floats (ndim = 2 or 3) the impedance matrix, steady state if `freqs` is ``None``, the frequency dependent impedance matrix if `freqs` is given, with the frequency dependence at the first dimension ``[MOhm ]`` """ if locarg is not None: locs = self._parseLocArg(locarg) alphas, gammas = self.getSOVMatrices(locs) elif sov_data is not None: alphas = sov_data[0] gammas = sov_data[1] else: raise IOError('One of the kwargs `locarg` or `sov_data` must not be ``None``') n_loc = gammas.shape[1] if freqs is None: # construct the 2d steady state matrix y_activation = 1. / alphas # compute the matrix, methods depends on memory limit if gammas.shape[1] < mem_limit and gammas.shape[0] < int(mem_limit/2.): z_mat = np.sum(gammas[:,:,np.newaxis] * \ gammas[:,np.newaxis,:] * \ y_activation[:,np.newaxis,np.newaxis], 0).real else: z_mat = np.zeros((n_loc, n_loc)) for ii, jj in itertools.product(range(n_loc), range(n_loc)): z_mat[ii,jj] = np.sum(gammas[:,ii] * \ gammas[:,jj] * \ y_activation).real else: # construct the 3d fourrier matrix y_activation = 1e3 / (alphas[np.newaxis,:]*1e3 + freqs[:,np.newaxis]) z_mat = np.zeros((len(freqs), n_loc, n_loc), dtype=complex) for ii, jj in itertools.product(range(n_loc), range(n_loc)): z_mat[:,ii,jj] = np.sum(gammas[np.newaxis,:,ii] * \ gammas[np.newaxis,:,jj] * \ y_activation, 1) return z_mat
[docs] def constructNET(self, dz=50., dx=10., eps=1e-4, use_hist=False, add_lin_terms=True, improve_input_impedance=False, pprint=False): """ Construct a Neural Evaluation Tree (NET) for this cell Parameters ---------- dz: float the impedance step for the NET model derivation dx: float the distance step to evaluate the impedance matrix eps: float the cutoff threshold in relative importance below which modes are truncated use_hist: bool whether or not to use histogram segmentations to find well separated parts of the dendritic tree (such ass apical tree) add_lin_terms: take into account that the optained NET will be used in conjunction with linear terms Returns ------- `neat.NETree` The neural evaluation tree (Wybo et al., 2019) associated with the morphology. """ # create a set of location at which to evaluate the impedance matrix self.distributeLocsUniform(dx=dx, name='net eval') # compute the z_mat matrix alphas, gammas = self.getImportantModes(locarg='net eval', eps=eps) z_mat = self.calcImpedanceMatrix(sov_data=(alphas, gammas)) # derive the NET net = NET() self._addLayerA(net, None, z_mat, alphas, gammas, 0., 0, np.arange(len(self.getLocs('net eval'))), dz=dz, use_hist=use_hist, add_lin_terms=add_lin_terms, pprint=pprint) net.setNewLocInds() if improve_input_impedance: self._improveInputImpedance(net, alphas, gammas) if add_lin_terms: lin_terms = self.computeLinTerms(net, sov_data=(alphas, gammas)) return net, lin_terms else: return net
def _addLayerA(self, net, pnode, z_mat, alphas, gammas, z_max_prev, z_ind_0, true_loc_inds, dz=100., use_hist=True, add_lin_terms=False, pprint=False): # create a histogram n_bin = 15 z_hist = np.histogram(z_mat[0,:], n_bin, density=False) # find the histogram partition h_ftc = hs.histogramSegmentator(z_hist) s_inds, p_inds = h_ftc.partition_fine_to_coarse(eps=1.4) while len(s_inds) > 3: s_inds = np.delete(s_inds, 1) # identify the necessary node indices and kernel computation indices node_inds = [] kernel_inds = [] min_inds = [] for ii, si in enumerate(s_inds[:-1]): if si > 0: n_inds = np.where(z_mat[0,:] > z_hist[1][si+1])[0] k_inds = np.where(np.logical_and( z_mat[0,:] > z_hist[1][si+1], z_mat[0,:] <= z_hist[1][s_inds[ii+1]+1]))[0] min_ind = np.argmin(z_mat[0,k_inds]) min_inds.append(min_ind) else: n_inds = np.where(z_mat[0,:] >= z_hist[1][0])[0] k_inds = np.where(np.logical_and( z_mat[0,:] >= z_hist[1][0], z_mat[0,:] <= z_hist[1][s_inds[ii+1]+1]))[0] min_ind = np.argmin(z_mat[0,k_inds]) min_inds.append(min_ind) node_inds.append(n_inds) kernel_inds.append(k_inds) # add NET nodes to the NET tree for ii, n_inds in enumerate(node_inds): k_inds = kernel_inds[ii] if len(k_inds) != 0: if add_lin_terms: # get the minimal kernel gammas_avg = gammas[:,0] * \ gammas[:,k_inds[min_inds[ii]]] else: # get the average kernel if len(k_inds) < 100000: gammas_avg = np.mean(gammas[:,0:1] * \ gammas[:,k_inds], 1) else: inds_ = np.random.choice(k_inds, size=100000) gammas_avg = np.mean(gammas[:,0:1] * \ gammas[:,inds_], 1) z_avg_approx = np.sum(gammas_avg / alphas).real self._subtractParentKernels(gammas_avg, pnode) # add a node to the tree node = NETNode(len(net), true_loc_inds[n_inds], z_kernel=(alphas, gammas_avg)) if pnode != None: net.addNodeWithParent(node, pnode) else: net.root = node # set new pnode pnode = node # print stuff if pprint: print(node) print('n_loc =', len(node.loc_inds)) print('(locind0, size) = ', (k_inds[0], z_mat.shape[0])) print('') if k_inds[0] == 0: # start new branches, split where they originate from soma by # checking where input impedance is close to somatic transfer # impedance z_max = z_hist[1][s_inds[ii+1]] # check where new dendritic branches start z_diag = z_mat[k_inds, k_inds] z_x0 = z_mat[k_inds, 0] b_inds = np.where(np.abs(z_diag - z_x0) < dz / 2.)[0][1:].tolist() if len(b_inds) > 0: if b_inds[0] != 1: b_inds = [1] + b_inds kk = len(b_inds)-1 while kk > 0: if b_inds[kk]-1 == b_inds[kk-1]: del b_inds[kk] kk -= 1 else: b_inds = [1] for jj, i0 in enumerate(b_inds): # make new z_mat matrix i1 = len(k_inds) if i0 == b_inds[-1] else b_inds[jj+1] inds = np.meshgrid(k_inds[i0:i1], k_inds[i0:i1], indexing='ij') z_mat_new = copy.deepcopy(z_mat[inds[0], inds[1]]) # move further in the tree self._addLayerB(net, node, z_mat_new, alphas, gammas, z_max, k_inds[i0:i1], dz=dz, use_hist=use_hist, add_lin_terms=add_lin_terms) else: # make new z_mat matrix k_seqs = _consecutive(k_inds) if pprint: print('\n>>> consecutive') print('nseq:', len(k_seqs)) for k_seq in k_seqs: print('sequence:', k_seq) for k_seq in k_seqs: inds = np.meshgrid(k_seq, k_seq, indexing='ij') z_mat_new = copy.deepcopy(z_mat[inds[0], inds[1]]) z_max = z_mat[0,0]+1 # move further in the tree self._addLayerB(net, node, z_mat_new, alphas, gammas, z_max, k_seq, dz=dz, pprint=pprint, use_hist=use_hist, add_lin_terms=add_lin_terms) def _addLayerB(self, net, pnode, z_mat, alphas, gammas, z_max_prev, true_loc_inds, dz=100., use_hist=True, pprint=False, add_lin_terms=False): # print stuff if pprint: print('>>> node index = ', node._index) if pnode != None: print('parent index = ', pnode._index) else: print('start') # get the diagonal z_diag = np.diag(z_mat) if true_loc_inds[0] == 0 and z_mat[0,0] > z_max_prev: n_bins = 'soma' z_max = z_mat[0,0] + 1. z_min = z_max_prev else: # histogram GF n_bins = max(int(z_mat.size/50.), int((np.max(z_mat) - np.min(z_mat))/dz)) if n_bins > 1: if np.all(np.diff(z_diag) > 0): z_min = z_max_prev z_max = z_min + dz if pprint: print('--> +', dz) elif use_hist: z_hist = np.histogram(z_mat.flatten(), n_bins, density=False) # find the histogram partition h_ftc = hs.histogramSegmentator(z_hist) s_ind, p_ind = h_ftc.partition_fine_to_coarse() # get the new min max values z_histx = z_hist[1] z_min = z_max_prev z_max = z_histx[s_ind[1]] ii = 1 while np.min(z_diag) > z_histx[s_ind[ii]]: ii += 1 z_max = z_histx[s_ind[ii]] ii = np.argmax(z_hist[0][s_ind[0]:s_ind[ii]]) z_avg = z_hist[0][ii] if z_max - z_min > dz: z_max = z_min + dz if pprint: print('--> hist: +', str(z_max - z_min)) else: z_min = z_max_prev z_max = z_min + dz if pprint: print('--> +', dz) else: z_min = z_max_prev z_max = np.max(z_mat) if pprint: print('--> all: +', str(z_max - z_min)) d_inds = np.where(z_diag <= z_max+1e-15)[0] # make sure that there is at least one element in the layer while len(d_inds) == 0: z_max += dz d_inds = np.where(z_diag <= z_max+1e-15)[0] # identify different domains if add_lin_terms and true_loc_inds[0] == 0: t0 = np.array([1]); t1 = np.array([len(z_diag)]) else: t0 = np.where(np.logical_and(z_diag[:-1] < z_max+1e-15, z_diag[1:] >= z_max+1e-15))[0] if len(t0) > 0: t0 += 1 if z_diag[0] >= z_max+1e-15: t0 = np.concatenate(([0], t0)) t1 = np.where(np.logical_and(z_diag[:-1] >= z_max+1e-15, z_diag[1:] < z_max+1e-15))[0] if len(t1) > 0: t1 += 1 if z_diag[-1] >= z_max+1e-15: t1 = np.concatenate((t1, [len(z_diag)])) # identify where the kernels are within the interval l_inds = np.where(z_mat <= z_max+1e-15) # get the average kernel if l_inds[0].size < 100000: gammas_avg = np.mean(gammas[:,true_loc_inds[l_inds[0]]] * \ gammas[:,true_loc_inds[l_inds[1]]], 1) else: inds_ = np.random.randint(l_inds[0].size, size=100000) gammas_avg = np.mean(gammas[:,true_loc_inds[l_inds[0]][inds_]] * \ gammas[:,true_loc_inds[l_inds[1]][inds_]], 1) self._subtractParentKernels(gammas_avg, pnode) # add a node to the tree node = NETNode(len(net), true_loc_inds, z_kernel=(alphas, gammas_avg)) if pnode != None: net.addNodeWithParent(node, pnode) else: net.root = node if pprint: print('(locind0, size) = ', (true_loc_inds[0], z_mat.shape[0])) print('(zmin, zmax, n_bins) = ', (z_min, z_max, n_bins)) print('') # move on to the next layers if len(d_inds) < len(z_diag): for jj, ind0 in enumerate(t0): ind1 = t1[jj] z_mat_new = copy.deepcopy(z_mat[ind0:ind1,ind0:ind1]) true_loc_inds_new = true_loc_inds[ind0:ind1] self._addLayerB(net, node, z_mat_new, alphas, gammas, z_max, true_loc_inds_new, dz=dz, use_hist=use_hist, pprint=pprint) def _subtractParentKernels(self, gammas, pnode): if pnode != None: gammas -= pnode.z_kernel['c'] self._subtractParentKernels(gammas, pnode.parent_node) def _improveInputImpedance(self, net, alphas, gammas): nmaxind = np.max([n.index for n in net]) for node in net: if len(node.loc_inds) == 1: # recompute the kernel of this single loc layer if node.parent_node is not None: p_kernel = net.calcTotalKernel(node.parent_node) p_k_c = p_kernel.c else: p_k_c = np.zeros_like(gammas) gammas_real = gammas[:,node.loc_inds[0]]**2 node.z_kernel.c = gammas_real - p_k_c elif len(node.newloc_inds) > 0: z_k_approx = net.calcTotalKernel(node) # add new input nodes for the nodes that don't have one for ind in node.newloc_inds: nmaxind += 1 gammas_real = gammas[:,ind]**2 z_k_real = Kernel(dict(a=alphas, c=gammas_real)) # add node newnode = NETNode(nmaxind, [ind], z_kernel=z_k_real-z_k_approx) newnode.newloc_inds = [ind] net.addNodeWithParent(newnode, node) # empty the new indices node.newloc_inds = [] net.setNewLocInds()
[docs] def computeLinTerms(self, net, sov_data=None, eps=1e-4): """ Construct linear terms for `net` so that transfer impedance to soma is exactly matched Parameters ---------- net: `neat.NETree` the neural evaluation tree (NET) sov_data: None or tuple of mode matrices If ``sov_data`` is not ``None``, it is a tuple of a vector of the reciprocals of the mode timescales and a matrix with the corresponding spatial mode functions. eps: float the cutoff threshold in relative importance below which modes are truncated Returns ------- lin_terms: dict of {int: `neat.Kernel`} the kernels associated with linear terms of the NET, keys are indices of their corresponding location stored inder 'net eval' """ if sov_data != None: alphas = sov_data[0] gammas = sov_data[1] else: alphas, gammas = self.getImportantModes(locarg='net eval', eps=eps) lin_terms = {} for ii, loc in enumerate(self.getLocs('net eval')): if not self.isRoot(self[loc['node']]): # create the true kernel z_k_true = Kernel((alphas, gammas[:,ii] * gammas[:,0])) # compute the NET approximation kernel z_k_net = net.getReducedTree([0, ii]).getRoot().z_kernel # compute the lin term lin_terms[ii] = z_k_true - z_k_net return lin_terms