"""
File contains:
- `neat.MorphLoc`
- `neat.MorphNode`
- `neat.MorphTree`
Authors: B. Torben-Nielsen (legacy code) and W. Wybo
"""
import numpy as np
import matplotlib.patheffects as patheffects
import matplotlib.patches as patches
import matplotlib.cm as cm
import matplotlib.pyplot as pl
from mpl_toolkits.axes_grid1 import make_axes_locatable
import warnings
import copy
from functools import reduce
from .stree import SNode, STree
from .compartmenttree import CompartmentNode, CompartmentTree
def originalTreetypeDecorator(fun):
"""
Decorator that provides the safety that the treetype is set to
'original' inside the functions it decorates
"""
# wrapper to access self
def wrapped(self, *args, **kwargs):
current_treetype = self.treetype
self.treetype = 'original'
res = fun(self, *args, **kwargs)
self.treetype = current_treetype
return res
wrapped.__doc__ = fun.__doc__
return wrapped
def computationalTreetypeDecorator(fun):
"""
Decorator that provides the safety that the treetype is set to
'computational' inside the functions it decorates. This decorator also
checks if a computational tree has been defined.
Raises
------
AttributeError
If this function is called and no computational tree has been
defined
"""
# wrapper to access self
def wrapped(self, *args, **kwargs):
if self._computational_root is None:
raise AttributeError('No computational tree has been defined, ' + \
'and this function requires one. Use ' + \
'`MorphTree.setCompTree()` or its ' + \
'overwritten version in one of the derived' + \
'classes')
current_treetype = self.treetype
self.treetype = 'computational'
res = fun(self, *args, **kwargs)
self.treetype = current_treetype
return res
wrapped.__doc__ = fun.__doc__
return wrapped
[docs]class MorphLoc(object):
"""
Stores a location on the morphology. The location is initialized starting
from a node and x-value on the real morphology. The location is also be
stored in the coordinates of the computational morphology. To toggle between
coordinates, the class stores a reference to the morphology tree on which
the location is defined, and returns either the original coordinate or the
coordinate on the computational tree, depending on which tree is active.
Initialized based on either a tuple or a dict where one entry specifies the
node index and the other entry the x-coordinate specifying the location
between parent node (x=0) or the node indicated by the index (x=1), or on
a `neat.MorphLoc`.
Parameters
----------
loc: tuple or dict or `neat.MorphLoc`
if tuple: (node index, x-value)
if dict: {'node': node index, 'x': x-value}
reftree: `neat.MorphTree`
set_as_comploc: bool
if True, assumes the paremeters provided in `loc` are coordinates
on the computational tree. Doing this while no computational tree
has been initialized in `reftree` will result in an error.
Defaults to False
Raises
------
ValueError
If x-coordinate of location is not in ``[0,1]``
"""
def __init__(self, loc, reftree, set_as_comploc=False):
self.reftree = reftree
if isinstance(loc, tuple):
x = float(loc[1])
if x > 1. or x < 0.:
raise ValueError('x-value should be in [0,1]')
if set_as_comploc:
self.comp_loc = {'node': int(loc[0]), 'x': x}
self._setOriginalLoc()
else:
self.loc = {'node': int(loc[0]), 'x': x}
elif isinstance(loc, dict):
x = float(loc['x'])
if x > 1. or x < 0.:
raise ValueError('x-value should be in [0,1]')
if set_as_comploc:
self.comp_loc = loc
self._setOriginalLoc()
else:
self.loc = loc
elif isinstance(loc, MorphLoc):
# self.__dict__.update({key: copy.deepcopy(val)
# for key, val in loc.__dict__.iteritems()
# if key != 'reftree'})
self.loc = loc.loc
self.reftree = reftree
else:
raise TypeError('Not a valid location type, should be tuple or dict')
def __getitem__(self, key):
if isinstance(key, int) and key in (0,1):
key = 'node' if key == 0 else 'x'
if isinstance(key, str):
if self.reftree.treetype == 'computational':
try:
return self.comp_loc[key]
except AttributeError:
self._setComputationalLoc()
return self.comp_loc[key]
else:
return self.loc[key]
def __eq__(self, other_loc):
if type(other_loc) == dict:
result = (other_loc['node'] == self.loc['node'])
if self.loc['node'] != 1:
result *= np.allclose(other_loc['x'], self.loc['x'])
return result
elif type(other_loc) == tuple:
result = (other_loc[0] == self.loc['node'])
if self.loc['node'] != 1:
result *= np.allclose(other_loc[1], self.loc['x'])
return result
elif isinstance(other_loc, MorphLoc):
result = (other_loc.loc['node'] == self.loc['node'])
if self.loc['node'] != 1:
result *= np.allclose(other_loc.loc['x'], self.loc['x'])
return result
else:
return NotImplemented
def keys(self):
return ['node', 'x']
def __iter__(self):
yield self['node']
yield self['x']
def __neq__(self, other_loc):
result = self.__eq__(other_loc)
if result is NotImplemented:
return result
else:
return not result
def __copy__(self):
"""
Customization of the copy function so that `loc` and `comp_loc`
attributes are deep copied and `reftree` attribute still refers to the
original tree
"""
new_loc = type(self)(copy.deepcopy(self.loc), self.reftree)
if hasattr(self, 'comp_loc'):
new_loc.__dict__.update({'comp_loc': copy.deepcopy(self.comp_loc)})
return new_loc
def __str__(self):
return '{\'node\': %d, \'x\': %.2f }'%(self.loc['node'], self.loc['x'])
def __repr__(self):
return str(self)
def _setComputationalLoc(self):
if self.loc['node'] != 1:
current_treetype = self.reftree.treetype
self.reftree.treetype = 'original'
node = self.reftree[self.loc['node']]
# find the computational nodes that are resp. up and down from the node
node_start = self.reftree._findCompnodeUp(node.parent_node)
node_stop = self.reftree._findCompnodeDown(node)
# length between loc and parent computational node to compute segment
# length
L = self.reftree.pathLength({'node': node_start.index, 'x': 1.},
self.loc)
# get the computational nodes' length
self.reftree.treetype = 'computational'
L_cn = self.reftree[node_stop.index].L
self.reftree.treetype = 'original'
# set the computational loc
self.comp_loc = {'node': node_stop.index, 'x': L/L_cn}
# reset treetype to its former value
self.reftree.treetype = current_treetype
else:
self.comp_loc = copy.deepcopy(self.loc)
def _setOriginalLoc(self):
if self.comp_loc['node'] != 1:
current_treetype = self.reftree.treetype
self.reftree.treetype = 'computational'
compnode = self.reftree[self.comp_loc['node']]
self.reftree.treetype = 'original'
node = self.reftree[self.comp_loc['node']]
# find the computational node that is down from the original node
pcnode = self.reftree._findCompnodeUp(node.parent_node)
# find the node index and x-coordinate of the original location
path = self.reftree.pathBetweenNodes(pcnode, node)
L0 = 0. ; found = False
for pathnode in path[1:]:
L1 = L0 + pathnode.L
Lloc = self.comp_loc['x']*compnode.L
if Lloc == 0.: Lloc += 1e-7
if Lloc > L0 and Lloc <= L1:
self.loc = {'node': pathnode.index,
'x': (Lloc-L0-1e-8) / pathnode.L}
L0 = L1
if self.loc['x'] > 1. or self.loc['x'] < 0.:
raise ValueError('x-value should be in [0,1]')
# reset treetype to its former value
self.reftree.treetype = current_treetype
else:
self.loc = copy.deepcopy(self.comp_loc)
[docs]class MorphNode(SNode):
"""
Node associated with `neat.MorphTree`. Stores the geometrical information
associated with a point on the tree morphology
Attributes
----------
xyz: numpy.array of floats
The xyz-coordinates associated with the node (um)
R: float
The radius of the node (um)
swc_type: int
The type of node, according to the .swc file format convention:
``1`` is dendrites, ``2`` is axon, ``3`` is basal dendrite and ``4``
is apical dendrite.
L: float
The length of the node (um)
"""
def __init__(self, index, p3d=None):
super().__init__(index)
if p3d != None:
self.setP3D(*p3d)
else:
# bogus values, to overwrite
self.setP3D(np.array([0.,0.,0.]), 1., 1)
self.L = 1.
self.R = 1.
[docs] def setP3D(self, xyz, R, swc_type):
"""
Set the 3d parameters of the node
Parameters
----------
xyz: `np.array`
3D location (um)
R: float
Radius of the segment (um)
swc_type: int
Type asscoiated with the segment according to SWC standards
"""
# morphology parameters
self.xyz = xyz
self.R = R
self.swc_type = swc_type
# auxiliary variable
self.used_in_comp_tree = False
def setLength(self, L):
"""
Set the length of the segment represented by the node
Parameters
----------
L: float
the length of the segment (um)
"""
self.L = L
def setRadius(self, R):
"""
Set the radius of the segment represented by the node
Parameters
----------
L: float
the radius of the segment (um)
"""
self.R = R
def getChildNodes(self, skip_inds=(2,3)):
"""
Get the `child_nodes` of this node. Indices ``2`` and ``3`` are skipped
by default (3-point soma convention)
Parameters
----------
skip_inds: list or tuple of ints
Node indices of child nodes that are not added to the returned list
Returns
-------
list of `neat.MorphNode`
The child nodes
"""
return [cnode for cnode in self._child_nodes \
if cnode.index not in skip_inds]
def setChildNodes(self, cnodes):
return super().setChildNodes(cnodes)
child_nodes = property(getChildNodes, setChildNodes)
def __str__(self, **kwarg):
return super().__str__(**kwarg)
[docs]class MorphTree(STree):
"""
Subclass of simple tree that implements neuronal morphologies. Reads in
trees from '.swc' files (http://neuromorpho.org/).
Neural morphologies are assumed to follow the three-point soma conventions.
Internally however, the soma is represented as a sphere. Hence nodes with
indices 2 and 3 do not represent anything and are skipped in iterations and
getters.
Can also store a simplified version of the original tree, where only nodes
are retained that should hold computational parameters - the root, the
bifurcation nodes and the leafs at least, although the user can also
specify additional nodes. One tree is set as primary by changing the
`treetype` attribute (select 'original' for the original morphology and
'computational' for the computational morphology). Lookup operations will
often use the primary tree. Using nodes from the other tree for lookup
operations is unsafe and should be avoided, it is better to set the proper
tree to primary first.
For computational efficiency, it is possible to store sets of locations on
the morphology, under user-specified names. These sets are stored as
lists of `neat.MorphLoc`, and associated arrays are stored that contain the
corresponding node indices of the locations, their x-coordinates, their
distances to the soma and their distances to the nearest bifurcation in the
'up'-direction
Parameters
----------
file_n: str (optional)
the file name of the morphology file. Assumed to follow the '.swc' format.
Default is ``None``, which initialized an empty tree
types: list of int (optional)
The list of node types to be included. As per the '.swc' convention,
``1`` is soma, ``2`` is axon, ``3`` is basal dendrite and ``4`` apical
dendrite. Default is ``[1,3,4]``.
Attributes
----------
root: `neat.MorphNode` instance
The root of the tree.
locs: dict {str: list of `neat.MorphLoc`}
Stored sets of locations, key is the user-specified the name of the
set of locations. Initialized as empty dict.
nids: dict {str: np.array of int}
Node indices of locations Initialized as empty dict.
xs: dict {str: np.array of float}
x-coordinates of locations Initialized as empty dict.
d2s: dict {str: np.array of float}
distances to soma of locations Initialized as empty dict.
d2b: dict {str: np.array of float}
distances to nearest bifurcation in 'up' direction of locations.
Initialized as empty dict.
"""
def __init__(self, file_n=None, types=[1,3,4]):
self._treetype = 'original' # alternative 'computational'
if file_n != None:
self.readSWCTreeFromFile(file_n, types=types)
# self._original_root = self.root
else:
self._original_root = None
self._computational_root = None
# to store sets of locations on the morphology
self.locs = {}
self._nids_orig = {}; self._nids_comp = {}
self._xs_orig = {}; self._xs_comp = {}
self.d2s = {}
self.d2b = {}
self.leafinds = {}
[docs] def __getitem__(self, index, skip_inds=(2,3)):
"""
Returns the node with given index, if no such node is in the tree, None
is returned.
Parameters
----------
index: int
the index of the node to be found
Returns:
`neat.MorphNode` or None
"""
return self._findNode(self.root, index, skip_inds=skip_inds)
def _findNode(self, node, index, skip_inds=(2,3)):
"""
Sweet breadth-first/stack iteration to replace the recursive call.
Traverses the tree until it finds the node you are looking for.
Returns SNode when found and None when not found
Parameters
----------
node: :class:`SNode` (optional)
node where the search is started
index: int
the index of the node to be found
Returns
-------
:class:`SNode`
"""
stack = [];
stack.append(node)
while len(stack) != 0:
for cnode in stack:
if cnode.index == index:
return cnode
else:
stack.remove(cnode)
stack.extend(cnode.getChildNodes(skip_inds=skip_inds))
return None # Not found!
[docs] def __iter__(self, node=None, skip_inds=(2,3)):
"""
Overloaded iterator from parent class that avoids iterating over the
nodes with index 2 and 3
Parameters
----------
node: `neat.MorphNode`
The starting node. Defaults to the root
skip_inds: tuple of ints
Indices of the nodes that are skipped by the iterator. Defaults
to ``(2,3)``, the nodes that contain extra geometrical
information on the soma.
Yields
------
`neat.MorphNode`
Nodes in the tree
"""
if node is None:
node = self.root
if node is not None:
if node.index not in skip_inds: yield node
for cnode in node.getChildNodes(skip_inds=skip_inds):
for inode in self.__iter__(cnode, skip_inds=skip_inds):
if node.index not in skip_inds: yield inode
def getRoot(self):
"""
Returns the root of the original or the computational tree, depending
on which `treetype` is active.
"""
if self.treetype == 'original':
return self._original_root
else:
return self._computational_root
def setRoot(self, node):
if self.treetype == 'original':
node.parent_node = None
self._original_root = node
else:
node.parent_node = None
self._computational_root = node
root = property(getRoot, setRoot)
[docs] def getNodes(self, recompute_flag=0, skip_inds=(2,3)):
"""
Overloads the parent function to allow skipping nodes with certain
indices and to return the nodes associated with the corresponding
`treetype`.
Parameters
----------
recompute_flag: bool
whether or not to re-evaluate the node list. Defaults to False.
skip_inds: tuple of ints
Indices of the nodes that are skipped by the iterator. Defaults
to ``(2,3)``, the nodes that contain extra geometrical
information on the soma.
Returns
-------
list of `neat.MorphNode`
"""
if self.treetype == 'original':
if not hasattr(self, '_nodes_orig') or recompute_flag:
self._nodes_orig = []
self._gatherNodes(self.root, self._nodes_orig,
skip_inds=skip_inds)
return self._nodes_orig
else:
if not hasattr(self, '_nodes_comp') or recompute_flag:
self._nodes_comp = []
self._gatherNodes(self.root, self._nodes_comp,
skip_inds=skip_inds)
return self._nodes_comp
def setNodes(self, illegal):
raise AttributeError("`nodes` is a read-only attribute")
nodes = property(getNodes, setNodes)
def _gatherNodes(self, node, node_list=[], skip_inds=(2,3)):
"""
Overloaded gathering function that avoids appending nodes with index 2
or 3 to the list.
Parameters
----------
node: `neat.MorphNode`
node_list: list of `neat.MorphNode`
"""
if node.index not in skip_inds: node_list.append(node)
for cnode in node.getChildNodes(skip_inds=skip_inds):
self._gatherNodes(cnode, node_list=node_list, skip_inds=skip_inds)
[docs] def getLeafs(self, recompute_flag=0):
"""
Overloads the `getLeafs` of the parent class to return the leafs
in the current `treetype`.
Parameters
----------
recompute_flag: bool
Whether to force recomputing the leaf list. Defaults to 0.
"""
if self.treetype == 'original':
if not hasattr(self, '_leafs_orig') or recompute_flag:
self._leafs_orig = [node for node in self if self.isLeaf(node)]
return self._leafs_orig
else:
if not hasattr(self, '_leafs_comp') or recompute_flag:
self._leafs_comp = [node for node in self if self.isLeaf(node)]
return self._leafs_comp
def setLeafs(self, illegal):
raise AttributeError("`leafs` is a read-only attribute")
leafs = property(getLeafs, setLeafs)
[docs] def getNodesInBasalSubtree(self):
"""
Return the nodes associated with the basal subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the basal subtree
"""
return [node for node in self if node.swc_type in [3]]
[docs] def getNodesInApicalSubtree(self):
"""
Return the nodes associated with the apical subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the apical subtree
"""
return [node for node in self if node.swc_type in [4]]
[docs] def getNodesInAxonalSubtree(self):
"""
Return the nodes associated with the apical subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the apical subtree
"""
return [node for node in self if node.swc_type in [2]]
[docs] def setTreetype(self, treetype):
"""
Set the active tree
Parameters
----------
treetype: 'original' or 'computational'
the treetype thas is set to active
"""
if treetype == 'original':
self._treetype = treetype
self.root = self._original_root
elif treetype == 'computational':
if self._computational_root != None:
self._treetype = treetype
self.root = self._computational_root
else:
raise ValueError('no computational tree has been defined, \
`treetype` can only be \'original\'')
else:
raise ValueError('`treetype` can be \'original\' or \'computational\'')
def getTreetype(self):
return self._treetype
treetype = property(getTreetype, setTreetype)
def _createCorrespondingNode(self, node_index, p3d=None):
"""
Creates a node with the given index corresponding to the tree class.
Parameters
----------
node_index: int
index of the new node
"""
return MorphNode(node_index, p3d=p3d)
[docs] def readSWCTreeFromFile(self, file_n, types=[1,3,4]):
"""
Non-specific for a "tree data structure"
Read and load a morphology from an SWC file and parse it into
an `neat.MorphTree` object.
On the NeuroMorpho.org website, 5 types of somadescriptions are
considered (http://neuromorpho.org/neuroMorpho/SomaFormat.html).
The "3-point soma" is the standard and most files are converted
to this format during a curation step. `neat` follows this default
specification and the *internal structure of `neat` implements
the 3-point soma*. Additionally multi-cylinder descriptions with more
than three nodes are also supported, but are converted to the standard
three point description.
Additionally, the root node of the tree must have ``index == 1``,
``swc_type == 1`` and occur first in the SWC file.
Parameters
-----------
file_n: str
name of the file to open
types: list of ints
NeuroMorpho.org segment types to be loaded
Examples
--------
The three point description is
.. code-block:: python
1 1 x y z r -1
1 1 x y-r z r 1
1 1 x y+r z r 1
with `x,y,z` the coordinates of the soma center and `r` the soma radius
This is a valid three point desciption
.. code-block:: python
# start of file
1 1 45.3625 18.6775 -50.25 10.1267403895 -1
2 1 45.3625 8.55075961052 -50.25 10.1267403895 1
3 1 45.3625 28.8042403895 -50.25 10.1267403895 1
# dendrite nodes
4 3 37.76 12.99 -46.08 0.29 1
5 3 26.7068019951 8.26344199599 -36.9426896493 0.795614809475 4
# ...
This is a valid multi-cylinder descirption
.. code-block:: python
# start of file
1 1 1066.38 399.67 157.0 4.9215 -1
2 1 1071.3 399.67 157.0 4.9215 1
3 1 1076.22 399.67 157.0 4.9215 2
4 1 1066.5 402.83 157.0 11.494 2
5 1 1062.4 405.5 157.0 15.308 4
6 1 1056.6 410.25 158.0 20.536 5
7 1 1056.6 410.25 158.0 20.536 6
8 1 1070.0 427.75 161.0 2.305 7
# dendrite nodes
9 3 1070.0 427.75 161.0 0.886 8
# ...
Raises
------
ValueError
If the SWC file is not consistent with the aforementioned conventions
"""
# check soma-representation: 3-point soma or a non-standard representation
soma_type = self.determineSomaType(file_n)
file = open(file_n,'r')
all_nodes = dict()
for line in file :
if not line.startswith('#') :
split = line.split()
index = int(split[0].rstrip())
swc_type = int(split[1].rstrip())
x = float(split[2].rstrip())
y = float(split[3].rstrip())
z = float(split[4].rstrip())
radius = float(split[5].rstrip())
parent_index = int(split[6].rstrip())
# create the nodes
if swc_type in types:
p3d = (np.array([x,y,z]), radius, swc_type)
node = self._createCorrespondingNode(index, p3d)
all_nodes[index] = (swc_type, node, parent_index)
# check if node with index 1 is soma node (swc_type == 1)
if all_nodes[1][0] != 1:
raise ValueError('Node with index 1 should be soma-type, i.e. swc_type == 1')
# standard three point soma representation
if soma_type == 1:
for index, (swc_type, node, parent_index) in list(all_nodes.items()) :
if index == 1:
self.setRoot(node)
elif index in (2,3):
# the 3-point soma representation
# (http://neuromorpho.org/neuroMorpho/SomaFormat.html)
somanode = all_nodes[1][1]
self.addNodeWithParent(node, somanode)
else:
parent_node = all_nodes[parent_index][1]
self.addNodeWithParent(node, parent_node)
# check if soma follows three point convention
radius_arr = np.array([all_nodes[1][1].R,
all_nodes[2][1].R,
all_nodes[3][1].R,
np.linalg.norm(all_nodes[2][1].xyz - all_nodes[1][1].xyz),
np.linalg.norm(all_nodes[3][1].xyz - all_nodes[1][1].xyz)])
if not np.allclose(np.abs(radius_arr - radius_arr[0]),
np.zeros_like(radius_arr), atol=2e-2):
raise ValueError('Soma radii not consistent with three-point convention')
# IF multiple cylinder soma representation
elif soma_type == 2:
self.setRoot(all_nodes[1][1])
# get all soma info
soma_cylinders = []
connected_to_root = []
for index, (swc_type, node, parent_index) in list(all_nodes.items()) :
if swc_type == 1 and not index == 1:
soma_cylinders.append((node, parent_index))
if index > 1 :
connected_to_root.append(index)
# make soma
s_node_2, s_node_3 = \
self._makeSomaFromCylinders(soma_cylinders, all_nodes)
# add soma
self.root.R = s_node_2.R
self.addNodeWithParent(s_node_2, self.root)
self.addNodeWithParent(s_node_3, self.root)
# add the other points
for index, (swc_type, node, parent_index) in list(all_nodes.items()) :
if swc_type == 1:
pass
else:
parent_node = all_nodes[parent_index][1]
if parent_node.index in connected_to_root:
self.addNodeWithParent(node, self.root)
else:
self.addNodeWithParent(node, parent_node)
# set the lengths of the nodes
for node in self:
if node.parent_node != None:
L = np.sqrt(np.sum((node.parent_node.xyz - node.xyz)**2))
else:
L = 0.
node.setLength(L)
return self
def _makeSomaFromCylinders(self, soma_cylinders, all_nodes):
"""
Construct 3-point soma
Step 1: calculate surface of all cylinders
Step 2: make 3-point representation with the same surface
"""
total_surf = 0
xyz_sum = self.root.xyz
for (node, parent_index) in soma_cylinders:
parent = all_nodes[parent_index][1]
nxyz = node.xyz
pxyz = parent.xyz
H = np.sqrt(np.sum( (nxyz - pxyz)**2 ))
surf = 2 * np.pi * parent.R * H
total_surf += surf
xyz_sum += node.xyz
# define apropriate radius
radius = np.sqrt(total_surf / (4.*np.pi))
rp = xyz_sum / (len(soma_cylinders)+1.)
rp2 = np.array([rp[0], rp[1] - radius, rp[2]])
rp3 = np.array([rp[0], rp[1] + radius, rp[2]])
self.root.xyz = rp
# create the soma nodes
s_node_2 = self._createCorrespondingNode(2, (rp2, radius, 1))
s_node_3 = self._createCorrespondingNode(3, (rp3, radius, 1))
return s_node_2, s_node_3
[docs] def determineSomaType(self, file_n):
"""
Determine the soma type used in the SWC file.
This method searches the whole file for soma entries.
Only tbe standard three-point soma type and a multi-cylinder description
are supported.
Furthermore, the root node of the tree must have ``index == 1``,
``swc_type == 1`` and occur first in the SWC file.
Parameters
----------
file_n: string
Name of the file containing the SWC description
Returns
-------
soma_type: int
Integer indicating one of the su[pported SWC soma formats.
1: Default three-point soma,
2: multiple cylinder description
Raises
------
ValueError
If soma type is not supported (less than three nodes have soma)
"""
file = open(file_n, 'r')
somas = 0
for line in file:
if not line.startswith('#') :
split = line.split()
index = int(split[0].rstrip())
s_type = int(split[1].rstrip())
if s_type == 1 :
somas = somas +1
file.close()
if somas == 3:
return 1
elif somas < 3:
raise ValueError('Soma description not supported, use 3-point or multi-cylinder description')
else:
return 2
[docs] def _evaluateCompCriteria(self, node, eps=1e-8, rbool=False):
"""
Return ``True`` if relative difference between node radius and parent
node raidus is larger than margin ``eps``, or if the node is the root
or bifurcation node.
Parameters
----------
node: `neat.MorphNode`
node that is compared to parent node
eps: float (optional, default ``1e-8``)
the margin
return
------
bool
"""
if not rbool:
rbool = node.parent_node == None
if not rbool:
rbool = len(node.getChildNodes()) != 1
if not rbool:
cnode = node.child_nodes[0]
rbool = np.abs(node.R - cnode.R) > eps * np.max([node.R, cnode.R])
return rbool
[docs] def setCompTree(self, compnodes=None, set_as_primary_tree=False, eps=1e-8):
"""
Sets the nodes that contain computational parameters. This are a priori
either bifurcations, leafs, the root or nodes where the neurons'
relevant parameters change.
Parameters
----------
compnodes: list of ::class::`MorphNode`
list of nodes that should be retained in the computational tree.
Note that specifying bifurcations, leafs or the root is
superfluous, since they are part of the computational tree by
default.
set_as_primary_tree: bool (default ``False``)
if True, sets the computational tree as the primary tree
eps: float (default ``1e-8``)
relative margin for parameter change
"""
self.removeCompTree()
if compnodes is None:
compnodes = []
compnodes += [node for node in self if self._evaluateCompCriteria(node, eps=eps)]
compnode_indices = [node.index for node in compnodes]
nodes = copy.deepcopy(self.nodes)
for node in nodes:
if node.index not in compnode_indices:
self.removeSingleNode(node)
elif node.parent_node != None:
orig_node = self[node.index]
orig_bnode = node.parent_node
L, R = self.pathLength({'node': orig_bnode.index, 'x': 1.},
{'node': orig_node.index, 'x': 1.},
compute_radius=1)
node.setLength(L)
node.setRadius(R)
node.used_in_comp_tree = True
orig_node.used_in_comp_tree = True
else:
orig_node = self[node.index]
node.used_in_comp_tree = True
orig_node.used_in_comp_tree = True
self._computational_root = \
next(node for node in nodes if node.index == 1)
self._leafs_comp = [node for node in nodes if self.isLeaf(node)]
self._nodes_comp = []
self._gatherNodes(self._computational_root, self._nodes_comp)
if set_as_primary_tree:
self.treetype = 'computational'
# create conversion of all coordinate arrays
for name in self.locs:
self._storeCompLocs(name)
def _findCompnodeUp(self, node):
"""
!!! Computational tree has to be initialized, otherwise may results in
error !!!
If the input node is a node of the original tree, finds the first node
on the path to the root that has an equivalent in the computational tree.
If the input node has such an equivalent, it is returned itself.
If the input node is in the computational tree, returns the node itself.
Parameters
----------
node: `neat.MorphNode` instance
the input node
Returns
-------
`neat.MorphNode` instance
"""
if not node.used_in_comp_tree:
node = self._findCompnodeUp(node.parent_node)
return node
def _findCompnodeDown(self, node):
"""
!!! Computational tree has to be initialized, otherwise may results in
error !!!
If the input node is a node of the original tree, finds the first node
away from the root that has an equivalent in the computational tree. If
the input node has such an equivalent, it is returned itself.
If the input node is in the computational tree, returns the node itself.
Parameters
----------
node: `neat.MorphNode` instance
the input node
Returns
-------
`neat.MorphNode` instance
"""
if not node.used_in_comp_tree:
node = self._findCompnodeDown(node.child_nodes[0])
return node
[docs] def removeCompTree(self):
"""
Removes the computational tree
"""
self._computational_root = None
try:
delattr(self, "_nodes_comp")
except AttributeError as err:
pass
try:
delattr(self, "_leafs_comp")
except AttributeError as err:
pass
self.treetype = 'original'
for node in self:
node.used_in_comp_tree = False
[docs] def _convertLocArgToLocs(self, locarg):
"""
Converts locations argument to list of `neat.MorphLoc`.
Parameters
----------
locarg: list of dictionaries, tuples or `neat.MorphLoc`, or string
* If list, entries should be valid arguments to initialize a `neat.MorphLoc`
* If string, should be the name of a list of locations stored in `self`
Returns
-------
list of `neat.MorphLoc`
List of locations, each referencing the current tree
"""
if isinstance(locarg, list):
locs = [MorphLoc(loc, self) for loc in locarg]
elif isinstance(locarg, str):
self._tryName(locarg)
locs = self.getLocs(locarg)
else:
raise IOError('`locarg` should be list of locs or string')
return locs
[docs] def _convertNodeArgToNodes(self, node_arg):
"""
Converts a node argument to a list of nodes. Behaviour depends on the
type of argument.
Parameters
----------
node_arg: ``None``, `neat.MorphNode`, {'apical', 'basal', 'axonal'} or iterable collection of instances of `neat.MorphNode`
* `None`: returns all nodes
* `neat.MorphNode`: returns list of nodes in the subtree of the given node
* {'apical', 'basal', 'axonal'}: returns list of nodes in the apical, basal or axonal subtree
* iterable collection of `neat.MorphNode`: returns the same list of nodes
If an iterable collection of original nodes is given, and the treetype
is computational, a reduced list is returned where only the corresponding
computational nodes are included. If an iterable collection of
computational nodes is given, and the treetype is original, a list of
corresponding original nodes is given, but the in between nodes are not
added.
Returns
-------
list of `neat.MorphNode`
"""
# convert the input argument to a list of nodes
if node_arg == None:
nodes = self.nodes
elif isinstance(node_arg, MorphNode):
if self.treetype == 'computational':
# assure that a list of computational nodes is returned
node_arg = self._findCompnodeDown(node_arg)
node_arg = self[node_arg.index]
else:
# assure that a list of original nodes is returned
node_arg = self[node_arg.index]
nodes = self.gatherNodes(node_arg)
elif node_arg == 'apical':
nodes = self.getNodesInApicalSubtree()
elif node_arg == 'basal':
nodes = self.getNodesInBasalSubtree()
elif node_arg == 'axonal':
nodes = self.getNodesInAxonalSubtree()
else:
try:
nodes = []
for node in node_arg:
assert isinstance(node, MorphNode)
if self.treetype == 'computational':
# assure that a list of computational nodes is returned
node_ = self._findCompnodeDown(node)
compnode = self[node_.index]
if compnode not in nodes:
nodes.append(compnode)
else:
# assure that a list of original nodes is returned
nodes.append(self[node.index])
except (AssertionError, TypeError):
raise ValueError('input should be (i) `None`, (ii) an instance of '
'`neat.MorphNode`, (iii) one of the following 3 strings '
'\'apical\', \'basal\' or \'axonal\' or (iv) an iterable '
'collection of instances of :class:MorphNode')
return nodes
[docs] def pathLength(self, loc1, loc2, compute_radius=0):
"""
Find the length of the direct path between loc1 and loc2
Parameters
----------
loc1: dict, tuple or `neat.MorphLoc`
one location
loc2: dict, tuple or `neat.MorphLoc`
other location
compute_radius: bool
if True, also computes the average weighted radius of the path
Returns
-------
L, R (optional)
L: float
length of path, in micron
R: float
weighted average radius of path, in micron
"""
# define location objects
if type(loc1) == dict or type(loc1) == tuple:
loc1 = MorphLoc(loc1, self)
if type(loc2) == dict or type(loc2) == tuple:
loc2 = MorphLoc(loc2, self)
# start path length calculation
if loc1['node'] == loc2['node']:
node = self[loc1['node']]
if node.index == 1:
L = 0. # soma is spherical and has no lenght
else:
L = node.L * np.abs(loc1['x'] - loc2['x'])
if compute_radius:
R = node.R
else:
node1 = self[loc1['node']]
node2 = self[loc2['node']]
path1 = self.pathToRoot(node1)[::-1]
path2 = self.pathToRoot(node2)[::-1]
path = path1 if len(path1) < len(path2) else path2
ind = next((ii for ii in range(len(path)) if path1[ii] != path2[ii]),
len(path))
if path1[ind-1] == node1:
L = node1.L * (1. - loc1['x'])
L += sum(node.L for node in path2[ind:-1])
L += node2.L * loc2['x']
if compute_radius:
R = node1.R * node1.L * (1. - loc1['x'])
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * loc2['x']
R /= L
elif path2[ind-1] == node2:
L = node1.L * loc1['x']
L += sum(node.L for node in path1[ind:-1])
L += node2.L * (1. - loc2['x'])
if compute_radius:
R = node1.R * node1.L * loc1['x']
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * (1. - loc2['x'])
R /= L
else:
L = node1.L * loc1['x']
L += sum(node.L for node in path1[ind:-1])
L += sum(node.L for node in path2[ind:-1])
L += node2.L * loc2['x']
if compute_radius:
R = node1.R * node1.L * loc1['x']
R += sum(node.R * node.L for node in path1[ind:-1])
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * loc2['x']
R /= L
if compute_radius:
return L, R
else:
return L
[docs] @originalTreetypeDecorator
def storeLocs(self, locs, name, warn=True):
"""
Store locations under a specified name
Parameters
----------
locs: list of dicts, tuples or `neat.MorphLoc`
the locations to be stored
name: string
name under which these locations are stored
warn: bool (default ``True``)
raise a `UserWarning` if two or more locations in `locs` refer
to the soma. Choose ``False`` if this is desired to remove
the warning.
"""
# copy list and store in MorphLoc if necessary
locs_ = []
n1 = 0
for loc in locs:
locs_.append(MorphLoc(loc, self))
if locs_[-1]['node'] == 1: n1 += 1
if n1 > 1 and warn:
warnings.warn('There are multiple locations on the soma in this set ' + \
'locations, this can cause issues in certain functions', UserWarning)
self.removeLocs(name)
self.locs[name] = locs_
self._nids_orig[name] = np.array([loc['node'] for loc in locs_])
self._xs_orig[name] = np.array([loc['x'] for loc in locs_])
if self._computational_root != None:
self._storeCompLocs(name)
@computationalTreetypeDecorator
def _storeCompLocs(self, name):
self._nids_comp[name] = np.array([loc['node'] for loc in self.locs[name]])
self._xs_comp[name] = np.array([loc['x'] for loc in self.locs[name]])
[docs] @originalTreetypeDecorator
def addLoc(self, loc, name):
"""
Add location to set of locations of given name
Parameters
----------
loc: dict, tuple or `neat.MorphLoc`
the location to be added
name: str
the name of the set of locations to which the location is added
"""
loc = MorphLoc(loc, self)
self.locs[name].append(loc)
self._nids_orig[name] = np.concatenate((self._nids_orig[name], [loc['node']]))
self._xs_orig[name] = np.concatenate((self._xs_orig[name], [loc['x']]))
if self._computational_root != None:
self._addCompLoc(loc, name)
@computationalTreetypeDecorator
def _addCompLoc(self, loc, name):
self._nids_comp[name] = np.concatenate((self._nids_comp[name], [loc['node']]))
self._xs_comp[name] = np.concatenate((self._xs_comp[name], [loc['x']]))
[docs] def clearLocs(self):
"""
Remove all set of locs stored in the tree
"""
self.locs = {}
self._nids_orig = {}; self._nids_comp = {}
self._xs_orig = {}; self._xs_comp = {}
self.d2s = {}
self.d2b = {}
self.leafinds = {}
[docs] def removeLocs(self, name):
"""
Remove a set of locations of a given name
Parameters
----------
name: string
name under which the desired list of locations is stored
"""
try:
del self.locs[name]
del self._nids_orig[name]
del self._nids_comp[name]
del self._xs_orig[name]
del self._xs_comp[name]
except KeyError: pass
# warnings.warn('Locations of name %s were not defined'%name)
try:
del self.d2s[name]
except KeyError: pass
try:
del self.d2b[name]
except KeyError: pass
try:
del self.leafinds[name]
except KeyError: pass
[docs] def _tryName(self, name):
"""
Tests if the name is in use. Raises a KeyError when it is not in use and
prints a list of possible names
Parameters
----------
name: string
name of the desired list of locations
Raises
------
KeyError
If 'name' does not refer to a set of locations in use
"""
try:
self.locs[name]
except KeyError as err:
err.args = ('\'' + err.args[0] \
+ '\' name not in use. Possible names are ' \
+ str(list(self.locs.keys())),)
raise
[docs] def getLocs(self, name):
"""
Returns a set of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
Returns
-------
list of `neat.MorphLoc`
"""
self._tryName(name)
return self.locs[name]
[docs] def getNodeIndices(self, name):
"""
Returns an array of node indices of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
Returns
-------
numpy.array of ints
"""
self._tryName(name)
return self.nids[name]
def getNids(self):
if self.treetype == 'original':
return self._nids_orig
else:
return self._nids_comp
def setNids(self, nids):
if self.treetype == 'original':
self._nids_orig = nids
else:
self._nids_comp = nids
nids = property(getNids, setNids)
[docs] def getXCoords(self, name):
"""
Returns an array of x-values of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
"""
self._tryName(name)
return self.xs[name]
def getXs(self):
if self.treetype == 'original':
return self._xs_orig
else:
return self._xs_comp
def setXs(self, xs):
if self.treetype == 'original':
self._xs_orig = xs
else:
self._xs_comp = xs
xs = property(getXs, setXs)
[docs] def getLocindsOnNode(self, name, node):
"""
Returns a list of the indices of locations in the list of a given name
that are on a the input node, ordered for increasing x
Parameters
----------
name: string
which list of locations to consider
node: `neat.MorphNode`
the node to consider. Should be part of the original
tree
Returns
-------
list of ints
indices of locations on the path
"""
self._tryName(name)
nids = self.nids[name]
xs = self.xs[name]
# get the locinds on the node
inds = np.where(nids == node.index)[0]
sortinds = np.argsort(xs[inds])
return inds[sortinds].tolist()
[docs] def getLocindsOnNodes(self, name, node_arg):
"""
Returns a list of the indices of locations in the list of a given name
that are on one of the nodes specified in the node list. Within each
node, locations are ordered for increasing x
Parameters
----------
name: string
which list of locations to consider
node_arg:
see documentation of `MorphTree._convertNodeArgToNodes`
Returns
-------
list of ints
indices of locations on the path
"""
# find locinds on all nodes
locinds = []
for node in self._convertNodeArgToNodes(node_arg):
locinds.extend(self.getLocindsOnNode(name, node))
return locinds
[docs] def getLocindsOnPath(self, name, node0, node1, xstart=0., xstop=1.):
"""
Returns a list of the indices of locations in the list of a given name
that are on the given path. The path is taken to start at the input
x-start coordinate of the first node in the list and to stop at the
given x-stop coordinate of the last node in the list
Parameters
----------
name: string
which list of locations to consider
node0: :class:`SNode`
start node of path
node1: :class:`SNode`
stop node of path
xstart: float (in ``[0,1]``)
starting coordinate on `node0`
xstop: float (in ``[0,1]``)
stopping coordinate on `node1`
Returns
-------
list of ints
Indices of locations on the path. If path is empty, an empty
array is returned.
"""
self._tryName(name)
locs = self.locs[name]
xs = self.xs[name]
# find the path
path = self.pathBetweenNodes(node0, node1)
# find the location indices
locinds = []
if len(path) > 1:
# first node in path
node = path[0]
ninds = np.array(self.getLocindsOnNode(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if node.parent_node == path[1]:
# goes runs towards root
inds = np.where(xs[ninds] <= xstart)[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
else:
# path goes away from root
inds = np.where(xs[ninds] >= xstart)[0]
sortinds = np.argsort(xs[ninds][inds])
locinds.extend(ninds[inds][sortinds])
# middle nodes in path
for ii, node in enumerate(path[1:-1]):
ninds = np.array(self.getLocindsOnNode(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
elif path[ii+2] == node.parent_node:
# path goes towards root
sortinds = np.argsort(xs[ninds])
locinds.extend(ninds[sortinds[::-1]])
elif path[ii] == node.parent_node:
# path goes away from root
sortinds = np.argsort(xs[ninds])
locinds.extend(ninds[sortinds])
else:
# turning point (path only goes on this node at x=1)
inds = np.where((1. - xs[ninds]) < 1e-4)[0]
if len(inds) > 0:
locinds.extend(ninds[inds])
# last node in path
node = path[-1]
ninds = np.array(self.getLocindsOnNode(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if node.parent_node == path[-2]:
# path goes away from root
inds = np.where(xs[ninds] <= xstop)[0]
sortinds = np.argsort(xs[ninds][inds])
else:
# path goes towards root
inds = np.where(xs[ninds] >= xstop)[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
locinds.extend(ninds[inds][sortinds])
elif len(path) == 1:
node = path[0]
ninds = np.array(self.getLocindsOnNode(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if xstart < xstop:
inds = np.where(np.logical_and(xs[ninds]>=xstart, xs[ninds]<=xstop))[0]
sortinds = np.argsort(xs[ninds][inds])
else:
inds = np.where(np.logical_and(xs[ninds]>=xstop, xs[ninds]<=xstart))[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
locinds.extend(ninds[inds][sortinds])
return locinds
[docs] def getNearestLocinds(self, locs, name, direction=0, check_siblings=True, pprint=False):
"""
For each location in the input location list, find the index of the
closest location in a set of locations stored under a given name. The
search can go in the either go in the up or down direction or in both
directions.
Parameters
----------
locs: list of dicts, tuples or `neat.MorphLoc`
the locations for which the nearest location index has to be
found
name: string
name under which the reference list is stored
direction: int
flag to indicate whether to search in both directions (0), only
in the up direction (1) or in the down direction (2).
Returns
-------
loc_indices: list of ints
indices of the locations closest to the given locs
"""
self._tryName(name)
# create the locs in a desirable format
locs_ = []
for loc in locs:
locs_.append(MorphLoc(loc, self))
locs = locs_
# look for the location indices
loc_indices = []
for loc in locs:
loc_ind1 = None; loc_ind2 = None
# find the location indices if necessary
if direction == 0 or direction == 1:
loc_ind1 = self._findLocsDown(loc, name, check_siblings=check_siblings)
if direction == 0 or direction == 2:
loc_ind2 = self._findLocsUp(loc, name)
# save the index of the closest location, if it exists and
# if it is asked for
if loc_ind1 == None and (direction == 0 or direction == 2):
loc_indices.append(loc_ind2)
elif loc_ind2 == None and (direction == 0 or direction == 1):
loc_indices.append(loc_ind1)
else:
L1 = self.pathLength(loc, self.locs[name][loc_ind1])
L2 = self.pathLength(loc, self.locs[name][loc_ind2])
if L1 >= L2:
loc_indices.append(loc_ind2)
else:
loc_indices.append(loc_ind1)
return loc_indices
def _findLocsUp(self, loc, name):
look_further = False
# look if there are locs on the same node
n_inds = np.where(loc['node'] == self.nids[name])[0]
if len(n_inds) > 0:
if loc['node'] == 1:
loc_ind = n_inds[0]
else:
x_inds = np.where(loc['x'] <= self.xs[name][n_inds])[0]
if len(x_inds) != 0:
ind = np.argmin(self.xs[name][n_inds][x_inds])
loc_ind = n_inds[x_inds[ind]]
else:
look_further = True
else:
look_further = True
# if no locs on the same node, then proceed to child nodes
# else, return the smallest location larger than loc
if look_further:
node = self[loc['node']]
cnodes = node.getChildNodes()
loc_inds = []
for cnode in cnodes:
cloc_ind = self._findLocsUp({'node': cnode.index, 'x': 0.}, name)
if cloc_ind != None:
loc_inds.append(cloc_ind)
# get the one that is closest, if they exist
pl_aux = 1e4
ind_loc = 0
for i, l_i in enumerate(loc_inds):
pl = self.pathLength({'node': loc['node'], 'x': 1.}, self.locs[name][l_i])
if pl < pl_aux:
pl_aux = pl
ind_loc = i
if pl_aux > 0. and len(loc_inds) > 0:
loc_ind = loc_inds[ind_loc]
elif pl_aux == 0. and node.index == 1:
loc_ind = loc_inds[ind_loc]
else:
loc_ind = None
return loc_ind
def _findLocsDown(self, loc, name, check_siblings=True):
look_further = False
# look if there are locs on the same node
n_inds = np.where(loc['node'] == self.nids[name] )[0]
if len(n_inds) > 0:
if loc['node'] == 1:
loc_ind = n_inds[0]
else:
x_inds = np.where(loc['x'] >= self.xs[name][n_inds])[0]
if len(x_inds) != 0:
ind = np.argmax(self.xs[name][n_inds][x_inds])
loc_ind = n_inds[x_inds[ind]]
else:
look_further = True
else:
look_further = True
if look_further:
# if no locs on the same node, then proceed to resp. parent and child nodes
node = self[loc['node']]
pnode = node.getParentNode()
loc_inds = []
# check parent node
if pnode != None:
ploc_ind = self._findLocsDown({'node': pnode.index, 'x': 1.}, name,
check_siblings=check_siblings)
if ploc_ind != None:
loc_inds.append(ploc_ind)
# check other child nodes of parent node
if pnode != None and check_siblings:
ocnodes = copy.copy(pnode.getChildNodes())
ocnodes.remove(node)
else:
ocnodes = []
for cnode in ocnodes:
cloc_ind = self._findLocsUp({'node': cnode.index, 'x': 0.}, name)
if cloc_ind != None:
loc_inds.append(cloc_ind)
# get the one that is closest, if they exist
pl_aux = 1e4
ind_loc = 0
for i, l_i in enumerate(loc_inds):
pl = self.pathLength({'node': loc['node'], 'x': 1.}, self.locs[name][l_i])
if pl < pl_aux:
pl_aux = pl
ind_loc = i
if pl_aux > 0. and len(loc_inds) > 0:
loc_ind = loc_inds[ind_loc]
else:
loc_ind = None
return loc_ind
[docs] def getNearestNeighbourLocinds(self, loc, locarg):
"""
Search nearest neighbours to `loc` in `locarg`.
Parameters
----------
loc: tuple, dict or `neat.MorphLoc`
The locations for which nearest neighbours have to be found
locarg: str or list of locs
See documentation of `MorphTree._parseLocArg`, the set of locations
within which to look for nearest neighbours
Returns
-------
list of ints
Indices of nearest neighbours of `loc` in `locarg`
"""
# preprocess locarg
loc = MorphLoc(loc, self)
if isinstance(locarg, str):
name = locarg
locs = self._parseLocArg(locarg)
else:
name = 'nn aux'
locs = locarg
self.storeLocs(locs, name=name)
nns = []
# search for nearest neighbours
node = self[loc['node']]
locinds_aux = np.where(node.index == self.nids[name])[0]
if len(locinds_aux) > 0:
dx = self.xs[name][locinds_aux] - loc['x']
# locs on node in down direction
inds_down = np.where(dx >= 0)[0]
if len(inds_down) > 0:
ind_aux = np.argmin(dx[inds_down])
nns.append(locinds_aux[inds_down][ind_aux])
else:
for c_node in node.child_nodes:
self._searchNNDown(c_node, nns, name)
# locs on node in up direction
inds_up = np.where(dx <= 0)[0]
if len(inds_up) > 0:
ind_aux = np.argmax(dx[inds_up])
nns.append(locinds_aux[inds_up][ind_aux])
else:
self._searchNNUp(node, nns, name)
else:
for c_node in node.child_nodes:
self._searchNNDown(c_node, nns, name)
self._searchNNUp(node, nns, name)
if name == 'nn aux':
self.removeLocs(name)
return list(set(nns))
def _searchNNUp(self, node, nns, name):
p_node = node.parent_node
if p_node is not None:
# up direction
locinds_aux = np.where(p_node.index == self.nids[name])[0]
xval = 0.
if len(locinds_aux) > 0:
ind_aux = np.argmax(self.xs[name][locinds_aux])
locind = locinds_aux[ind_aux]
nns.append(locind)
xval = self.xs[name][locind]
else:
self._searchNNUp(p_node, nns, name)
# down direction
if xval < 1.-1e-5:
for c_node in set(p_node.child_nodes) - {node}:
self._searchNNDown(c_node, nns, name)
def _searchNNDown(self, node, nns, name):
locinds_aux = np.where(node.index == self.nids[name])[0]
if len(locinds_aux) > 0:
ind_aux = np.argmin(self.xs[name][locinds_aux])
locind = locinds_aux[ind_aux]
nns.append(locind)
else:
for c_node in node.child_nodes:
self._searchNNDown(c_node, nns, name)
[docs] def getLeafLocinds(self, name, recompute=False):
"""
Find the indices in the desire location list that are 'leafs', i.e.
locations for which no other location exist that is farther from the
root
Parameters
----------
name: string
name of the desired set of locations
recompute: bool (optional, default ``False``)
whether or not to force recomputing the distances
Returns
-------
list of inds
the indices of the 'leaf' locations
"""
try:
if recompute:
raise KeyError
self.leafinds[name]
except KeyError:
self._tryName(name)
self.leafinds[name] = []
locs = self.locs[name]
for ind, loc in enumerate(locs):
if not self._hasLocUp(loc, name):
self.leafinds[name].append(ind)
return self.leafinds[name]
def _hasLocUp(self, loc, name):
look_further = False
# look if there are locs on the same node
if loc['node'] != 1:
n_inds = np.where(loc['node'] == self.nids[name] )[0]
if len(n_inds) > 0:
x_inds = np.where(loc['x'] < self.xs[name][n_inds])[0]
if len(x_inds) > 0:
returnbool = True
else:
look_further = True
else:
look_further = True
else:
look_further = True
# if no locs on the same node, then proceed to child nodes
if look_further:
node = self[loc['node']]
cnodes = node.child_nodes
returnbool = False
for cnode in cnodes:
if self._hasLocUp({'node': cnode.index, 'x': 0.}, name):
returnbool = True
return returnbool
[docs] def distancesToSoma(self, locarg, recompute=False):
"""
Compute the distance of each location in a given set to the soma
Parameters
----------
locarg: list of locations or string
if list of locations, specifies the locations, if str,
specifies the name under which the set of location is stored
that should be used to create the new tree
Returns
-------
np.array of float
the distances to the soma of the corresponding locations
recompute: bool (optional)
whether or not to force recomputing the distances
"""
# process input argument
if isinstance(locarg, list):
locs = [MorphLoc(loc, self) for loc in locarg]
recompute = True
save = False
elif isinstance(locarg, str):
name = locarg
self._tryName(name)
locs = self.getLocs(name)
recompute = not (name in self.d2s) or recompute
save = True
else:
raise IOError('`locarg` should be list of locs or string')
if recompute:
d2s = np.array([self.pathLength({'node': 1, 'x': 0.}, loc) \
for loc in locs])
else:
d2s = self.d2s[name]
if save:
self.d2s[name] = d2s
return d2s
[docs] def distancesToBifurcation(self, name, recompute=False):
"""
Compute the distance of each location to the nearest bifurcation in
the 'up' direction (towards root)
Parameters
----------
name: str
name of the set of locations
recompute: bool (optional, default ``False``)
whether or not to force recomputing the distances
Returns
-------
np.array of floats
the distances to the nearest bifurcation of the corresponding
locations
"""
try:
if recompute:
raise KeyError
return self.d2b[name]
except KeyError:
self._tryName(name)
self.d2b[name] = []
locs = self.locs[name]
for i, loc in enumerate(locs):
if loc['node'] != 1:
if loc['node'] != locs[i-1]['node']:
node = self[loc['node']]
bnode, _ = self.upBifurcationNode(node)
self.d2b[name].append(self.pathLength( \
{'node': bnode.index, 'x': 1.}, loc))
else:
self.d2b[name].append(0.)
return self.d2b[name]
[docs] def distributeLocsOnNodes(self, d2s, node_arg=None, name='dont save'):
"""
Distributes locs on a given set of nodes at specified distances from the
soma. If the specified distances are on the specified nodes, the list
of locations will be empty. The locations are stored if the name is set
to be something other than 'dont save'. On each node, locations are
ordered from low to high x-values.
Parameters
----------
d2s: numpy.array of floats
the distances from the soma at which to put the locations (micron)
node_arg:
see documentation of `MorphTree._convertNodeArgToNodes`
name: string
the name under which the locations are stored. Defaults to 'dont save'
which means the locations are not stored
Returns
-------
list of `neat.MorphLoc`
the list of locations
"""
# distribute the locations
locs = []
for node in self._convertNodeArgToNodes(node_arg):
if node.parent_node != None:
L0 = self.pathLength({'node': 1, 'x': 0.5},
{'node': node.index, 'x': 0.})
L1 = self.pathLength({'node': 1, 'x': 0.5},
{'node': node.index, 'x': 1.})
inds = np.where(np.logical_and(L0 < d2s, d2s <= L1))[0]
Ls = np.sort(d2s[inds])
locs.extend([MorphLoc((node.index, (L-L0)/(L1-L0)), self) \
for L in Ls if L > 1e-12])
elif np.any(np.abs(d2s) <= 1e-12):
# node is soma, append a location on the soma
locs.append(MorphLoc((node.index, 0.5), self))
if name != 'dont save': self.storeLocs(locs, name=name)
return locs
[docs] def distributeLocsRandom(self, num, dx=0.001, node_arg=None,
add_soma=True, name='dont save', seed=None):
"""
Returns a list of input locations randomly distributed on the tree
Parameters
----------
num: int
number of inputs
dx: float (optional)
minimal or given distance between input locations (micron)
node_arg (optional):
see documentation of `MorphTree._convertNodeArgToNodes`
add_soma: bool (optional)
whether or not to include the possibility of adding locations on the
soma
name: string (optional)
the name under which the locations are stored. Defaults to 'dont save'
which means the locations are not stored
seed: int (optiona)
Seed for numpy random number generator
Returns
-------
list of `neat.MorphLoc`
the locations
"""
np.random.seed(seed)
# use the requested subset of nodes
nodes = [node for node in self._convertNodeArgToNodes(node_arg)
if node.index != 1]
# initialize the loclist with or without soma
if add_soma:
locs = [MorphLoc({'node': 1, 'x': 0.}, self)]
self.root.content['tag'] = 1
else:
locs = []
# add the nodes
for ii in range(num):
nodes_left = [node.index for node in nodes
if 'tag' not in node.content]
if len(nodes_left) < 1:
break
index = np.random.choice(nodes_left)
x = np.random.random()
locs.append(MorphLoc((index, x), self))
node = self[index]
self._tagNodesDown(node, node, dx=dx)
self._tagNodesUp(node, node, dx=dx)
self._removeTags()
# store the locations
if name != 'dont save': self.storeLocs(locs, name=name)
return locs
def _tagNodesDown(self, start_node, node, dx=0.001):
if 'tag' not in node.content:
if node.index == start_node.index:
length = 0.
else:
length = self.pathLength({'node': start_node.index, 'x': 1.},
{'node': node.index, 'x': 0.})
if length < dx:
node.content['tag'] = 1
for cnode in node.child_nodes:
self._tagNodesDown(start_node, cnode, dx=dx)
def _tagNodesUp(self, start_node, node, cnode=None, dx=0.001):
if node.index == start_node.index:
length = 0.
else:
length = self.pathLength({'node': start_node.index, 'x': 1.},
{'node': node.index, 'x': 1.})
if length < dx:
node.content['tag'] = 1
cnodes = node.child_nodes
if len(cnodes) > 1:
if cnode != None:
cnodes = list(set(cnodes) - set([cnode]))
for cn in cnodes:
self._tagNodesDown(start_node, cn, dx=dx)
pnode = node.getParentNode()
if pnode != None:
self._tagNodesUp(start_node, pnode, node, dx=dx)
def _removeTags(self):
for node in self:
if 'tag' in node.content:
del node.content['tag']
def _parseLocArg(self, loc_arg):
if isinstance(loc_arg, list):
locs = [MorphLoc(loc, self) for loc in loc_arg]
elif isinstance(loc_arg, str):
self._tryName(loc_arg)
locs = self.getLocs(loc_arg)
else:
raise IOError('invalid type for `loc_arg`, should be list or string')
return locs
[docs] def extendWithBifurcationLocs(self, loc_arg, name='dont save'):
"""
Extends input loc_arg with the intermediate bifurcations. They are
appended to the end of the list
Parameters
----------
loc_arg: list of `neat.MorphLoc` or string
the locations
name: string (optional)
The name under which the list of bifurcation locs will be stored.
Defaults to 'dont save' which means they are not stored.
Returns
-------
list of `neat.MorphLoc`
the bifurcation locs
"""
locs = self._parseLocArg(loc_arg)
# get the bifurcation locs
nodes = [self[loc['node']] for loc in locs]
bnodes = self.getBifurcationNodes(nodes)
blocs = [MorphLoc((bnode.index, 1.), self) for bnode in bnodes]
# retain unique locs
all_locs = self.uniqueLocs(locs + blocs)
# store the locations
if name != 'dont save': self.storeLocs(all_locs, name=name)
return all_locs
[docs] def uniqueLocs(self, loc_arg, name='dont save'):
"""
Gets the unique locations in the provided locs
Parameters
----------
loc_arg: list of `neat.MorphLoc` or string
the locations
name: string (optional)
The name under which the list of bifurcation locs will be stored.
Defaults to 'dont save' which means they are not stored.
Returns
-------
list of `neat.MorphLoc`
the bifurcation locs
"""
locs = self._parseLocArg(loc_arg)
locs_ = reduce(lambda l, x: l.append(x) or l if x not in l else l, locs, [])
if name != 'dont save': self.storeLocs(locs_, name=name)
return locs_
[docs] def makeXAxis(self, dx=10., node_arg=None, loc_arg=None):
"""
Create a set of locs suitable for serving as the x-axis for 1D plotting.
The neurons is put on a 1D axis with a depth-first ordering.
Parameters
----------
dx: float
target separation between the plot points (micron)
node_arg:
see documentation of `MorphTree._convertNodeArgToNodes`
The nodes on which the locations for the x-axis are distributed.
When this is given as a list of nodes, assumes a depth first
ordering.
loc_arg: list of locs or string
if list of locs, these locs will be used as x-axis, if string, name
of set of locs on the morphology that will be used as x-axis
"""
if loc_arg is None:
# if comptree has not been set, create a basic one for plotting
if self._computational_root == None:
self.setCompTree()
# distribute the x-axis locations
self.distributeLocsUniform(dx, node_arg=node_arg, name='xaxis')
# get the root node
nodes = self._convertNodeArgToNodes(node_arg)
# check that first node is root
for node in nodes:
if nodes[0] in node.child_nodes:
raise ValueError('Input `node_arg` is not a depth-first ordered'
' list of nodes.')
# set the node colors for both trees
if self.treetype == 'original':
rootnode_orig = nodes[0]
tempnode = self._findCompnodeDown(nodes[0])
self.setNodeColors(rootnode_orig)
self.treetype = 'computational'
rootnode_comp = self[tempnode.index]
self.setNodeColors(rootnode_comp)
self.treetype = 'original'
else:
rootnode_comp = nodes[0]
self.setNodeColors(rootnode_comp)
self.treetype = 'original'
rootnode_orig = self[rootnode.comp.index]
self.setNodeColors(rootnode_orig)
self.treetype = 'computational'
else:
if isinstance(loc_arg, list):
self.storeLocs(loc_arg, name='xaxis')
elif isinstance(loc_arg, str):
self.storeLocs(self.getLocs(loc_arg), name='xaxis')
else:
raise IOError('`loc_org` should be string or list of locs')
# compute the x-axis 1D array
pinds = self.getLeafLocinds('xaxis')
d2s = self.distancesToSoma('xaxis')
xaxis = d2s[0:pinds[0]+1].tolist()
d_add = d2s[pinds[0]]
for ii in range(0,len(pinds)-1):
xaxis.extend((d_add + d2s[pinds[ii]+1:pinds[ii+1]+1] \
- d2s[pinds[ii]+1]).tolist())
d_add += d2s[pinds[ii+1]] - d2s[pinds[ii]+1]
self.xaxis = np.array(xaxis)
[docs] def setNodeColors(self, startnode=None):
"""
Set the color code for the nodes for 1D plotting
Parameters
----------
node: int or `neat.MorphNode`
index of the node or node whose subtree will be colored. Defaults
to the root
"""
if startnode == None: startnode = self.root
for node in self: node.content['color'] = 0.
self.node_color = [0.] # trick to pass the pointer and not the number itself
self._setColorsDown(startnode)
def _setColorsDown(self, node):
node.content['color'] = self.node_color[0]
if self.isLeaf(node):
self.node_color[0] += 1.
for cnode in node.child_nodes:
self._setColorsDown(cnode)
[docs] def getXValues(self, locs):
"""
Get the corresponding location on the x-axis of the input locations
Parameters
----------
locs: list of tuples, dicts or `neat.MorphLoc`
list of the locations
"""
locinds = np.array(self.getNearestLocinds(locs, 'xaxis')).astype(int)
return self.xaxis[locinds]
[docs] def plot1D(self, ax, parr, *args, **kwargs):
"""
Plot an array where each element corresponds to the matching location on
the x-axis with a depth-first ordering on a 1D plot
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be made
parr: numpy.array of floats
the array that will be plotted
args, kwargs:
arguments for `matplotlib.pyplot.plot`
Returns
-------
lines: list of `matplotlib.lines.Line2D` instances
the line segments corresponding to the value of the plotted array
in each branch
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs['xaxis'])
pinds = self.getLeafLocinds('xaxis')
d2s = self.distancesToSoma('xaxis')
# make the plot
lines = []
line = ax.plot(self.xaxis[0:pinds[0]+1], parr[0:pinds[0]+1],
*args, **kwargs)
lines.append(line[0])
if 'label' in list(kwargs.keys()):
kwargs = copy.deepcopy(kwargs)
del kwargs['label']
for ii in range(0,len(pinds)-1):
line = ax.plot(self.xaxis[pinds[ii]+1:pinds[ii+1]+1],
parr[pinds[ii]+1:pinds[ii+1]+1],
*args, **kwargs)
lines.append(line[0])
return lines
def setLineData(self, lines, parr):
"""
Update the line objects with new data
Parameters
----------
lines: list of `matplotlib.lines.Line2D` instance
the line segments of which the data has to be updated
parr: numpy.array of floats
the array that will be put in the line segments
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs['xaxis'])
pinds = self.getLeafLocinds('xaxis')
d2s = self.distancesToSoma('xaxis')
lines[0].set_data(self.xaxis[0:pinds[0]+1], parr[0:pinds[0]+1])
for ii in range(0,len(pinds)-1):
ll = ii+1
lines[ll].set_data(self.xaxis[pinds[ii]+1:pinds[ii+1]+1],
parr[pinds[ii]+1:pinds[ii+1]+1])
[docs] def plotTrueD2S(self, ax, parr, cmap=None, **kwargs):
"""
Plot an array where each element corresponds to the matching location in
the x-axis location list. Now all locations are plotted at their true
distance from the soma.
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be made
parr: numpy.array of floats
the array that will be plotted
cmap: `matplotlib.colors.Colormap` instance
If provided, the lines will be colored according to the branch
to which they belong, in colors specified by the colormap
kwargs:
keyword arguments for `matplotlib.pyplot.plot`
Returns
-------
lines
lines: list of `matplotlib.lines.Line2D`
the line segments corresponding to the value of the plotted array
in each branch
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs['xaxis'])
locs = self.locs['xaxis']
pinds = self.getLeafLocinds('xaxis')
d2s = self.distancesToSoma('xaxis')
# list of colors for plotting
cs = {node.index: node.content['color'] for node in self}
cplot = [cs[loc['node']] for loc in locs]
max_cs = max(cplot)
min_cs = min(cplot)
if np.abs(max_cs - min_cs) < 1e-12:
norm_cs = max_cs + 1e-2
else:
norm_cs = (max_cs - min_cs) * (1. + 1./100.)
# create the truespace plot
lines = []
if cmap != None:
kwargs['c'] = cmap((cplot[0]-min_cs)/norm_cs)
if 'color' in kwargs: del kwargs['color']
line = ax.plot(d2s[0:pinds[0]+1], parr[0:pinds[0]+1], **kwargs)
lines.append(line[0])
if 'label' in kwargs: del kwargs['label']
for ii in range(0,len(pinds)-1):
if cmap != None:
kwargs['c'] = cmap((cs[locs[pinds[ii]+1]['node']]-min_cs)/norm_cs)
line = ax.plot(d2s[pinds[ii]+1:pinds[ii+1]+1],
parr[pinds[ii]+1:pinds[ii+1]+1],
**kwargs)
lines.append(line[0])
return lines
def _addScalebar(self, ax, borderpad=-1.8, sep=2):
from neat.tools.plottools import scalebars
scalebars.addScalebar(ax, hidex=False, hidey=False, matchy=False,
labelx='$\mu$m',
loc=8, borderpad=borderpad, sep=sep)
ax.set_xticklabels([])
[docs] def colorXAxis(self, ax, cmap, addScalebar=1, borderpad=-1.8):
"""
Color the x-axis of a plot according to the morphology.
!!! Has to be called after all lines are plotted !!!
Furthermor, node colors have to be set first. This can be done with
`MorphTree.setNodeColors()` or manually by adding a 'color' entry
to the ``MorphNode.content`` dictionary
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object of which the x-axis will be colored
cmap: `matplotlib.colors.Colormap` instance
Colormap that determines the color of each branch
sizex: float
Size of scalebar (in micron). If set to None, no scalebar is
plotted.
borderpad: float
Borderpad of scalebar
"""
locs = self.locs['xaxis']
# list of colors for plotting
cs = {node.index: node.content['color'] for node in self}
cplot = [cs[loc['node']] for loc in locs]
max_cs = max(cplot)
min_cs = min(cplot)
if np.abs(max_cs - min_cs) < 1e-12:
norm_cs = max_cs + 1e-2
else:
norm_cs = (max_cs - min_cs) * (1. + 1./100.)
# necessary distance arrays
pinds = self.getLeafLocinds('xaxis')
assert len(pinds) > 0
d2s = self.distancesToSoma('xaxis')
# plot colored xaxis
ylim = np.array(ax.get_ylim())
ax.plot(self.xaxis[0:pinds[0]+1], [ylim[0]+1e-9 for _ in d2s[0:pinds[0]+1]],
c=cmap((cplot[0]-min_cs)/norm_cs), lw=10)
for ii in range(0,len(pinds)-1):
if locs[pinds[ii]+1]['node'] in list(cs.keys()):
ax.plot(self.xaxis[pinds[ii]+1:pinds[ii+1]+1],
[ylim[0]+1e-9 for _ in d2s[pinds[ii]+1:pinds[ii+1]+1]],
c=cmap((cs[locs[pinds[ii]+1]['node']]-min_cs)/norm_cs), lw=10)
else:
ax.plot(self.xaxis[pinds[ii]+1:pinds[ii+1]+1],
[ylim[0]+1e-9 for _ in d2s[pinds[ii]+1:pinds[ii+1]+1]],
c='k', lw=10)
ax.set_ylim((ylim[0], ylim[1]))
# add scalebar
if addScalebar:
self._addScalebar(ax, borderpad=borderpad)
ax.axes.get_xaxis().set_visible(False)
[docs] def plot2DMorphology(self, ax, node_arg=None, cs=None, cminmax=None, cmap=None,
use_radius=1, draw_soma_circle=1,
plotargs={}, textargs={},
marklocs=[], locargs={},
marklabels={}, labelargs={},
cb_draw=0, cb_orientation='vertical', cb_label='',
sb_draw=1, sb_scale=100, sb_width=5.,
set_lims=True, lims_margin=.1):
"""
Plot the morphology projected on the x,y-plane
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be drawn
node_arg:
see documentation of `MorphTree._convertNodeArgToNodes`
cs: dict {int: float}, None or 'x_color'
If dict, node indices are keys and the float value will
correspond to the plotted color. If None, the color of the tree
will be the one specified in ``plotargs``. Note that the dict
does not have to contain all node indices. The ones that are not
featured in the dict are plot in the color specified in ``plotargs``.
If 'node_color', colors will be those stored on the nodes. Note
that choosing this option when there are nodes without 'color'
as an entry in ``node.content`` will result in an error. Node
colors can be set with `MorphTree.setNodeColor()``
cminmax: (float, float) or None (default)
The min and max values of the color scale (if cs is provided).
If None, the min and max values of cs are used.
cmap: `matplotlib.colors.Colormap` instance
colormap fram which colors in ``cs`` are taken
use_radius: bool
If ``True``, uses the swc radius for the width of the line
segments
draw_soma_circle: bool
If ``True``, draws the soma as a circle, otherwise doesn't draw
soma
plotargs: dict
`kwargs` for `matplotlib.pyplot.plot`. 'c'- or 'color'-
argument will be overwritten when cs is defined. 'lw'- or
'linewidth' argument will be multiplied with the swc radius of
the node if `use_radius` is ``True``.
textargs: dict
text properties for various labels in the plot
marklocs: list of tuples, dicts or instances of `neat.MorphLoc`
Location that will be plotted on the morphology
locargs: dict or list of dict
`kwargs` for `matplotlib.pyplot.plot` for the location.
Use only point markers and no lines! When it is a single dict
all location will have the same marker. When it is a list it
should have the same length as `marklocs`.
marklabels: dict {int: string}
Keys are indices of locations in `marklocs`, values are strings
that are used to annotate the corresponding locations
labelargs: dict
text properties for the location annotation
cb_draw: bool
Whether or not to draw a `matplotlib.pyplot.colorbar()`
instance.
cb_orientation: string, 'vertical' or 'horizontal'
The colorbars' orientation
cb_label: string
The label of the colorbar
sb_draw: bool
Whether or not to draw a scale bar
sb_scale: float
Lenght of the scale bar (micron)
sb_width: float
Width of the scale bar
set_lims: bool (optional, default ``True``)
set ``ax`` limits based on the morphology
lims_margin: float
the margin, as fraction of total width and height of tree, at
which the limits are placed
"""
# default cmap
if cmap is None:
cmap = cm.get_cmap('jet')
# ensure color is indicated by the 'c'-parameter in `plotargs`
if 'color' in plotargs:
plotargs['c'] = plotargs['color']
del plotargs['color']
elif 'c' not in plotargs:
plotargs['c'] = 'k'
# define a norm for the colors, if defined
if cs == 'x_color':
cs = {node.index: node.content['color'] for node in self}
if cs is not None:
if cminmax is None:
if len(cs) > 0:
max_cs = cs[max(cs, key=cs.__getitem__)] # works for dict and list
min_cs = cs[min(cs, key=cs.__getitem__)] # works for dict and list
else:
min_cs, max_cs = 0., 1.
else:
min_cs = cminmax[0]
max_cs = cminmax[1]
norm = pl.Normalize(vmin=min_cs, vmax=max_cs)
# ensure linewidth is indicated as 'lw' in plotargs
if 'linewidth' in plotargs:
plotargs['lw'] = plotargs['linewidth']
del plotargs['linewidth']
elif 'lw' not in plotargs:
plotargs['lw'] = 1.
plotargs_orig = copy.deepcopy(plotargs)
# locargs can be dictionary, so that the same properties hold for every
# markloc, or can be list with the same size as marklocs, so that every
# marker has different properties. `zorder` of the markers is also set
# very high so that they are always in the foreground
self.storeLocs(marklocs, 'plotlocs')
xs = self.xs['plotlocs']
if type(locargs) == dict:
if 'zorder' not in locargs:
locargs['zorder'] = 1e4
locargs = [locargs for _ in marklocs]
else:
assert len(locargs) == len(marklocs)
for locarg in locargs:
if 'zorder' not in locarg:
locarg['zorder'] = 1e4
# `marklabels` is a dictionary with as keys the index of the loc in
# `marklocs` to which the label belongs. `labelargs` is the same for
# every label
for ind in marklabels: assert ind < len(marklocs)
# plot the tree
xlim = [0.,0.]; ylim = [0.,0.]
for node in self._convertNodeArgToNodes(node_arg):
if node.xyz[0] < xlim[0]: xlim[0] = node.xyz[0]
if node.xyz[0] > xlim[1]: xlim[1] = node.xyz[0]
if node.xyz[1] < ylim[0]: ylim[0] = node.xyz[1]
if node.xyz[1] > ylim[1]: ylim[1] = node.xyz[1]
# find the locations that are on the current node
inds = self.getLocindsOnNode('plotlocs', node)
if node.parent_node is None:
# node is soma, draw as circle if necessary
if draw_soma_circle:
if cs is not None and node.index in cs:
plotargs['c'] = cmap(norm(cs[node.index]))
else:
plotargs['c'] = plotargs_orig['c']
circ = patches.Circle(node.xyz[0:2], node.R,
color=plotargs['c'])
ax.add_patch(circ)
for ind in inds:
self._plotLoc(ax, ind, node.xyz[0], node.xyz[1],
locargs, marklabels, labelargs)
else:
# plot line segment associated with node
nxyz = node.xyz; pxyz = node.parent_node.xyz
if cs is not None and node.index in cs:
plotargs['c'] = cmap(norm(cs[node.index]))
else:
plotargs['c'] = plotargs_orig['c']
if use_radius:
plotargs['lw'] = plotargs_orig['lw'] * node.R
ax.plot([pxyz[0], nxyz[0]], [pxyz[1], nxyz[1]], **plotargs)
# plot the locations
for ind in inds:
locxyz = pxyz + (nxyz - pxyz) * xs[ind]
self._plotLoc(ax, ind, locxyz[0], locxyz[1],
locargs, marklabels, labelargs)
# margins
dx = xlim[1]-xlim[0]
dy = ylim[1]-ylim[0]
xlim[0] -= dx*lims_margin; xlim[1] += dx*lims_margin
ylim[0] -= dy*lims_margin; ylim[1] += dy*lims_margin
# draw a scale bar
if sb_draw:
scale = sb_scale
dy = ylim[1]-ylim[0]
dx = xlim[1]-xlim[0]
ax.plot([xlim[0],xlim[0]+scale],
[ylim[0],ylim[0]],
'k', linewidth=sb_width, zorder=1e5)
txt = ax.annotate(r'' + str(scale) + ' $\mu$m',
xy=(xlim[0]+scale/2., ylim[0]),
xycoords='data', xytext=(xlim[0]+scale/2.,ylim[0]-dy/200.), ha='center', va='top',
**textargs)
# textcoords='offset points', **textargs)
txt.set_path_effects([patheffects.withStroke(foreground="w",
linewidth=2)])
if set_lims:
ax.set_xlim(xlim); ax.set_ylim(ylim)
ax.set_aspect('equal', 'datalim')
if cs != None and cb_draw:
# create colorbar ax
divider = make_axes_locatable(ax)
if cb_orientation == 'horizontal':
cax = divider.append_axes("bottom", "5%", pad="3%")
else:
cax = divider.append_axes("right", "5%", pad="3%")
# create a mappable
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm._A = [] # fake array for scalar mappable
# create the colorbar
cb = pl.colorbar(sm, cax=cax, orientation=cb_orientation)
ticks_cb = np.round(np.linspace(min_cs, max_cs, 7), decimals=1)
cb.set_ticks(ticks_cb)
if cb_orientation == 'horizontal':
cb.ax.xaxis.set_ticks_position('bottom')
else:
cb.ax.yaxis.set_ticks_position('right')
cb.set_label(cb_label, **textargs)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['right'].set_color('none')
ax.spines['left'].set_color('none')
ax.draw_frame = False
ax.set_xticks([])
ax.set_yticks([])
def _plotLoc(self, ax, ind, xval, yval, locargs, marklabels, labelargs):
"""
plot a location on the morphology together with its annotation
"""
ax.plot(xval, yval, **locargs[ind])
if ind in marklabels:
txt = ax.annotate(marklabels[ind], xy=(xval, yval),
xycoords='data', xytext=(5,5),
textcoords='offset points', **labelargs)
txt.set_path_effects([patheffects.withStroke(foreground="w",
linewidth=2)])
[docs] def plotMorphologyInteractive(self, node_arg=None,
use_radius=1, draw_soma_circle=1,
plotargs={'c': 'k', 'lw': 1.},
project3d=False):
"""
Show the morphology either in 3d or projected on the x,y-plane. When
a line segment is clicked, the associated node is printed.
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be drawn
node_arg:
see documentation of `MorphTree._convertNodeArgToNodes`
use_radius: bool
If ``True``, uses the swc radius for the width of the line
segments
draw_soma_circle: bool
If ``True``, draws the soma as a circle, otherwise doesn't draw
soma
"""
fig = pl.figure('Morphology interactive')
ax = pl.gca(projection='3d') if project3d else pl.gca()
# ax = pl.gca()
if 'c' not in plotargs:
plotargs.update({'c': 'k'})
if 'linewidth' in plotargs:
plotargs['lw'] = plotargs['linewidth']
del plotargs['linewidth']
if 'lw' not in plotargs:
plotargs.update({'lw': 'k'})
plotargs_orig = copy.deepcopy(plotargs)
# plot the tree
node_line_associators = {}
for ii, node in enumerate(self._convertNodeArgToNodes(node_arg)):
if node.parent_node is not None:
# plot line segment associated with node
nxyz = node.xyz; pxyz = node.parent_node.xyz
if use_radius:
plotargs['lw'] = plotargs_orig['lw'] * node.R
if project3d:
line = ax.plot([pxyz[0], nxyz[0]],
[pxyz[1], nxyz[1]],
[pxyz[2], nxyz[2]],
label=str(ii), picker=2., **plotargs)
else:
line = ax.plot([pxyz[0], nxyz[0]],
[pxyz[1], nxyz[1]],
label=str(ii), picker=2., **plotargs)
node_line_associators.update({str(ii): node})
ax.axes.get_xaxis().set_visible(0)
ax.axes.get_yaxis().set_visible(0)
ax.axison = 0
# define the clickevent action
def onPick(event):
line = event.artist
node = node_line_associators[line.get_label()]
# print the associated node
print('\n>>> line segment at ' + str(node) + \
', distance to soma (um) = ' + \
str(self.pathLength({'node': node.index, 'x': 1},
{'node': 1, 'x':0.})))
# show morphology
cid = fig.canvas.mpl_connect('pick_event', onPick)
pl.show()
@originalTreetypeDecorator
def findCommonRoot(self, name):
self._tryName(name)
# get the node indices of nodes
node_inds = self.getNodeIndices(name)
# find the paths to the root
paths = [set(self.pathToRoot(self[node_ind])) for node_ind in node_inds]
# possible roots
roots = list(set.intersection(*paths))
# return the node of highest order
rootind = np.argmax([self.orderOfNode(node) for node in roots])
return roots[rootind]
[docs] @originalTreetypeDecorator
def createNewTree(self, name, fake_soma=False, store_loc_inds=False):
"""
Creates a new tree where the locs of a given 'name' are now the nodes.
Distance relations between locations are maintained (note that this
relation is stored in `L` attribute of `neat.MorphNode`, using the `p3d`
attribute containing the 3d coordinates does not maintain distances)
Parameters
----------
name: string
the name under which the locations are stored that should be
used to create the new tree
fake_soma: bool (default `False`)
if `True`, finds the common root of the set of locations and
uses that as the soma of the new tree. If `False`, the real soma
is used.
store_loc_inds: bool (default `False`)
store the index of each location in the `content` attribute of the
new node (under the key 'loc ind')
Returns
-------
`neat.MorphTree`
The new tree.
"""
self._tryName(name)
nids = self.getNodeIndices(name)
# create new tree
new_tree = self.__class__()
if fake_soma:
# find the common root of the set of locations
snode = self.findCommonRoot(name)
else:
# use the soma as root
snode = self[1]
p3d = (snode.xyz, snode.R, snode.swc_type)
new_snode = new_tree._createCorrespondingNode(1, p3d)
new_snode.L = snode.L
new_tree.setRoot(new_snode)
new_nodes = [new_snode]
if store_loc_inds:
new_snode.content['loc ind'] = None if 1 not in nids else \
np.where(nids == 1)[0][0]
# make two other soma nodes
if fake_soma:
for index in [2,3]:
new_cnode = new_tree._createCorrespondingNode(index, p3d)
new_tree.addNodeWithParent(new_cnode, new_snode)
new_nodes.append(new_cnode)
else:
for cnode in snode.getChildNodes(skip_inds=[]):
if cnode.index in [2,3]:
p3d = (cnode.xyz, cnode.R, cnode.swc_type)
new_cnode = new_tree._createCorrespondingNode(cnode.index, p3d)
new_tree.addNodeWithParent(new_cnode, new_snode)
new_nodes.append(new_cnode)
# make rest of tree
for cnode in snode.child_nodes:
self._addNodesToTree(cnode, new_snode, new_tree, new_nodes, name,
store_loc_inds=store_loc_inds)
# set the lengths of the nodes
for new_node in new_tree:
if new_node.parent_node != None:
L = np.sqrt(np.sum((new_node.parent_node.xyz - new_node.xyz)**2))
else:
L = 0.
new_node.setLength(L)
return new_tree
def _addNodesToTree(self, node, new_pnode, new_tree, new_nodes, name,
store_loc_inds=False):
# get the specified locs
xs = self.xs[name]
# check which locinds are on the branch
ninds = self.getLocindsOnNode(name, node)
order_inds = np.argsort(xs[ninds])
for ind in np.array(ninds)[order_inds]:
index = len(new_nodes) + 1
# new coordinates
new_xyz = node.parent_node.xyz * (1.-xs[ind]) + node.xyz * xs[ind]
if node.parent_node.index == 1:
new_radius = node.R
else:
new_radius = node.parent_node.R * (1.-xs[ind]) + node.R * xs[ind]
# make new node
p3d = (new_xyz, new_radius, node.swc_type)
new_node = new_tree._createCorrespondingNode(index, p3d)
if store_loc_inds:
new_node.content['loc ind'] = ind
# add new node
new_tree.addNodeWithParent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
# continue with the children
for cnode in node.child_nodes:
self._addNodesToTree(cnode, new_pnode, new_tree, new_nodes, name,
store_loc_inds=store_loc_inds)
[docs] @originalTreetypeDecorator
def createCompartmentTree(self, locarg):
"""
Creates a new compartment tree where the provided set of locations
correspond to the nodes.
Parameters
----------
locarg: list of locations or str
if list of locations, specifies the locations, if str,
specifies the name under which the set of location is stored
that should be used to create the new tree
Returns
-------
`neat.MorphTree`
The new tree.
"""
# process input argument
if isinstance(locarg, list):
locs = [MorphLoc(loc, self) for loc in locarg]
name = 'comp_locs'
self.storeLocs(locs, name=name)
elif isinstance(locarg, str):
name = locarg
self._tryName(name)
else:
raise IOError('`locarg` should be list of locs or string')
nids = self.nids[name]
xs = self.xs[name]
# create new tree
new_tree = CompartmentTree()
# find the common root of the set of locations
snode = self.findCommonRoot(name)
# check if that root is in set of locations
possible_loc_inds = self.getLocindsOnNode(name, snode)
if len(possible_loc_inds) > 0:
# create the new root node
new_pnode = CompartmentNode(0, loc_ind=possible_loc_inds[0])
new_tree.setRoot(new_pnode)
new_nodes = [new_pnode]
# create other nodes
for loc_ind in possible_loc_inds[1:]:
index = len(new_nodes)
# make new node
new_node = CompartmentNode(index, loc_ind=loc_ind)
# add new node
new_tree.addNodeWithParent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
else:
warnings.warn('Locations of name `' + name + '` do not define a root - ' + \
'adding root to set of locations')
locs = self.getLocs(name)
locs = [(snode.index, 1.)] + locs
self.storeLocs(locs, name=name)
# create the new root node
new_pnode = CompartmentNode(0, loc_ind=0)
new_tree.setRoot(new_pnode)
new_nodes = [new_pnode]
# make rest of tree
for cnode in snode.child_nodes:
self._addCompNodesToTree(cnode, new_pnode, new_tree, new_nodes, name)
return new_tree
def _addCompNodesToTree(self, node, new_pnode, new_tree, new_nodes, name):
# get the specified locs
xs = self.xs[name]
# check which locinds are on the branch
ninds = self.getLocindsOnNode(name, node)
for loc_ind in ninds:
index = len(new_nodes)
# make new node
new_node = CompartmentNode(index, loc_ind=loc_ind)
# add new node
new_tree.addNodeWithParent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
# continue with the children
for cnode in node.child_nodes:
self._addCompNodesToTree(cnode, new_pnode, new_tree, new_nodes, name)
[docs] def __copy__(self, new_tree=None):
"""
Fill the ``new_tree`` with it's corresponding nodes in the same
structure as ``self``, and copies all node variables that both tree
classes have in common
Parameters
----------
new_tree: :class:`STree` or derived class (default is ``None``)
the tree class in which the ``self`` is copied. If ``None``,
returns a copy of ``self``.
Returns
-------
The new tree instance
"""
if new_tree is None:
new_tree = self.__class__()
current_treetype = self.treetype
self.treetype = 'original'
super().__copy__(new_tree=new_tree)
try:
# set the computational tree
self.treetype = 'computational'
new_node = new_tree._createCorrespondingNode(self.root.index)
self.root.__copy__(new_node=new_node)
new_tree._computational_root = new_node
new_tree.treetype = 'computational'
self._recurseCopy(self.root, new_tree)
except ValueError:
pass
self.treetype = current_treetype
new_tree.treetype = current_treetype
return new_tree