# -*- coding: utf-8 -*-
#
# morphtree.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST. If not, see <http://www.gnu.org/licenses/>.
import numpy as np
import matplotlib.patheffects as patheffects
import matplotlib.patches as patches
import matplotlib.cm as cm
import matplotlib.pyplot as pl
from mpl_toolkits.axes_grid1 import make_axes_locatable
import copy
import pathlib
import warnings
from typing import Literal
from functools import reduce
from contextlib import contextmanager
from .stree import SNode, STree
from .compartmenttree import CompartmentNode, CompartmentTree
from ..tools.interval import interval
def computational_tree_decorator(fun):
"""
Decorator that provides the safety that the active tree is the
computational one inside the functions it decorates. This decorator also
checks if a computational tree has been defined.
Raises
------
AttributeError
If this function is called and no computational tree has been
defined
"""
# wrapper to access self
def wrapped(self, *args, **kwargs):
if self._computational_root is None:
raise AttributeError(
"No computational tree has been defined, "
"and this function requires one. Use "
"`MorphTree.set_comp_tree()` or its "
"overwritten version in one of the derived"
"classes"
)
with self.as_computational_tree:
res = fun(self, *args, **kwargs)
return res
wrapped.__doc__ = fun.__doc__
return wrapped
def original_tree_decorator(fun):
"""
Decorator that provides the safety that active tree is the
original one inside the functions it decorates.
"""
# wrapper to access self
def wrapped(self, *args, **kwargs):
with self.as_original_tree:
res = fun(self, *args, **kwargs)
return res
wrapped.__doc__ = fun.__doc__
return wrapped
[docs]
class MorphLoc(object):
"""
Stores a location on the morphology. The location is initialized starting
from a node and x-value on the real morphology. The location is also be
stored in the coordinates of the computational morphology. To toggle between
coordinates, the class stores a reference to the morphology tree on which
the location is defined, and returns either the original coordinate or the
coordinate on the computational tree, depending on which tree is active.
Initialized based on either a tuple or a dict where one entry specifies the
node index and the other entry the x-coordinate specifying the location
between parent node (x=0) or the node indicated by the index (x=1), or on
a `neat.MorphLoc`.
Parameters
----------
loc: tuple or dict or `neat.MorphLoc`
if tuple: (node index, x-value)
if dict: {'node': node index, 'x': x-value}
reftree: `neat.MorphTree`
set_as_comploc: bool
if True, assumes the paremeters provided in `loc` are coordinates
on the computational tree. Doing this while no computational tree
has been initialized in `reftree` will result in an error.
Defaults to False
Raises
------
ValueError
If x-coordinate of location is not in ``[0,1]``
"""
def __init__(self, loc, reftree, set_as_comploc=False):
self.reftree = reftree
if isinstance(loc, tuple):
x = float(loc[1])
if x > 1.0 or x < 0.0:
raise ValueError("x-value should be in [0,1]")
if set_as_comploc:
self.comp_loc = {"node": int(loc[0]), "x": x}
self._set_original_loc()
else:
self.loc = {"node": int(loc[0]), "x": x}
elif isinstance(loc, dict):
x = float(loc["x"])
if x > 1.0 or x < 0.0:
raise ValueError("x-value should be in [0,1]")
if set_as_comploc:
self.comp_loc = loc
self._set_original_loc()
else:
self.loc = loc
elif isinstance(loc, MorphLoc):
self.loc = loc.loc
self.reftree = reftree
else:
raise TypeError("Not a valid location type, should be tuple or dict")
def __getitem__(self, key):
if isinstance(key, int) and key in (0, 1):
key = "node" if key == 0 else "x"
if isinstance(key, str):
if self.reftree.check_computational_tree_active():
try:
return self.comp_loc[key]
except AttributeError:
self._set_computational_loc()
return self.comp_loc[key]
else:
return self.loc[key]
def __eq__(self, other_loc):
loc1 = self.loc
loc2 = MorphLoc(
other_loc,
self.reftree,
set_as_comploc=self.reftree.check_computational_tree_active(),
)
# covering all posible combinations
if loc1["node"] != 1:
if loc2["node"] != 1:
if (loc1["x"] < 1e-8) and ((1.0 - loc2["x"]) < 1e-8):
node = self.reftree[loc2["node"]]
parent = self.reftree[loc1["node"]].parent_node
result = node.index == parent.index
elif (loc2["x"] < 1e-8) and ((1.0 - loc1["x"]) < 1e-8):
node = self.reftree[loc1["node"]]
parent = self.reftree[loc2["node"]].parent_node
result = node.index == parent.index
else:
result = loc1["node"] == loc2["node"]
result *= np.allclose(loc1["x"], loc2["x"])
else:
if loc1["x"] < 1e-8:
node = self.reftree[loc1["node"]]
result = node.parent_node.index == 1
else:
result = False
else:
if loc2["node"] != 1:
if loc2["x"] < 1e-8:
node = self.reftree[loc2["node"]]
result = node.parent_node.index == 1
else:
result = False
else:
result = True
return result
def keys(self):
return ["node", "x"]
def __iter__(self):
yield self["node"]
yield self["x"]
def __neq__(self, other_loc):
result = self.__eq__(other_loc)
if result is NotImplemented:
return result
else:
return not result
def __copy__(self):
"""
Customization of the copy function so that `loc` and `comp_loc`
attributes are deep copied and `reftree` attribute still refers to the
original tree
"""
new_loc = type(self)(copy.deepcopy(self.loc), self.reftree)
if hasattr(self, "comp_loc"):
new_loc.__dict__.update({"comp_loc": copy.deepcopy(self.comp_loc)})
return new_loc
def __str__(self):
return "{'node': %d, 'x': %.2f }" % (self.loc["node"], self.loc["x"])
def __repr__(self):
return str(self)
def _set_computational_loc(self):
if self.loc["node"] != 1:
with self.reftree.as_original_tree:
node = self.reftree[self.loc["node"]]
# find the computational nodes that are resp. up and down from the node
node_start = self.reftree._find_comp_node_to_root(node.parent_node)
node_stop = self.reftree._find_comp_node_from_root(node)
# length between loc and parent computational node to compute segment
# length
L = self.reftree.path_length(
{"node": node_start.index, "x": 1.0}, self.loc
)
# get the computational nodes' length
with self.reftree.as_computational_tree:
L_cn = self.reftree[node_stop.index].L
# set the computational loc
self.comp_loc = {"node": node_stop.index, "x": L / L_cn}
else:
self.comp_loc = copy.deepcopy(self.loc)
def _set_original_loc(self):
if self.comp_loc["node"] != 1:
with self.reftree.as_computational_tree:
compnode = self.reftree[self.comp_loc["node"]]
with self.reftree.as_original_tree:
node = self.reftree[self.comp_loc["node"]]
# find the computational node that is down from the original node
pcnode = self.reftree._find_comp_node_to_root(node.parent_node)
# find the node index and x-coordinate of the original location
path = self.reftree.path_between_nodes(pcnode, node)
L0 = 0.0
found = False
for pathnode in path[1:]:
L1 = L0 + pathnode.L
Lloc = self.comp_loc["x"] * compnode.L
if Lloc == 0.0:
Lloc += 1e-7
if Lloc > L0 and Lloc <= L1:
self.loc = {
"node": pathnode.index,
"x": (Lloc - L0 - 1e-8) / pathnode.L,
}
L0 = L1
if self.loc["x"] > 1.0 or self.loc["x"] < 0.0:
raise ValueError("x-value should be in [0,1]")
else:
self.loc = copy.deepcopy(self.comp_loc)
[docs]
class MorphNode(SNode):
"""
Node associated with `neat.MorphTree`. Stores the geometrical information
associated with a point on the tree morphology
Attributes
----------
xyz: numpy.array of floats
The xyz-coordinates associated with the node (um)
R: float
The radius of the node (um)
swc_type: int
The type of node, according to the .swc file format convention:
``1`` is dendrites, ``2`` is axon, ``3`` is basal dendrite and ``4``
is apical dendrite.
L: float
The length of the node (um)
"""
def __init__(self, index, p3d=None):
super().__init__(index)
if p3d != None:
self.set_p3d(*p3d)
else:
# bogus values, to overwrite
self.set_p3d(np.array([0.0, 0.0, 0.0]), 1.0, 1)
self.L = 1.0
self.R = 1.0
[docs]
def set_p3d(self, xyz, R, swc_type):
"""
Set the 3d parameters of the node
Parameters
----------
xyz: `np.array`
3D location (um)
R: float
Radius of the segment (um)
swc_type: int
Type asscoiated with the segment according to SWC standards
"""
# morphology parameters
self.xyz = xyz
self.R = R
self.swc_type = swc_type
# auxiliary variable
self.used_in_comp_tree = False
def set_length(self, L):
"""
Set the length of the segment represented by the node
Parameters
----------
L: float
the length of the segment (um)
"""
self.L = L
def set_radius(self, R):
"""
Set the radius of the segment represented by the node
Parameters
----------
L: float
the radius of the segment (um)
"""
self.R = R
def get_child_nodes(self, skip_inds=(2, 3)):
"""
Get the `child_nodes` of this node. Indices ``2`` and ``3`` are skipped
by default (3-point soma convention)
Parameters
----------
skip_inds: list or tuple of ints
Node indices of child nodes that are not added to the returned list
Returns
-------
list of `neat.MorphNode`
The child nodes
"""
return [cnode for cnode in self._child_nodes if cnode.index not in skip_inds]
def set_child_nodes(self, cnodes):
return super().set_child_nodes(cnodes)
child_nodes = property(get_child_nodes, set_child_nodes)
def __str__(self, **kwargs):
node_str = super().__str__(**kwargs)
node_str += (
f" --- xyz = [{self.xyz[0]:.3f}, {self.xyz[1]:.3f}, {self.xyz[2]:.3f}] um, "
f"R = {self.R:.2f} um, swc_type = {self.swc_type}"
)
return node_str
def _get_repr_dict(self):
repr_dict = super()._get_repr_dict()
repr_dict.update(
{
"xyz": self.xyz,
"R": f"{self.R:1.6g}",
"swc_type": self.swc_type,
}
)
return repr_dict
def __repr__(self):
return repr(self._get_repr_dict())
[docs]
class MorphTree(STree):
"""
Subclass of simple tree that implements neuronal morphologies. Reads in
trees from '.swc' files (http://neuromorpho.org/).
Neural morphologies are assumed to follow the three-point soma convention.
Internally however, the soma is represented as a sphere. Hence nodes with
indices 2 and 3 do not represent anything and are skipped in iterations and
getters.
The `MorphTree` can also store a simplified version of the original tree,
where only nodes are retained that should hold computational parameters,
i.e. the root, the bifurcation nodes and the leafs at least, although the
user can also specify additional nodes. By default, the active tree is the
original one, where nodes correspond to the underlying '.swc' file. Trough the
`MorphTree.as_computational_tree` context, the computational tree can be
accessed.
For computational efficiency, it is possible to store sets of locations on
the morphology, under user-specified names. These sets are stored as
lists of `neat.MorphLoc`, and associated arrays are stored that contain the
corresponding node indices of the locations, their x-coordinates, their
distances to the soma and their distances to the nearest bifurcation in the
in the direction of the soma.
Parameters
----------
arg: `str` (optional), `neat.MorphNode` or subclass, or `neat.MorphTree` or subclass, or ``None``
- If `str`, represents the file path of the morphology file.
Assumed to follow the '.swc' format.
- If `neat.MorphNode` or derived class, initializes a the tree with the
provided node as root.
- If `neat.MorphTree` or derived class, initialized a copy of the provided
tree.
- Default is ``None``, which initialized an empty tree
types: list of int (optional)
The list of node types to be included. As per the '.swc' convention,
``1`` is soma, ``2`` is axon, ``3`` is basal dendrite and ``4`` apical
dendrite. Default is ``[1,3,4]``.
Attributes
----------
root: `neat.MorphNode` instance
The root of the tree.
locs: dict {str: list of `neat.MorphLoc`}
Stored sets of locations, key is the user-specified the name of the
set of locations. Initialized as empty dict.
nids: dict {str: np.array of int}
Node indices of locations. Initialized as empty dict.
xs: dict {str: np.array of float}
x-coordinates of locations. Initialized as empty dict.
d2s: dict {str: np.array of float}
distances to soma of locations. Initialized as empty dict.
d2b: dict {str: np.array of float}
distances to nearest bifurcation in the direction of the soma
of locations. Initialized as empty dict.
"""
def __init__(self, arg=None, types=[1, 3, 4]):
# we initialize two root nodes, one for the original tree mimicking the
# .swc file, and one for the coarse grained tree for computational efficiency
self._computational_root = None
self._original_root = None
# to store sets of locations on the morphology
self.locs = {}
self._nids_orig = {}
self._nids_comp = {}
self._xs_orig = {}
self._xs_comp = {}
self.d2s = {}
self.d2b = {}
self.leafinds = {}
# instantiate the tree structure
try:
swc_file_path = pathlib.Path(arg)
self.read_swc_tree_from_file(swc_file_path, types=types)
except TypeError:
super().__init__(arg)
# STree will always initialize `self._root`, independent of whether the
# input argument is a node or a tree.
# When the input argument is a node, it will *only* initialize `self._root`.
# When the input argument is a tree, it will call `self.__copy__`, which
# ensures that both the `self._original_root`` and `self._computational_root``
# (when the comptree is already set, see `self.set_comp_tree()`) are initialized.
# Therefor, the following code ensures that `self._original_root` is always set
# to the correct value.
if not self.check_computational_tree_active():
self._original_root = self._root
[docs]
def __getitem__(self, index, skip_inds=(2, 3)):
"""
Returns the node with given index, if no such node is in the tree, None
is returned.
Parameters
----------
index: int
the index of the node to be found
Returns:
`neat.MorphNode` or None
"""
return self._find_node(self.root, index, skip_inds=skip_inds)
def _find_node(self, node, index, skip_inds=(2, 3)):
"""
Breadth-first/stack iteration to replace the recursive call.
Traverses the tree until it finds the node you are looking for.
Returns SNode when found and None when not found
Parameters
----------
node: :class:`SNode` (optional)
node where the search is started
index: int
the index of the node to be found
Returns
-------
:class:`SNode`
"""
stack = []
stack.append(node)
while len(stack) != 0:
for cnode in stack:
if cnode.index == index:
return cnode
else:
stack.remove(cnode)
stack.extend(cnode.get_child_nodes(skip_inds=skip_inds))
return None # Not found!
[docs]
def __iter__(self, node=None, skip_inds=(2, 3)):
"""
Overloaded iterator from parent class that avoids iterating over the
nodes with index 2 and 3
Parameters
----------
node: `neat.MorphNode`
The starting node. Defaults to the root
skip_inds: tuple of ints
Indices of the nodes that are skipped by the iterator. Defaults
to ``(2,3)``, the nodes that contain extra geometrical
information on the soma.
Yields
------
`neat.MorphNode`
Nodes in the tree
"""
if node is None:
node = self.root
if node is not None:
if node.index not in skip_inds:
yield node
for cnode in node.get_child_nodes(skip_inds=skip_inds):
for inode in self.__iter__(cnode, skip_inds=skip_inds):
if node.index not in skip_inds:
yield inode
[docs]
def reset_indices(self):
"""
Resets the indices in the order they appear in a depth-first iteration
"""
for ind, node in enumerate(self):
node.index = ind + 1
[docs]
def get_nodes(self, skip_inds=(2, 3)):
"""
Overloads the parent function to allow skipping nodes with certain
indices and to return the nodes associated with the corresponding
active tree.
Parameters
----------
skip_inds: tuple of ints
Indices of the nodes that are skipped by the iterator. Defaults
to ``(2,3)``, the nodes that contain extra geometrical
information on the soma.
Returns
-------
list of `neat.MorphNode`
"""
nodes = []
self._gather_nodes(self.root, nodes, skip_inds=skip_inds)
return nodes
def set_nodes(self, illegal):
raise AttributeError(
f"Property 'nodes' of '{self.__class__.__name__}' object has no setter."
"Add nodes to the tree with `tree.add_node_with_parent(node)`"
)
nodes = property(get_nodes, set_nodes)
def _gather_nodes(self, node, node_list=[], skip_inds=(2, 3)):
"""
Overloaded gathering function that avoids appending nodes with index 2
or 3 to the list.
Parameters
----------
node: `neat.MorphNode`
node_list: list of `neat.MorphNode`
"""
if node.index not in skip_inds:
node_list.append(node)
for cnode in node.get_child_nodes(skip_inds=skip_inds):
self._gather_nodes(cnode, node_list=node_list, skip_inds=skip_inds)
[docs]
def get_leafs(self, node=None):
"""
Overloads the `get_leafs` of the parent class to return the leaf nodes
in the corresponding active tree.
Parameters
----------
node: `neat.SNode` (optional)
The starting node. Defaults to the root
"""
# return [node for node in self if self.is_leaf(node)]
return [n for n in self.__iter__(node) if self.is_leaf(n)]
leafs = property(get_leafs)
[docs]
def get_nodes_in_basal_subtree(self):
"""
Return the nodes associated with the basal subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the basal subtree
"""
return [node for node in self if node.swc_type in [3]]
basal_nodes = property(get_nodes_in_basal_subtree)
[docs]
def get_nodes_in_apical_subtree(self):
"""
Return the nodes associated with the apical subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the apical subtree
"""
return [node for node in self if node.swc_type in [4]]
apical_nodes = property(get_nodes_in_apical_subtree)
[docs]
def get_nodes_in_axonal_subtree(self):
"""
Return the nodes associated with the apical subtree
Returns
-------
list of `neat.MorphNode`
List of all nodes in the apical subtree
"""
return [node for node in self if node.swc_type in [2]]
def _check_computational_root(self):
if self._computational_root is None:
raise AttributeError(
"No computational tree has been defined. Use "
"`MorphTree.set_comp_tree()` or its overwritten "
"version in one of the derived classes"
)
[docs]
def set_default_tree(
self, default: Literal["original", "computational"] = "original"
):
"""
Set either the 'original' or 'computatianal' tree as the default one.
Note that this function should be used outside
`as_computational_tree()` and `as_original_tree()` context managers,
otherwise the default will not be stored.
Parameters
----------
default : Literal['original', 'computational'], optional
The tree to set as default, by default 'original'.
Raises
------
ValueError
For incorrect argument
"""
if default == "original":
self.root = self._original_root
elif default == "computational":
self._check_computational_root()
self.root = self._computational_root
else:
raise ValueError("Default tree can be 'original' or 'computational'.")
@property
@contextmanager
def as_computational_tree(self):
"""
Context manager that ensures the computational tree is active
"""
self._check_computational_root()
treetype = (
"computational" if self.check_computational_tree_active() else "original"
)
self.root = self._computational_root
try:
yield self
except Exception as e:
raise
finally:
self.root = (
self._computational_root
if treetype == "computational"
else self._original_root
)
@property
@contextmanager
def as_original_tree(self):
"""
Context manager that ensures the original tree is active
"""
treetype = (
"computational" if self.check_computational_tree_active() else "original"
)
self.root = self._original_root
try:
yield self
except Exception as e:
raise
finally:
self.root = (
self._computational_root
if treetype == "computational"
else self._original_root
)
[docs]
def check_computational_tree_active(self):
"""
check whether the computational tree is active
Returns
-------
`bool`:
``True`` if the computational tree is active
"""
return (
self.root is self._computational_root
and self._computational_root is not None
)
[docs]
def create_corresponding_node(self, node_index, p3d=None):
"""
Creates a node with the given index corresponding to the tree class.
Parameters
----------
node_index: int
index of the new node
"""
return MorphNode(node_index, p3d=p3d)
[docs]
def read_swc_tree_from_file(self, file_n, types=[1, 3, 4]):
"""
Non-specific for a "tree data structure"
Read and load a morphology from an SWC file and parse it into
an `neat.MorphTree` object.
On the NeuroMorpho.org website, 5 types of somadescriptions are
considered (http://neuromorpho.org/neuroMorpho/SomaFormat.html).
The "3-point soma" is the standard and most files are converted
to this format during a curation step. `neat` follows this default
specification and the *internal structure of `neat` implements
the 3-point soma*. Additionally multi-cylinder descriptions with more
than three nodes are also supported, but are converted to the standard
three point description.
Additionally, the root node of the tree must have ``index == 1``,
``swc_type == 1`` and occur first in the SWC file.
Parameters
-----------
file_n: str
name of the file to open
types: list of ints
NeuroMorpho.org segment types to be loaded
Examples
--------
The three point description is
.. code-block:: python
1 1 x y z r -1
1 1 x y-r z r 1
1 1 x y+r z r 1
with `x,y,z` the coordinates of the soma center and `r` the soma radius
This is a valid three point desciption
.. code-block:: python
# start of file
1 1 45.3625 18.6775 -50.25 10.1267403895 -1
2 1 45.3625 8.55075961052 -50.25 10.1267403895 1
3 1 45.3625 28.8042403895 -50.25 10.1267403895 1
# dendrite nodes
4 3 37.76 12.99 -46.08 0.29 1
5 3 26.7068019951 8.26344199599 -36.9426896493 0.795614809475 4
# ...
This is a valid multi-cylinder descirption
.. code-block:: python
# start of file
1 1 1066.38 399.67 157.0 4.9215 -1
2 1 1071.3 399.67 157.0 4.9215 1
3 1 1076.22 399.67 157.0 4.9215 2
4 1 1066.5 402.83 157.0 11.494 2
5 1 1062.4 405.5 157.0 15.308 4
6 1 1056.6 410.25 158.0 20.536 5
7 1 1056.6 410.25 158.0 20.536 6
8 1 1070.0 427.75 161.0 2.305 7
# dendrite nodes
9 3 1070.0 427.75 161.0 0.886 8
# ...
Raises
------
ValueError
If the SWC file is not consistent with the aforementioned conventions
"""
# check soma-representation: 3-point soma or a non-standard representation
soma_type = self.determine_soma_type(file_n)
file = open(file_n, "r")
all_nodes = dict()
for line in file:
if not line.startswith("#"):
split = line.split()
index = int(split[0].rstrip())
swc_type = int(split[1].rstrip())
x = float(split[2].rstrip())
y = float(split[3].rstrip())
z = float(split[4].rstrip())
radius = float(split[5].rstrip())
parent_index = int(split[6].rstrip())
# create the nodes
if swc_type in types:
p3d = (np.array([x, y, z]), radius, swc_type)
node = self.create_corresponding_node(index, p3d)
all_nodes[index] = (swc_type, node, parent_index)
# check if node with index 1 is soma node (swc_type == 1)
if all_nodes[1][0] != 1:
raise ValueError(
"Node with index 1 should be soma-type, i.e. swc_type == 1"
)
# one point soma representation
if soma_type == 0:
for index, (swc_type, node, parent_index) in list(all_nodes.items()):
if index == 1:
self.root = node
else:
parent_node = all_nodes[parent_index][1]
self.add_node_with_parent(node, parent_node)
# we add nodes 2 and 3 to adhere to obtain a 3-point soma representation
# (http://neuromorpho.org/neuroMorpho/SomaFormat.html)
# and increment the indices of all nodes that are not the soma by 2
for node in self.__iter__(skip_inds=[]):
if self.is_root(node):
# create p3ds for the extra soma nodes
xyz_2 = copy.copy(node.xyz)
xyz_2[1] -= node.R
p3d_2 = (xyz_2, node.R, node.swc_type)
xyz_3 = copy.copy(node.xyz)
xyz_3[1] += node.R
p3d_3 = (xyz_3, node.R, node.swc_type)
else:
node.index += 2
# add extra soma nodes to the tree
snode_2 = self.create_corresponding_node(2, p3d_2)
snode_3 = self.create_corresponding_node(3, p3d_3)
self.add_node_with_parent(snode_2, self[1])
self.add_node_with_parent(snode_3, self[1])
# three point soma representation
if soma_type == 1:
for index, (swc_type, node, parent_index) in list(all_nodes.items()):
if index == 1:
self.root = node
elif index in (2, 3):
# the 3-point soma representation
# (https://neuromorpho.org/SomaFormat.html)
somanode = all_nodes[1][1]
self.add_node_with_parent(node, somanode)
else:
parent_node = all_nodes[parent_index][1]
self.add_node_with_parent(node, parent_node)
# check if soma follows three point convention
radius_arr = np.array(
[
all_nodes[1][1].R,
all_nodes[2][1].R,
all_nodes[3][1].R,
np.linalg.norm(all_nodes[2][1].xyz - all_nodes[1][1].xyz),
np.linalg.norm(all_nodes[3][1].xyz - all_nodes[1][1].xyz),
]
)
if not np.allclose(
np.abs(radius_arr - radius_arr[0]), np.zeros_like(radius_arr), atol=2e-2
):
raise ValueError(
"Soma radii not consistent with three-point convention"
)
# IF multiple cylinder soma representation
elif soma_type == 2:
self.root = all_nodes[1][1]
# get all soma info
soma_cylinders = []
connected_to_root = []
for index, (swc_type, node, parent_index) in list(all_nodes.items()):
if swc_type == 1 and not index == 1:
soma_cylinders.append((node, parent_index))
if index > 1:
connected_to_root.append(index)
# make soma
s_node_2, s_node_3 = self._make_soma_from_cylinders(
soma_cylinders, all_nodes
)
# add soma
self.root.R = s_node_2.R
self.add_node_with_parent(s_node_2, self.root)
self.add_node_with_parent(s_node_3, self.root)
# add the other points
for index, (swc_type, node, parent_index) in list(all_nodes.items()):
if swc_type == 1:
pass
else:
parent_node = all_nodes[parent_index][1]
if parent_node.index in connected_to_root:
self.add_node_with_parent(node, self.root)
else:
self.add_node_with_parent(node, parent_node)
# set the lengths of the nodes
for node in self:
if node.parent_node != None:
L = np.sqrt(np.sum((node.parent_node.xyz - node.xyz) ** 2))
# if the length is zero we can just delete it
if L == 0:
self.remove_single_node(node)
else:
L = 0.0
node.set_length(L)
# store the created root as the root of the original tree
self._original_root = self.root
return self
def _make_soma_from_cylinders(self, soma_cylinders, all_nodes):
"""
Construct 3-point soma
Step 1: calculate surface of all cylinders
Step 2: make 3-point representation with the same surface
"""
total_surf = 0
xyz_sum = self.root.xyz
for node, parent_index in soma_cylinders:
parent = all_nodes[parent_index][1]
nxyz = node.xyz
pxyz = parent.xyz
H = np.sqrt(np.sum((nxyz - pxyz) ** 2))
surf = 2 * np.pi * parent.R * H
total_surf += surf
xyz_sum += node.xyz
# define apropriate radius
radius = np.sqrt(total_surf / (4.0 * np.pi))
rp = xyz_sum / (len(soma_cylinders) + 1.0)
rp2 = np.array([rp[0], rp[1] - radius, rp[2]])
rp3 = np.array([rp[0], rp[1] + radius, rp[2]])
self.root.xyz = rp
# create the soma nodes
s_node_2 = self.create_corresponding_node(2, (rp2, radius, 1))
s_node_3 = self.create_corresponding_node(3, (rp3, radius, 1))
return s_node_2, s_node_3
[docs]
def determine_soma_type(self, file_n):
"""
Determine the soma type used in the SWC file.
This method searches the whole file for soma entries.
Only tbe standard three-point soma type and a multi-cylinder description
are supported.
Furthermore, the root node of the tree must have ``index == 1``,
``swc_type == 1`` and occur first in the SWC file.
Parameters
----------
file_n: string
Name of the file containing the SWC description
Returns
-------
soma_type: int
Integer indicating one of the su[pported SWC soma formats.
0: One point soma
1: Default three-point soma,
2: multiple cylinder description
Raises
------
ValueError
If soma type is not supported (less than three nodes have soma)
"""
file = open(file_n, "r")
somas = 0
for line in file:
if not line.startswith("#"):
split = line.split()
index = int(split[0].rstrip())
s_type = int(split[1].rstrip())
if s_type == 1:
somas = somas + 1
file.close()
if somas == 1:
return 0
if somas == 3:
return 1
elif somas < 3:
raise ValueError(
"Soma description not supported, use 3-point or multi-cylinder description"
)
else:
return 2
[docs]
def _evaluate_comp_criteria(self, node, eps=1e-8, rbool=False):
"""
Return ``True`` if relative difference between node radius and parent
node raidus is larger than margin ``eps``, or if the node is the root
or bifurcation node.
Parameters
----------
node: `neat.MorphNode`
node that is compared to parent node
eps: float (optional, default ``1e-8``)
the margin
return
------
bool
"""
if not rbool:
rbool = node.parent_node == None
if not rbool:
rbool = len(node.get_child_nodes()) != 1
if not rbool:
cnode = node.child_nodes[0]
rbool = np.abs(node.R - cnode.R) > eps * np.max([node.R, cnode.R])
return rbool
[docs]
def set_comp_tree(self, compnodes=None, eps=1e-8):
"""
Sets the nodes that contain computational parameters. This are a priori
either bifurcations, leafs, the root or nodes where the neurons'
relevant parameters change.
Parameters
----------
compnodes: list of ::class::`MorphNode`
list of nodes that should be retained in the computational tree.
Note that specifying bifurcations, leafs or the root is
superfluous, since they are part of the computational tree by
default.
eps: float (default ``1e-8``)
relative margin for parameter change
"""
self.remove_comp_tree()
if compnodes is None:
compnodes = []
compnodes += [
node for node in self if self._evaluate_comp_criteria(node, eps=eps)
]
compnode_indices = [node.index for node in compnodes]
nodes = copy.deepcopy(self.nodes)
for node in nodes:
if node.index not in compnode_indices:
self.remove_single_node(node)
elif node.parent_node != None:
orig_node = self[node.index]
orig_bnode = node.parent_node
L, R = self.path_length(
{"node": orig_bnode.index, "x": 1.0},
{"node": orig_node.index, "x": 1.0},
compute_radius=1,
)
node.set_length(L)
node.set_radius(R)
node.used_in_comp_tree = True
orig_node.used_in_comp_tree = True
else:
orig_node = self[node.index]
node.used_in_comp_tree = True
orig_node.used_in_comp_tree = True
self._computational_root = next(node for node in nodes if node.index == 1)
self._leafs_comp = [node for node in nodes if self.is_leaf(node)]
self._nodes_comp = []
self._gather_nodes(self._computational_root, self._nodes_comp)
# create conversion of all coordinate arrays
for name in self.locs:
self._store_comp_locs(name)
def _find_comp_node_to_root(self, node):
"""
!!! Computational tree has to be initialized, otherwise may results in
error !!!
If the input node is a node of the original tree, finds the first node
on the path to the root that has an equivalent in the computational tree.
If the input node has such an equivalent, it is returned itself.
If the input node is in the computational tree, returns the node itself.
Parameters
----------
node: `neat.MorphNode` instance
the input node
Returns
-------
`neat.MorphNode` instance
"""
if not node.used_in_comp_tree:
node = self._find_comp_node_to_root(node.parent_node)
return node
def _find_comp_node_from_root(self, node):
"""
!!! Computational tree has to be initialized, otherwise may results in
error !!!
If the input node is a node of the original tree, finds the first node
away from the root that has an equivalent in the computational tree. If
the input node has such an equivalent, it is returned itself.
If the input node is in the computational tree, returns the node itself.
Parameters
----------
node: `neat.MorphNode` instance
the input node
Returns
-------
`neat.MorphNode` instance
"""
if not node.used_in_comp_tree:
node = self._find_comp_node_from_root(node.child_nodes[0])
return node
[docs]
def remove_comp_tree(self):
"""
Removes the computational tree
"""
self._computational_root = None
for node in self:
node.used_in_comp_tree = False
[docs]
def convert_loc_arg_to_locs(self, loc_arg):
"""
Converts locations argument to list of `neat.MorphLoc`.
Parameters
----------
loc_arg: list of dictionaries, tuples or `neat.MorphLoc`, or string
* If list, entries should be valid arguments to initialize a `neat.MorphLoc`
* If string, should be the name of a list of locations stored in `self`
Returns
-------
list of `neat.MorphLoc`
List of locations, each referencing the current tree
"""
if isinstance(loc_arg, list):
locs = [MorphLoc(loc, self) for loc in loc_arg]
elif isinstance(loc_arg, str):
self._try_name(loc_arg)
locs = self.get_locs(loc_arg)
else:
raise IOError("`loc_arg` should be list of locs or string")
return locs
[docs]
def convert_node_arg_to_nodes(self, node_arg):
"""
Converts a node argument to a list of nodes. Behaviour depends on the
type of argument.
Parameters
----------
node_arg: ``None``, `neat.MorphNode`, {'apical', 'basal', 'axonal', 'somatic'} or iterable collection of instances of `neat.MorphNode`
* `None`: returns all nodes
* `neat.MorphNode`: returns list of nodes in the subtree of the given node
* {'apical', 'basal', 'axonal'}: returns list of nodes in the apical, basal or axonal subtree
* iterable collection of `neat.MorphNode`: returns the same list of nodes
If an iterable collection of original nodes is given, and the computational
tree is active, a reduced list is returned where only the corresponding
computational nodes are included. If an iterable collection of
computational nodes is given, and the original tree is active, a list of
corresponding original nodes is given, but the in between nodes are not
added.
Returns
-------
list of `neat.MorphNode`
"""
# convert the input argument to a list of nodes
nodes = []
if node_arg == None:
nodes.extend(self.nodes)
elif isinstance(node_arg, MorphNode):
if self.check_computational_tree_active():
# assure that a list of computational nodes is returned
node_arg = self._find_comp_node_from_root(node_arg)
node_arg = self[node_arg.index]
else:
# assure that a list of original nodes is returned
node_arg = self[node_arg.index]
nodes.extend(self.gather_nodes(node_arg))
elif hasattr(node_arg, "__len__") and len(node_arg) > 0:
# if node arg is iterable, it may a string, a list of strings
# or a list of nodes
check_all_morphnode = True
for node in node_arg:
if not isinstance(node, MorphNode):
check_all_morphnode = False
if check_all_morphnode:
for node in node_arg:
if self.check_computational_tree_active():
# assure that a list of computational nodes is returned
node_ = self._find_comp_node_from_root(node)
compnode = self[node_.index]
if compnode not in nodes:
nodes.append(compnode)
else:
# assure that a list of original nodes is returned
nodes.append(self[node.index])
else:
if not isinstance(node_arg, str):
for entry in node_arg:
if entry not in ["apical", "basal", "axonal", "somatic"]:
raise IOError(
"input should be (i) `None`, (ii) an instance of "
"`neat.MorphNode`, (iii) one of the following 4 strings "
"'somatic', 'apical', 'basal' or 'axonal' or (iv) an iterable "
"collection of instances of :class:MorphNode"
)
else:
if node_arg not in ["apical", "basal", "axonal", "somatic"]:
raise IOError(
"input should be (i) `None`, (ii) an instance of "
"`neat.MorphNode`, (iii) one of the following 3 strings "
"'somatic', 'apical', 'basal' or 'axonal' or (iv) an iterable "
"collection of instances of :class:MorphNode"
)
if "apical" in node_arg:
nodes.extend(self.get_nodes_in_apical_subtree())
if "basal" in node_arg:
nodes.extend(self.get_nodes_in_basal_subtree())
if "axonal" in node_arg:
nodes.extend(self.get_nodes_in_axonal_subtree())
if "somatic" in node_arg:
nodes.extend([self[1]])
else:
raise IOError(
"input should be (i) `None`, (ii) an instance of "
"`neat.MorphNode`, (iii) one of the following 3 strings "
"'apical', 'basal' or 'axonal' or (iv) an iterable "
"collection of instances of :class:MorphNode"
)
return nodes
[docs]
def path_length(self, loc1, loc2, compute_radius=0):
"""
Find the length of the direct path between loc1 and loc2
Parameters
----------
loc1: dict, tuple or `neat.MorphLoc`
one location
loc2: dict, tuple or `neat.MorphLoc`
other location
compute_radius: bool
if True, also computes the average weighted radius of the path
Returns
-------
L, R (optional)
L: float
length of path, in micron
R: float
weighted average radius of path, in micron
"""
# define location objects
if type(loc1) == dict or type(loc1) == tuple:
loc1 = MorphLoc(loc1, self)
if type(loc2) == dict or type(loc2) == tuple:
loc2 = MorphLoc(loc2, self)
# start path length calculation
if loc1["node"] == loc2["node"]:
node = self[loc1["node"]]
if node.index == 1:
L = 0.0 # soma is spherical and has no lenght
else:
L = node.L * np.abs(loc1["x"] - loc2["x"])
if compute_radius:
R = node.R
else:
node1 = self[loc1["node"]]
node2 = self[loc2["node"]]
path1 = self.path_to_root(node1)[::-1]
path2 = self.path_to_root(node2)[::-1]
path = path1 if len(path1) < len(path2) else path2
ind = next(
(ii for ii in range(len(path)) if path1[ii] != path2[ii]), len(path)
)
if path1[ind - 1] == node1:
L = node1.L * (1.0 - loc1["x"])
L += sum(node.L for node in path2[ind:-1])
L += node2.L * loc2["x"]
if compute_radius:
R = node1.R * node1.L * (1.0 - loc1["x"])
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * loc2["x"]
R /= L
elif path2[ind - 1] == node2:
L = node1.L * loc1["x"]
L += sum(node.L for node in path1[ind:-1])
L += node2.L * (1.0 - loc2["x"])
if compute_radius:
R = node1.R * node1.L * loc1["x"]
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * (1.0 - loc2["x"])
R /= L
else:
L = node1.L * loc1["x"]
L += sum(node.L for node in path1[ind:-1])
L += sum(node.L for node in path2[ind:-1])
L += node2.L * loc2["x"]
if compute_radius:
R = node1.R * node1.L * loc1["x"]
R += sum(node.R * node.L for node in path1[ind:-1])
R += sum(node.R * node.L for node in path2[ind:-1])
R += node2.R * node2.L * loc2["x"]
R /= L
if compute_radius:
return L, R
else:
return L
[docs]
@original_tree_decorator
def store_locs(self, locs, name, warn=True):
"""
Store locations under a specified name
Parameters
----------
locs: list of dicts, tuples or `neat.MorphLoc`
the locations to be stored
name: string
name under which these locations are stored
warn: bool (default ``True``)
raise a `UserWarning` if two or more locations in `locs` refer
to the soma. Choose ``False`` if this is desired to remove
the warning.
"""
# copy list and store in MorphLoc if necessary
locs_ = []
n1 = 0
for loc in locs:
locs_.append(MorphLoc(loc, self))
if locs_[-1]["node"] == 1:
n1 += 1
if n1 > 1 and warn:
warnings.warn(
"There are multiple locations on the soma in this set "
+ "locations, this can cause issues in certain functions",
UserWarning,
)
self.remove_locs(name)
self.locs[name] = locs_
self._nids_orig[name] = np.array([loc["node"] for loc in locs_])
self._xs_orig[name] = np.array([loc["x"] for loc in locs_])
if self._computational_root != None:
self._store_comp_locs(name)
@computational_tree_decorator
def _store_comp_locs(self, name):
self._nids_comp[name] = np.array([loc["node"] for loc in self.locs[name]])
self._xs_comp[name] = np.array([loc["x"] for loc in self.locs[name]])
[docs]
def add_loc(self, loc, name):
"""
Add location to set of locations of given name
Parameters
----------
loc: dict, tuple or `neat.MorphLoc`
the location to be added
name: str
the name of the set of locations to which the location is added
"""
loc = MorphLoc(loc, self)
self.locs[name].append(loc)
self._nids_orig[name] = np.concatenate((self._nids_orig[name], [loc["node"]]))
self._xs_orig[name] = np.concatenate((self._xs_orig[name], [loc["x"]]))
if self._computational_root != None:
self._add_comp_loc(loc, name)
@computational_tree_decorator
def _add_comp_loc(self, loc, name):
self._nids_comp[name] = np.concatenate((self._nids_comp[name], [loc["node"]]))
self._xs_comp[name] = np.concatenate((self._xs_comp[name], [loc["x"]]))
[docs]
def clear_locs(self):
"""
Remove all set of locs stored in the tree
"""
self.locs = {}
self._nids_orig = {}
self._nids_comp = {}
self._xs_orig = {}
self._xs_comp = {}
self.d2s = {}
self.d2b = {}
self.leafinds = {}
[docs]
def remove_locs(self, name):
"""
Remove a set of locations of a given name
Parameters
----------
name: string
name under which the desired list of locations is stored
"""
try:
del self.locs[name]
del self._nids_orig[name]
del self._nids_comp[name]
del self._xs_orig[name]
del self._xs_comp[name]
except KeyError:
pass
# warnings.warn('Locations of name %s were not defined'%name)
try:
del self.d2s[name]
except KeyError:
pass
try:
del self.d2b[name]
except KeyError:
pass
try:
del self.leafinds[name]
except KeyError:
pass
[docs]
def _try_name(self, name):
"""
Tests if the name is in use. Raises a KeyError when it is not in use and
prints a list of possible names
Parameters
----------
name: string
name of the desired list of locations
Raises
------
KeyError
If 'name' does not refer to a set of locations in use
"""
try:
self.locs[name]
except KeyError as err:
err.args = (
"'"
+ err.args[0]
+ "' name not in use. Possible names are "
+ str(list(self.locs.keys())),
)
raise
[docs]
def get_locs(self, name):
"""
Returns a set of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
Returns
-------
list of `neat.MorphLoc`
"""
self._try_name(name)
return self.locs[name]
[docs]
def get_node_indices(self, name):
"""
Returns an array of node indices of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
Returns
-------
numpy.array of ints
"""
self._try_name(name)
return self.nids[name]
def get_node_idxs(self):
if self.check_computational_tree_active():
return self._nids_comp
else:
return self._nids_orig
def set_node_idxs(self, nids):
if self.check_computational_tree_active():
self._nids_comp = nids
else:
self._nids_orig = nids
nids = property(get_node_idxs, set_node_idxs)
[docs]
def get_x_coords(self, name):
"""
Returns an array of x-values of locations of a specified name
Parameters
----------
name: string
name under which the desired list of locations is stored
"""
self._try_name(name)
return self.xs[name]
def get_xs(self):
if self.check_computational_tree_active():
return self._xs_comp
else:
return self._xs_orig
def set_xs(self, xs):
if self.check_computational_tree_active():
self._xs_comp = xs
else:
self._xs_orig = xs
xs = property(get_xs, set_xs)
[docs]
def get_loc_idxs_on_node(self, name, node):
"""
Returns a list of the indices of locations in the list of a given name
that are on a the input node, ordered for increasing x
Parameters
----------
name: string
which list of locations to consider
node: `neat.MorphNode`
the node to consider. Should be part of the original
tree
Returns
-------
list of ints
indices of locations on the path
"""
self._try_name(name)
nids = self.nids[name]
xs = self.xs[name]
# get the locinds on the node
inds = np.where(nids == node.index)[0]
sortinds = np.argsort(xs[inds])
return inds[sortinds].tolist()
[docs]
def get_loc_idxs_on_nodes(self, name, node_arg):
"""
Returns a list of the indices of locations in the list of a given name
that are on one of the nodes specified in the node list. Within each
node, locations are ordered for increasing x
Parameters
----------
name: string
which list of locations to consider
node_arg:
see documentation of `MorphTree.convert_node_arg_to_nodes`
Returns
-------
list of ints
indices of locations on the path
"""
# find locinds on all nodes
locinds = []
for node in self.convert_node_arg_to_nodes(node_arg):
locinds.extend(self.get_loc_idxs_on_node(name, node))
return locinds
[docs]
def get_loc_idxs_on_path(self, name, node0, node1, xstart=0.0, xstop=1.0):
"""
Returns a list of the indices of locations in the list of a given name
that are on the given path. The path is taken to start at the input
x-start coordinate of the first node in the list and to stop at the
given x-stop coordinate of the last node in the list
Parameters
----------
name: string
which list of locations to consider
node0: :class:`SNode`
start node of path
node1: :class:`SNode`
stop node of path
xstart: float (in ``[0,1]``)
starting coordinate on `node0`
xstop: float (in ``[0,1]``)
stopping coordinate on `node1`
Returns
-------
list of ints
Indices of locations on the path. If path is empty, an empty
array is returned.
"""
self._try_name(name)
locs = self.locs[name]
xs = self.xs[name]
# find the path
path = self.path_between_nodes(node0, node1)
# find the location indices
locinds = []
if len(path) > 1:
# first node in path
node = path[0]
ninds = np.array(self.get_loc_idxs_on_node(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if node.parent_node == path[1]:
# goes runs towards root
inds = np.where(xs[ninds] <= xstart)[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
else:
# path goes away from root
inds = np.where(xs[ninds] >= xstart)[0]
sortinds = np.argsort(xs[ninds][inds])
locinds.extend(ninds[inds][sortinds])
# middle nodes in path
for ii, node in enumerate(path[1:-1]):
ninds = np.array(self.get_loc_idxs_on_node(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
elif path[ii + 2] == node.parent_node:
# path goes towards root
sortinds = np.argsort(xs[ninds])
locinds.extend(ninds[sortinds[::-1]])
elif path[ii] == node.parent_node:
# path goes away from root
sortinds = np.argsort(xs[ninds])
locinds.extend(ninds[sortinds])
else:
# turning point (path only goes on this node at x=1)
inds = np.where((1.0 - xs[ninds]) < 1e-4)[0]
if len(inds) > 0:
locinds.extend(ninds[inds])
# last node in path
node = path[-1]
ninds = np.array(self.get_loc_idxs_on_node(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if node.parent_node == path[-2]:
# path goes away from root
inds = np.where(xs[ninds] <= xstop)[0]
sortinds = np.argsort(xs[ninds][inds])
else:
# path goes towards root
inds = np.where(xs[ninds] >= xstop)[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
locinds.extend(ninds[inds][sortinds])
elif len(path) == 1:
node = path[0]
ninds = np.array(self.get_loc_idxs_on_node(name, node)).astype(int)
if node.parent_node == None:
locinds.extend(ninds)
else:
if xstart < xstop:
inds = np.where(
np.logical_and(xs[ninds] >= xstart, xs[ninds] <= xstop)
)[0]
sortinds = np.argsort(xs[ninds][inds])
else:
inds = np.where(
np.logical_and(xs[ninds] >= xstop, xs[ninds] <= xstart)
)[0]
sortinds = np.argsort(xs[ninds][inds])[::-1]
locinds.extend(ninds[inds][sortinds])
return locinds
[docs]
def get_nearest_loc_idxs(
self, locs, name, direction=0, check_siblings=True, pprint=False
):
"""
For each location in the input location list, find the index of the
closest location in a set of locations stored under a given name. The
search can go in the either go in the up or down direction or in both
directions.
Parameters
----------
locs: list of dicts, tuples or `neat.MorphLoc`
the locations for which the nearest location index has to be
found
name: string
name under which the reference list is stored
direction: int
flag to indicate whether to search in both directions (0), only
in the direction of the root (1) or in the direction away from
the root (2).
Returns
-------
loc_idxices: list of ints
indices of the locations closest to the given locs
"""
self._try_name(name)
# create the locs in a desirable format
locs_ = []
for loc in locs:
locs_.append(MorphLoc(loc, self))
locs = locs_
# look for the location indices
loc_idxices = []
for loc in locs:
loc_idx1 = None
loc_idx2 = None
# find the location indices if necessary
if direction == 0 or direction == 1:
loc_idx1 = self._find_locs_to_root(
loc, name, check_siblings=check_siblings
)
if direction == 0 or direction == 2:
loc_idx2 = self._find_locs_from_root(loc, name)
# save the index of the closest location, if it exists and
# if it is asked for
if loc_idx1 == None and (direction == 0 or direction == 2):
loc_idxices.append(loc_idx2)
elif loc_idx2 == None and (direction == 0 or direction == 1):
loc_idxices.append(loc_idx1)
else:
L1 = self.path_length(loc, self.locs[name][loc_idx1])
L2 = self.path_length(loc, self.locs[name][loc_idx2])
if L1 >= L2:
loc_idxices.append(loc_idx2)
else:
loc_idxices.append(loc_idx1)
return loc_idxices
def _find_locs_from_root(self, loc, name):
look_further = False
# look if there are locs on the same node
n_inds = np.where(loc["node"] == self.nids[name])[0]
if len(n_inds) > 0:
if loc["node"] == 1:
loc_idx = n_inds[0]
else:
x_inds = np.where(loc["x"] <= self.xs[name][n_inds])[0]
if len(x_inds) != 0:
ind = np.argmin(self.xs[name][n_inds][x_inds])
loc_idx = n_inds[x_inds[ind]]
else:
look_further = True
else:
look_further = True
# if no locs on the same node, then proceed to child nodes
# else, return the smallest location larger than loc
if look_further:
node = self[loc["node"]]
cnodes = node.get_child_nodes()
loc_idxs = []
for cnode in cnodes:
cloc_idx = self._find_locs_from_root(
{"node": cnode.index, "x": 0.0}, name
)
if cloc_idx != None:
loc_idxs.append(cloc_idx)
# get the one that is closest, if they exist
pl_aux = 1e4
ind_loc = 0
for i, l_i in enumerate(loc_idxs):
pl = self.path_length(
{"node": loc["node"], "x": 1.0}, self.locs[name][l_i]
)
if pl < pl_aux:
pl_aux = pl
ind_loc = i
if pl_aux > 0.0 and len(loc_idxs) > 0:
loc_idx = loc_idxs[ind_loc]
elif pl_aux == 0.0 and node.index == 1:
loc_idx = loc_idxs[ind_loc]
else:
loc_idx = None
return loc_idx
def _find_locs_to_root(self, loc, name, check_siblings=True):
look_further = False
# look if there are locs on the same node
n_inds = np.where(loc["node"] == self.nids[name])[0]
if len(n_inds) > 0:
if loc["node"] == 1:
loc_idx = n_inds[0]
else:
x_inds = np.where(loc["x"] >= self.xs[name][n_inds])[0]
if len(x_inds) != 0:
ind = np.argmax(self.xs[name][n_inds][x_inds])
loc_idx = n_inds[x_inds[ind]]
else:
look_further = True
else:
look_further = True
if look_further:
# if no locs on the same node, then proceed to resp. parent and child nodes
node = self[loc["node"]]
pnode = node.get_parent_node()
loc_idxs = []
# check parent node
if pnode != None:
ploc_idx = self._find_locs_to_root(
{"node": pnode.index, "x": 1.0}, name, check_siblings=check_siblings
)
if ploc_idx != None:
loc_idxs.append(ploc_idx)
# check other child nodes of parent node
if pnode != None and check_siblings:
ocnodes = copy.copy(pnode.get_child_nodes())
ocnodes.remove(node)
else:
ocnodes = []
for cnode in ocnodes:
cloc_idx = self._find_locs_from_root(
{"node": cnode.index, "x": 0.0}, name
)
if cloc_idx != None:
loc_idxs.append(cloc_idx)
# get the one that is closest, if they exist
pl_aux = 1e4
ind_loc = 0
for i, l_i in enumerate(loc_idxs):
pl = self.path_length(
{"node": loc["node"], "x": 1.0}, self.locs[name][l_i]
)
if pl < pl_aux:
pl_aux = pl
ind_loc = i
if pl_aux > 0.0 and len(loc_idxs) > 0:
loc_idx = loc_idxs[ind_loc]
else:
loc_idx = None
return loc_idx
[docs]
def get_nearest_neighbour_loc_idxs(self, loc, loc_arg):
"""
Search nearest neighbours to `loc` in `loc_arg`.
Parameters
----------
loc: tuple, dict or `neat.MorphLoc`
The locations for which nearest neighbours have to be found
loc_arg: str or list of locs
See documentation of `MorphTree.convert_loc_arg_to_locs`, the set of locations
within which to look for nearest neighbours
Returns
-------
list of ints
Indices of nearest neighbours of `loc` in `loc_arg`
"""
# preprocess loc_arg
loc = MorphLoc(loc, self)
if isinstance(loc_arg, str):
name = loc_arg
locs = self.convert_loc_arg_to_locs(loc_arg)
else:
name = "nn aux"
locs = loc_arg
self.store_locs(locs, name=name)
nns = []
# search for nearest neighbours
node = self[loc["node"]]
locinds_aux = np.where(node.index == self.nids[name])[0]
if len(locinds_aux) > 0:
dx = self.xs[name][locinds_aux] - loc["x"]
# locs on node in down direction
inds_down = np.where(dx >= 0)[0]
if len(inds_down) > 0:
ind_aux = np.argmin(dx[inds_down])
nns.append(locinds_aux[inds_down][ind_aux])
else:
for c_node in node.child_nodes:
self._search_nn_from_root(c_node, nns, name)
# locs on node in up direction
inds_up = np.where(dx <= 0)[0]
if len(inds_up) > 0:
ind_aux = np.argmax(dx[inds_up])
nns.append(locinds_aux[inds_up][ind_aux])
else:
self._search_nn_to_root(node, nns, name)
else:
for c_node in node.child_nodes:
self._search_nn_from_root(c_node, nns, name)
self._search_nn_to_root(node, nns, name)
if name == "nn aux":
self.remove_locs(name)
return list(set(nns))
def _search_nn_to_root(self, node, nns, name):
p_node = node.parent_node
if p_node is not None:
# up direction
locinds_aux = np.where(p_node.index == self.nids[name])[0]
xval = 0.0
if len(locinds_aux) > 0:
ind_aux = np.argmax(self.xs[name][locinds_aux])
locind = locinds_aux[ind_aux]
nns.append(locind)
xval = self.xs[name][locind]
else:
self._search_nn_to_root(p_node, nns, name)
# down direction
if xval < 1.0 - 1e-5:
for c_node in set(p_node.child_nodes) - {node}:
self._search_nn_from_root(c_node, nns, name)
def _search_nn_from_root(self, node, nns, name):
locinds_aux = np.where(node.index == self.nids[name])[0]
if len(locinds_aux) > 0:
ind_aux = np.argmin(self.xs[name][locinds_aux])
locind = locinds_aux[ind_aux]
nns.append(locind)
else:
for c_node in node.child_nodes:
self._search_nn_from_root(c_node, nns, name)
[docs]
def get_leaf_loc_idxs(self, name, recompute=False):
"""
Find the indices in the location list specified stored under `name`
that are 'leafs', i.e. locations for which no other location exist
that is farther from the root.
Parameters
----------
name: string
name of the desired set of locations
recompute: bool (optional, default ``False``)
whether or not to force recomputing the distances
Returns
-------
list of inds
the indices of the 'leaf' locations
"""
try:
if recompute:
raise KeyError
self.leafinds[name]
except KeyError:
self._try_name(name)
self.leafinds[name] = []
locs = self.locs[name]
for ind, loc in enumerate(locs):
if not self._has_loc_from_root(loc, name):
self.leafinds[name].append(ind)
return self.leafinds[name]
def _has_loc_from_root(self, loc, name):
look_further = False
# look if there are locs on the same node
if loc["node"] != 1:
n_inds = np.where(loc["node"] == self.nids[name])[0]
if len(n_inds) > 0:
x_inds = np.where(loc["x"] < self.xs[name][n_inds])[0]
if len(x_inds) > 0:
returnbool = True
else:
look_further = True
else:
look_further = True
else:
look_further = True
# if no locs on the same node, then proceed to child nodes
if look_further:
node = self[loc["node"]]
cnodes = node.child_nodes
returnbool = False
for cnode in cnodes:
if self._has_loc_from_root({"node": cnode.index, "x": 0.0}, name):
returnbool = True
return returnbool
[docs]
def distances_to_soma(self, loc_arg, recompute=False):
"""
Compute the distance of each location in a given set to the soma
Parameters
----------
loc_arg: list of locations or string
if list of locations, specifies the locations, if str,
specifies the name under which the set of location is stored
that should be used to create the new tree
Returns
-------
np.array of float
the distances to the soma of the corresponding locations
recompute: bool (optional)
whether or not to force recomputing the distances
"""
# process input argument
if isinstance(loc_arg, list):
locs = [MorphLoc(loc, self) for loc in loc_arg]
recompute = True
save = False
elif isinstance(loc_arg, str):
name = loc_arg
self._try_name(name)
locs = self.get_locs(name)
recompute = not (name in self.d2s) or recompute
save = True
else:
raise IOError("`loc_arg` should be list of locs or string")
if recompute:
d2s = np.array(
[self.path_length({"node": 1, "x": 0.0}, loc) for loc in locs]
)
else:
d2s = self.d2s[name]
if save:
self.d2s[name] = d2s
return d2s
[docs]
def distances_to_bifurcation(self, name, recompute=False):
"""
Compute the distance of each location to the nearest bifurcation in
the 'up' direction (towards root)
Parameters
----------
name: str
name of the set of locations
recompute: bool (optional, default ``False``)
whether or not to force recomputing the distances
Returns
-------
np.array of floats
the distances to the nearest bifurcation of the corresponding
locations
"""
try:
if recompute:
raise KeyError
return self.d2b[name]
except KeyError:
self._try_name(name)
self.d2b[name] = []
locs = self.locs[name]
for i, loc in enumerate(locs):
if loc["node"] != 1:
if loc["node"] != locs[i - 1]["node"]:
node = self[loc["node"]]
bnode, _ = self.find_bifurcation_node_to_root(node)
self.d2b[name].append(
self.path_length({"node": bnode.index, "x": 1.0}, loc)
)
else:
self.d2b[name].append(0.0)
return self.d2b[name]
[docs]
def distribute_locs_at_d2s(self, d2s, node_arg=None, name="dont save"):
"""
Distributes locs on a given set of nodes at specified distances to the
soma ('d2s'). If the specified distances are on the specified nodes, the list
of locations will be empty. The locations are stored if the name is set
to be something other than 'dont save'. On each node, locations are
ordered from low to high x-values.
Parameters
----------
d2s: numpy.array of floats
the distances from the soma at which to put the locations (micron)
node_arg:
see documentation of `MorphTree.convert_node_arg_to_nodes`
name: string
the name under which the locations are stored. Defaults to 'dont save'
which means the locations are not stored
Returns
-------
list of `neat.MorphLoc`
the list of locations
"""
# distribute the locations
locs = []
for node in self.convert_node_arg_to_nodes(node_arg):
if node.parent_node != None:
L0 = self.path_length(
{"node": 1, "x": 0.5}, {"node": node.index, "x": 0.0}
)
L1 = self.path_length(
{"node": 1, "x": 0.5}, {"node": node.index, "x": 1.0}
)
inds = np.where(np.logical_and(L0 < d2s, d2s <= L1))[0]
Ls = np.sort(d2s[inds])
locs.extend(
[
MorphLoc((node.index, (L - L0) / (L1 - L0)), self)
for L in Ls
if L > 1e-12
]
)
elif np.any(np.abs(d2s) <= 1e-12):
# node is soma, append a location on the soma
locs.append(MorphLoc((node.index, 0.5), self))
if name != "dont save":
self.store_locs(locs, name=name)
return locs
[docs]
def distribute_locs_random(
self,
num,
dx=0.001,
node_arg=None,
add_soma=True,
name="dont save",
seed=None,
):
"""
Returns a list of input locations randomly distributed on the tree.
Locations are distributed uniformly, with the optional exclusion of an interval
of size `dx` around each location.
Parameters
----------
num: int
number of inputs
dx: float (optional)
minimal or given distance between input locations (micron)
node_arg (optional):
see documentation of `MorphTree.convert_node_arg_to_nodes`
add_soma: bool (optional)
whether or not to include the possibility of adding locations on the
soma
name: string (optional)
the name under which the locations are stored. Defaults to 'dont save'
which means the locations are not stored
seed: int (optiona)
Seed for numpy random number generator
Returns
-------
list of `neat.MorphLoc`
the locations
"""
rng = np.random.default_rng(seed=seed)
# setup, tag all nodes as empty intervals (no space is excluded for sampling)
# when a location is sampled, we add an interval around that location to the tag,
# which will exclude that interval from being sampled
for node in self:
node.content["tag"] = interval()
# use the requested subset of nodes
nodes = [
node for node in self.convert_node_arg_to_nodes(node_arg) if node.index != 1
]
# initialize the loclist with or without soma
if add_soma:
locs = [MorphLoc({"node": 1, "x": 0.0}, self)]
self.root.content["tag"] = interval([0.0, 1.0])
else:
locs = []
nodes_left = [node.index for node in nodes]
probs_left = np.array([node.L for node in nodes])
probs_left /= np.sum(probs_left)
jj = num
while jj > 0 and len(nodes_left) > 0:
index = rng.choice(nodes_left, p=probs_left, replace=True)
node = self[index]
# draw an x-value not in the tagged intervals
x = rng.uniform()
ii = 0
while x in node.content["tag"] and ii < 100:
x = rng.uniform()
ii += 1
if ii == 100:
node.content["tag"] = interval([0.0, 1.0])
else:
# add the location
locs.append(MorphLoc((index, x), self))
jj -= 1
# tag new intervals
self._tag_nodes_from_root(locs[-1], node, dx=dx)
self._tag_nodes_to_root(locs[-1], node, dx=dx)
# if node is full, remove it from sampling set
nodes_left_ = []
probs_left_ = []
for nidx, prob in zip(nodes_left, probs_left):
if interval([0, 1]) not in self[nidx].content["tag"]:
nodes_left_.append(nidx)
probs_left_.append(prob)
nodes_left = nodes_left_
probs_left = np.array(probs_left_) / np.sum(probs_left_)
self._remove_tags()
# store the locations
if name != "dont save":
self.store_locs(locs, name=name)
return locs
def _tag_nodes_from_root(self, start_loc, node, dx=0.001):
"""
tag intervals within dx from the start_loc, away from root
"""
if dx <= 0:
return
# check if we are on the first node or deeper into the recursion
if node.index == start_loc["node"]:
_loc = start_loc
else:
_loc = MorphLoc((node.index, 0.0), self)
d0 = _loc["x"] * node.L
d1 = d0 + dx
if d1 > node.L:
# append the interval to the excluded zone and continue
node.content["tag"] = node.content["tag"] | interval([_loc["x"], 1.0])
# leftover part of dx
dx -= node.L - d0
for cnode in node.child_nodes:
self._tag_nodes_from_root(start_loc, cnode, dx=dx)
else:
# append the interval to the excluded zone
node.content["tag"] = node.content["tag"] | interval(
[_loc["x"], d1 / node.L]
)
def _tag_nodes_to_root(self, start_loc, node, cnode=None, dx=0.001):
"""
tag intervals within dx from the start_loc, towards from root
"""
# check if we are on the first node or deeper into the recursion
if node.index == start_loc["node"]:
_loc = start_loc
else:
_loc = MorphLoc((node.index, 1.0), self)
d0 = _loc["x"] * node.L
d1 = d0 - dx
if d1 < 0.0:
node.content["tag"] = node.content["tag"] | interval([0.0, _loc["x"]])
dx -= d0
# if we are deeper in the recursion, we also have to assess
# sibling branches at bifurcations
if _loc != start_loc:
cnodes = node.child_nodes
if len(cnodes) > 1:
if cnode != None:
cnodes = list(set(cnodes) - set([cnode]))
for cn in cnodes:
self._tag_nodes_from_root(start_loc, cn, dx=dx)
pnode = node.get_parent_node()
if pnode != None:
self._tag_nodes_to_root(start_loc, pnode, node, dx=dx)
else:
node.content["tag"] = node.content["tag"] | interval(
[d1 / node.L, start_loc["x"]]
)
def _remove_tags(self):
for node in self:
if "tag" in node.content:
del node.content["tag"]
def distribute_locs_finite_diff(
self, dx_max=15.0, node_arg=None, name="dont store"
):
"""
Distribute locs in such a way that they correspond to the compartment
locations under NEAT's finite difference approximation.
Parameters
----------
dx_max: float
Maximum distance step between compartments (in [um]). By default,
each node of this tree will correspond to at least one compartment,
and thus one node in the comparment tree. If the length of a node
exceeds `dx_max`, there will be the smallest possible number of
equally spaced comparments so that the distance between them does
not exceed `dx_max`. Note that if the computational tree is active,
the computational nodes will be taken as a reference for placing
the compartment locations.
node_arg (optional):
see documentation of `MorphTree.convert_node_arg_to_nodes`
name: string (optional)
the name under which the locations are stored. Defaults to 'dont save'
which means the locations are not stored
Returns
-------
list of `neat.MorphLoc`
the locations
"""
set_as_comploc = self.check_computational_tree_active()
locs = []
for node in self.convert_node_arg_to_nodes(node_arg):
if self.is_root(node):
locs.append(
MorphLoc((node.index, 0.5), self, set_as_comploc=set_as_comploc)
)
else:
n_comp = np.ceil(node.L / dx_max).astype(int)
for cc in range(1, n_comp + 1):
new_loc = MorphLoc(
(node.index, cc / n_comp), self, set_as_comploc=set_as_comploc
)
locs.append(new_loc)
if name != "dont store":
self.store_locs(locs, name)
return locs
[docs]
def extend_with_bifurcation_locs(self, loc_arg, name="dont save"):
"""
Extends input loc_arg with the intermediate bifurcations. They are
appended to the end of the list
Parameters
----------
loc_arg: list of `neat.MorphLoc` or string
the locations
name: string (optional)
The name under which the list of bifurcation locs will be stored.
Defaults to 'dont save' which means they are not stored.
Returns
-------
list of `neat.MorphLoc`
the extended location list
"""
locs = self.convert_loc_arg_to_locs(loc_arg)
# get the bifurcation locs
nodes = [self[loc["node"]] for loc in locs]
bnodes = self.find_in_between_bifurcation_nodes(nodes)
blocs = [MorphLoc((bnode.index, 1.0), self) for bnode in bnodes]
# retain unique locs
all_locs = self.unique_locs(locs + blocs)
# store the locations
if name != "dont save":
self.store_locs(all_locs, name=name)
return all_locs
[docs]
def unique_locs(self, loc_arg, name="dont save"):
"""
Gets the unique locations in the provided locs
Parameters
----------
loc_arg: list of `neat.MorphLoc` or string
the locations
name: string (optional)
The name under which the list of bifurcation locs will be stored.
Defaults to 'dont save' which means they are not stored.
Returns
-------
list of `neat.MorphLoc`
the bifurcation locs
"""
locs = self.convert_loc_arg_to_locs(loc_arg)
locs_ = reduce(lambda l, x: l.append(x) or l if x not in l else l, locs, [])
if name != "dont save":
self.store_locs(locs_, name=name)
return locs_
[docs]
def make_x_axis(self, dx=10.0, node_arg=None, loc_arg=None, scheme="iterator"):
"""
Create a set of locs suitable for serving as the x-axis for 1D plotting.
The neurons is put on a 1D axis with a depth-first ordering.
Parameters
----------
dx: float
target separation between the plot points (micron)
node_arg:
see documentation of `MorphTree.convert_node_arg_to_nodes`
The nodes on which the locations for the x-axis are distributed.
When this is given as a list of nodes, assumes a depth first
ordering.
loc_arg: list of locs or string
if list of locs, these locs will be used as x-axis, if string, name
of set of locs on the morphology that will be used as x-axis
scheme: str (default: 'iterator')
scheme for coloring the nodes. Options are:
'iterator': colors nodes in depth-first order
'depth': colors nodes based on distance to leafs
'random': colors nodes randomly
startnode: `neat.MorphNode` (optional)
node from which to start coloring. If None, starts from root
"""
if loc_arg is None:
# if comptree has not been set, create a basic one for plotting
if self._computational_root == None:
self.set_comp_tree()
# distribute the x-axis locations
self.distribute_locs_uniform(dx, node_arg=node_arg, name="xaxis")
# get the root node
nodes = self.convert_node_arg_to_nodes(node_arg)
# check that first node is root
for node in nodes:
if nodes[0] in node.child_nodes:
raise ValueError(
"Input `node_arg` is not a depth-first ordered"
" list of nodes."
)
# set the node colors for original and computational trees
tempnode_orig = self._find_comp_node_from_root(nodes[0])
self.set_node_colors(tempnode_orig, scheme=scheme)
with self.as_computational_tree:
tempnode_comp = self[tempnode_orig.index]
self.set_node_colors(tempnode_comp, scheme=scheme)
else:
if isinstance(loc_arg, list):
self.store_locs(loc_arg, name="xaxis")
elif isinstance(loc_arg, str):
self.store_locs(self.get_locs(loc_arg), name="xaxis")
else:
raise IOError("`loc_org` should be string or list of locs")
# compute the x-axis 1D array
pinds = self.get_leaf_loc_idxs("xaxis")
d2s = self.distances_to_soma("xaxis")
xaxis = d2s[0 : pinds[0] + 1].tolist()
d_add = d2s[pinds[0]]
for ii in range(0, len(pinds) - 1):
xaxis.extend(
(
d_add + d2s[pinds[ii] + 1 : pinds[ii + 1] + 1] - d2s[pinds[ii] + 1]
).tolist()
)
d_add += d2s[pinds[ii + 1]] - d2s[pinds[ii] + 1]
self.xaxis = np.array(xaxis)
[docs]
def set_node_colors(self, startnode=None, scheme="iterator"):
"""
Set the color code for the nodes for 1D plotting
Parameters
----------
scheme: str (default: 'iterator')
scheme for coloring the nodes. Options are:
'iterator': colors nodes in depth-first order
'depth': colors nodes based on distance to leafs
'random': colors nodes randomly
startnode: `neat.MorphNode` (optional)
node from which to start coloring. If None, starts from root
"""
if startnode == None:
startnode = self.root
for node in self:
node.content["color"] = 0.0
self.node_color = [0.0] # trick to pass the pointer and not the number itself
if scheme == "iterator":
self._set_node_colors_from_iterator(startnode)
elif scheme == "depth":
self._set_node_colors_from_depth(startnode)
elif scheme == "random":
rng = np.random.default_rng()
self.node_colors = rng.permuted(np.arange(len(self.leafs)))
self._set_node_colors_from_rng(startnode, rng)
self.node_colors = None
else:
raise ValueError(
"Unknown scheme for node coloring, choose from 'iterator', 'depth' or 'random'"
)
self.node_color = None
def _set_node_colors_from_iterator(self, node):
node.content["color"] = self.node_color[0]
if self.is_leaf(node):
self.node_color[0] += 1.0
for cnode in node.child_nodes:
self._set_node_colors_from_iterator(cnode)
def _set_node_colors_from_rng(self, node, rng):
node.content["color"] = self.node_colors[int(self.node_color[0])]
if self.is_leaf(node):
self.node_color[0] += 1.0
for cnode in node.child_nodes:
self._set_node_colors_from_rng(cnode, rng)
def _set_node_colors_from_depth(self, node):
node.content["color"] = self.node_color[0]
if len(node.child_nodes) > 1:
# find leafs in subtree
max_dist = []
for cnode in enumerate(node.child_nodes):
st_leafs = self.get_leafs(cnode)
max_d = max(
[
self.path_length((cnode.index, 0.0), (leaf.index, 1.0))
for leaf in st_leafs
]
)
max_dist.append(max_d)
# sort child nodes based on min distance to leafs
sorted_inds = np.argsort(max_dist)
for ind in sorted_inds:
cnode = node.child_nodes[ind]
self._set_node_colors_from_depth(cnode)
elif len(node.child_nodes) == 1:
self._set_node_colors_from_depth(node.child_nodes[0])
elif self.is_leaf(node):
self.node_color[0] += 1.0
else:
raise RuntimeError("This code should not be reachable")
[docs]
def get_x_values(self, locs):
"""
Get the corresponding location on the x-axis of the input locations
Parameters
----------
locs: list of tuples, dicts or `neat.MorphLoc`
list of the locations
"""
locinds = np.array(self.get_nearest_loc_idxs(locs, "xaxis")).astype(int)
return self.xaxis[locinds]
[docs]
def plot_1d(self, ax, parr, *args, **kwargs):
"""
Plot an array where each element corresponds to the matching location on
the x-axis with a depth-first ordering on a 1D plot
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be made
parr: numpy.array of floats
the array that will be plotted
args, kwargs:
arguments for `matplotlib.pyplot.plot`
Returns
-------
lines: list of `matplotlib.lines.Line2D` instances
the line segments corresponding to the value of the plotted array
in each branch
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs["xaxis"])
pinds = self.get_leaf_loc_idxs("xaxis")
d2s = self.distances_to_soma("xaxis")
# make the plot
lines = []
line = ax.plot(
self.xaxis[0 : pinds[0] + 1], parr[0 : pinds[0] + 1], *args, **kwargs
)
lines.append(line[0])
if "label" in list(kwargs.keys()):
kwargs = copy.deepcopy(kwargs)
del kwargs["label"]
for ii in range(0, len(pinds) - 1):
line = ax.plot(
self.xaxis[pinds[ii] + 1 : pinds[ii + 1] + 1],
parr[pinds[ii] + 1 : pinds[ii + 1] + 1],
*args,
**kwargs,
)
lines.append(line[0])
return lines
def set_line_data(self, lines, parr):
"""
Update the line objects with new data
Parameters
----------
lines: list of `matplotlib.lines.Line2D` instance
the line segments of which the data has to be updated
parr: numpy.array of floats
the array that will be put in the line segments
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs["xaxis"])
pinds = self.get_leaf_loc_idxs("xaxis")
d2s = self.distances_to_soma("xaxis")
lines[0].set_data(self.xaxis[0 : pinds[0] + 1], parr[0 : pinds[0] + 1])
for ii in range(0, len(pinds) - 1):
ll = ii + 1
lines[ll].set_data(
self.xaxis[pinds[ii] + 1 : pinds[ii + 1] + 1],
parr[pinds[ii] + 1 : pinds[ii + 1] + 1],
)
[docs]
def plot_true_d2s(self, ax, parr, cmap=None, **kwargs):
"""
Plot an array where each element corresponds to the matching location in
the x-axis location list. Now all locations are plotted at their true
distance from the soma.
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be made
parr: numpy.array of floats
the array that will be plotted
cmap: `matplotlib.colors.Colormap` instance
If provided, the lines will be colored according to the branch
to which they belong, in colors specified by the colormap
kwargs:
keyword arguments for `matplotlib.pyplot.plot`
Returns
-------
lines
lines: list of `matplotlib.lines.Line2D`
the line segments corresponding to the value of the plotted array
in each branch
Raises
------
AssertionError
When the number of elements in the data array in not equal to
the number of elements on the x-axis
"""
assert len(parr) == len(self.locs["xaxis"])
locs = self.locs["xaxis"]
pinds = self.get_leaf_loc_idxs("xaxis")
d2s = self.distances_to_soma("xaxis")
# list of colors for plotting
cs = {node.index: node.content["color"] for node in self}
cplot = [cs[loc["node"]] for loc in locs]
max_cs = max(cplot)
min_cs = min(cplot)
if np.abs(max_cs - min_cs) < 1e-12:
norm_cs = max_cs + 1e-2
else:
norm_cs = (max_cs - min_cs) * (1.0 + 1.0 / 100.0)
# create the truespace plot
lines = []
if cmap != None:
kwargs["c"] = cmap((cplot[0] - min_cs) / norm_cs)
if "color" in kwargs:
del kwargs["color"]
line = ax.plot(d2s[0 : pinds[0] + 1], parr[0 : pinds[0] + 1], **kwargs)
lines.append(line[0])
if "label" in kwargs:
del kwargs["label"]
for ii in range(0, len(pinds) - 1):
if cmap != None:
kwargs["c"] = cmap((cs[locs[pinds[ii] + 1]["node"]] - min_cs) / norm_cs)
line = ax.plot(
d2s[pinds[ii] + 1 : pinds[ii + 1] + 1],
parr[pinds[ii] + 1 : pinds[ii + 1] + 1],
**kwargs,
)
lines.append(line[0])
return lines
def _add_scalebar(
self,
ax,
xlabel=f" $\mu$m",
fstr_xlabel=r"%.0f ",
lx_offset=0.1,
ly_offset=0.1,
bx_offset=0.05,
by_offset=0.05,
bc_offset=0.0,
sb_width=4.0,
text_kwargs_x=dict(size=15.0, rotation=0, va="center"),
):
xticks = ax.get_xticks()
xlim = ax.get_xlim()
ylim = ax.get_ylim()
transf = ax.transData.inverted() + ax.transAxes
p0, p1 = transf.transform((0.0, 0.0)), transf.transform((0.0, 1.0))
bcx_offset = p1[0] - p0[0]
bcy_offset = p1[1] - p0[1]
p0, p1 = transf.transform((0.0, 0.0)), transf.transform((by_offset, -bx_offset))
bx_offset = p1[1] - p0[1]
by_offset = p1[0] - p0[0]
p0, p1 = transf.transform((0.0, 0.0)), transf.transform((ly_offset, -lx_offset))
lx_offset = p1[1] - p0[1]
ly_offset = p1[0] - p0[0]
# position and length
sblen = xticks[-1] - xticks[-2]
xpos = (xlim[0] + xlim[1]) / 2.0 + bcx_offset
ypos = ylim[0] + bx_offset
px = (xpos, ypos)
xbar = ((xpos - sblen / 2.0, xpos + sblen / 2.0), (ypos, ypos))
# draw the scale bar
ax.plot(*xbar, "k-", lw=sb_width, clip_on=False)
ax.annotate(
fstr_xlabel % sblen + xlabel,
xy=px,
xytext=(px[0], px[1] + lx_offset),
annotation_clip=False,
transform=ax.transData,
ha="center",
**text_kwargs_x,
)
ax.tick_params(axis="x", which="both", length=0, color="none")
ax.set_xlim(xlim)
[docs]
def color_x_axis(self, ax, cmap, add_scalebar=1):
"""
Color the x-axis of a plot according to the morphology.
!!! Has to be called after all lines are plotted !!!
Furthermor, node colors have to be set first. This can be done with
`MorphTree.set_node_colors()` or manually by adding a 'color' entry
to the ``MorphNode.content`` dictionary
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object of which the x-axis will be colored
cmap: `matplotlib.colors.Colormap` instance
Colormap that determines the color of each branch
sizex: float
Size of scalebar (in micron). If set to None, no scalebar is
plotted.
"""
locs = self.locs["xaxis"]
# list of colors for plotting
cs = {node.index: node.content["color"] for node in self}
cplot = [cs[loc["node"]] for loc in locs]
max_cs = max(cplot)
min_cs = min(cplot)
if np.abs(max_cs - min_cs) < 1e-12:
norm_cs = max_cs + 1e-2
else:
norm_cs = (max_cs - min_cs) * (1.0 + 1.0 / 100.0)
# necessary distance arrays
pinds = self.get_leaf_loc_idxs("xaxis")
assert len(pinds) > 0
d2s = self.distances_to_soma("xaxis")
# plot colored xaxis
ylim = np.array(ax.get_ylim())
ax.plot(
self.xaxis[0 : pinds[0] + 1],
[ylim[0] + 1e-9 for _ in d2s[0 : pinds[0] + 1]],
c=cmap((cplot[0] - min_cs) / norm_cs),
lw=10,
)
for ii in range(0, len(pinds) - 1):
if locs[pinds[ii] + 1]["node"] in list(cs.keys()):
ax.plot(
self.xaxis[pinds[ii] + 1 : pinds[ii + 1] + 1],
[ylim[0] + 1e-9 for _ in d2s[pinds[ii] + 1 : pinds[ii + 1] + 1]],
c=cmap((cs[locs[pinds[ii] + 1]["node"]] - min_cs) / norm_cs),
lw=10,
)
else:
ax.plot(
self.xaxis[pinds[ii] + 1 : pinds[ii + 1] + 1],
[ylim[0] + 1e-9 for _ in d2s[pinds[ii] + 1 : pinds[ii + 1] + 1]],
c="k",
lw=10,
)
ax.set_ylim((ylim[0], ylim[1]))
# add scalebar
if add_scalebar:
self._add_scalebar(ax)
ax.axes.get_xaxis().set_visible(False)
[docs]
def plot_2d_morphology(
self,
ax,
node_arg=None,
cs=None,
cminmax=None,
cmap=None,
use_radius=1,
draw_soma_circle=1,
plotargs={},
textargs={},
marklocs=[],
loc_args={},
marklabels={},
labelargs={},
cb_draw=0,
cb_orientation="vertical",
cb_label="",
sb_draw=1,
sb_scale=100,
sb_width=5.0,
set_lims=True,
lims_margin=0.1,
soma_pos=None,
):
"""
Plot the morphology projected on the x,y-plane
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be drawn
node_arg:
see documentation of `MorphTree.convert_node_arg_to_nodes`
cs: dict {int: float}, None or 'node_color'
If dict, node indices are keys and the float value will
correspond to the plotted color. If None, the color of the tree
will be the one specified in ``plotargs``. Note that the dict
does not have to contain all node indices. The ones that are not
featured in the dict are plot in the color specified in ``plotargs``.
If 'node_color', colors will be those stored on the nodes. Note
that choosing this option when there are nodes without 'color'
as an entry in ``node.content`` will result in an error.
cminmax: (float, float) or None (default)
The min and max values of the color scale (if cs is provided).
If None, the min and max values of cs are used.
cmap: `matplotlib.colors.Colormap` instance
colormap fram which colors in ``cs`` are taken
use_radius: bool
If ``True``, uses the swc radius for the width of the line
segments
draw_soma_circle: bool
If ``True``, draws the soma as a circle, otherwise doesn't draw
soma
plotargs: dict
`kwargs` for `matplotlib.pyplot.plot`. 'c'- or 'color'-
argument will be overwritten when cs is defined. 'lw'- or
'linewidth' argument will be multiplied with the swc radius of
the node if `use_radius` is ``True``.
textargs: dict
text properties for various labels in the plot
marklocs: list of tuples, dicts or instances of `neat.MorphLoc`
Location that will be plotted on the morphology
loc_args: dict or list of dict
`kwargs` for `matplotlib.pyplot.plot` for the location.
Use only point markers and no lines! When it is a single dict
all location will have the same marker. When it is a list it
should have the same length as `marklocs`.
marklabels: dict {int: string}
Keys are indices of locations in `marklocs`, values are strings
that are used to annotate the corresponding locations
labelargs: dict
text properties for the location annotation
cb_draw: bool
Whether or not to draw a `matplotlib.pyplot.colorbar()`
instance.
cb_orientation: string, 'vertical' or 'horizontal'
The colorbars' orientation
cb_label: string
The label of the colorbar
sb_draw: bool
Whether or not to draw a scale bar
sb_scale: float
Lenght of the scale bar (micron)
sb_width: float
Width of the scale bar
set_lims: bool (optional, default ``True``)
set ``ax`` limits based on the morphology
lims_margin: float
the margin, as fraction of total width and height of tree, at
which the limits are placed
soma_pos: np.ndarray of float (shape=(3,))
translate the soma position the the given coordinates
"""
# apply the translation
if soma_pos is None:
soma_pos = np.zeros((3,))
for node in self:
node.xyz += soma_pos
# default cmap
if cmap is None:
cmap = cm.get_cmap("jet")
# ensure color is indicated by the 'c'-parameter in `plotargs`
if "color" in plotargs:
plotargs["c"] = plotargs["color"]
del plotargs["color"]
elif "c" not in plotargs:
plotargs["c"] = "k"
# define a norm for the colors, if defined
if cs == "node_color":
cs = {node.index: node.content["color"] for node in self}
if cs is not None:
if cminmax is None:
if len(cs) > 0:
max_cs = cs[max(cs, key=cs.__getitem__)] # works for dict and list
min_cs = cs[min(cs, key=cs.__getitem__)] # works for dict and list
else:
min_cs, max_cs = 0.0, 1.0
else:
min_cs = cminmax[0]
max_cs = cminmax[1]
norm = pl.Normalize(vmin=min_cs, vmax=max_cs)
# ensure linewidth is indicated as 'lw' in plotargs
if "linewidth" in plotargs:
plotargs["lw"] = plotargs["linewidth"]
del plotargs["linewidth"]
elif "lw" not in plotargs:
plotargs["lw"] = 1.0
plotargs_orig = copy.deepcopy(plotargs)
# loc_args can be dictionary, so that the same properties hold for every
# markloc, or can be list with the same size as marklocs, so that every
# marker has different properties. `zorder` of the markers is also set
# very high so that they are always in the foreground
self.store_locs(marklocs, "plotlocs")
xs = self.xs["plotlocs"]
if isinstance(loc_args, dict):
if "zorder" not in loc_args:
loc_args["zorder"] = 1e4
loc_args = [loc_args for _ in marklocs]
else:
assert len(loc_args) == len(marklocs)
for loc_arg in loc_args:
if "zorder" not in loc_arg:
loc_arg["zorder"] = 1e4
# `marklabels` is a dictionary with as keys the index of the loc in
# `marklocs` to which the label belongs. `labelargs` is the same for
# every label
for ind in marklabels:
assert ind < len(marklocs)
# plot the tree
xlim = [0.0, 0.0]
ylim = [0.0, 0.0]
for node in self.convert_node_arg_to_nodes(node_arg):
if node.xyz[0] < xlim[0]:
xlim[0] = node.xyz[0]
if node.xyz[0] > xlim[1]:
xlim[1] = node.xyz[0]
if node.xyz[1] < ylim[0]:
ylim[0] = node.xyz[1]
if node.xyz[1] > ylim[1]:
ylim[1] = node.xyz[1]
# find the locations that are on the current node
inds = self.get_loc_idxs_on_node("plotlocs", node)
if node.parent_node is None:
# node is soma, draw as circle if necessary
if draw_soma_circle:
if cs is not None and node.index in cs:
plotargs["c"] = cmap(norm(cs[node.index]))
else:
plotargs["c"] = plotargs_orig["c"]
circ = patches.Circle(node.xyz[0:2], node.R, color=plotargs["c"])
ax.add_patch(circ)
for ind in inds:
self._plot_loc(
ax,
ind,
node.xyz[0],
node.xyz[1],
loc_args,
marklabels,
labelargs,
)
else:
# plot line segment associated with node
nxyz = node.xyz
pxyz = node.parent_node.xyz
if cs is not None and node.index in cs:
plotargs["c"] = cmap(norm(cs[node.index]))
else:
plotargs["c"] = plotargs_orig["c"]
if use_radius:
plotargs["lw"] = plotargs_orig["lw"] * node.R
ax.plot([pxyz[0], nxyz[0]], [pxyz[1], nxyz[1]], **plotargs)
# plot the locations
for ind in inds:
locxyz = pxyz + (nxyz - pxyz) * xs[ind]
self._plot_loc(
ax, ind, locxyz[0], locxyz[1], loc_args, marklabels, labelargs
)
# margins
dx = xlim[1] - xlim[0]
dy = ylim[1] - ylim[0]
xlim[0] -= dx * lims_margin
xlim[1] += dx * lims_margin
ylim[0] -= dy * lims_margin
ylim[1] += dy * lims_margin
# draw a scale bar
if sb_draw:
scale = sb_scale
dy = ylim[1] - ylim[0]
dx = xlim[1] - xlim[0]
ax.plot(
[xlim[0], xlim[0] + scale],
[ylim[0], ylim[0]],
"k",
linewidth=sb_width,
zorder=1e5,
)
txt = ax.annotate(
r"" + str(scale) + " μm",
xy=(xlim[0] + scale / 2.0, ylim[0]),
xycoords="data",
xytext=(xlim[0] + scale / 2.0, ylim[0] - dy / 200.0),
ha="center",
va="top",
**textargs,
)
# textcoords='offset points', **textargs)
txt.set_path_effects([patheffects.withStroke(foreground="w", linewidth=2)])
if set_lims:
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_aspect("equal", "datalim")
if cs != None and cb_draw:
# create colorbar ax
divider = make_axes_locatable(ax)
if cb_orientation == "horizontal":
cax = divider.append_axes("bottom", "5%", pad="3%")
else:
cax = divider.append_axes("right", "5%", pad="3%")
# create a mappable
sm = cm.ScalarMappable(cmap=cmap, norm=norm)
sm._A = [] # fake array for scalar mappable
# create the colorbar
cb = pl.colorbar(sm, cax=cax, orientation=cb_orientation)
ticks_cb = np.round(np.linspace(min_cs, max_cs, 7), decimals=1)
cb.set_ticks(ticks_cb)
if cb_orientation == "horizontal":
cb.ax.xaxis.set_ticks_position("bottom")
else:
cb.ax.yaxis.set_ticks_position("right")
cb.set_label(cb_label, **textargs)
ax.spines["top"].set_color("none")
ax.spines["bottom"].set_color("none")
ax.spines["right"].set_color("none")
ax.spines["left"].set_color("none")
ax.draw_frame = False
ax.set_xticks([])
ax.set_yticks([])
# remove the translation
for node in self:
node.xyz -= soma_pos
def _plot_loc(self, ax, ind, xval, yval, loc_args, marklabels, labelargs):
"""
plot a location on the morphology together with its annotation
"""
ax.plot(xval, yval, **loc_args[ind])
if ind in marklabels:
txt = ax.annotate(
marklabels[ind],
xy=(xval, yval),
xycoords="data",
xytext=(5, 5),
textcoords="offset points",
**labelargs,
)
txt.set_path_effects([patheffects.withStroke(foreground="w", linewidth=2)])
[docs]
def plot_morphology_interactive(
self,
node_arg=None,
use_radius=1,
draw_soma_circle=1,
plotargs={"c": "k", "lw": 1.0},
project3d=False,
):
"""
Show the morphology either in 3d or projected on the x,y-plane. When
a line segment is clicked, the associated node is printed.
Parameters
----------
ax: `matplotlib.axes.Axes` instance
the ax object on which the plot will be drawn
node_arg:
see documentation of `MorphTree.convert_node_arg_to_nodes`
use_radius: bool
If ``True``, uses the swc radius for the width of the line
segments
draw_soma_circle: bool
If ``True``, draws the soma as a circle, otherwise doesn't draw
soma
"""
fig = pl.figure("Morphology interactive")
ax = pl.gca(projection="3d") if project3d else pl.gca()
# ax = pl.gca()
if "c" not in plotargs:
plotargs.update({"c": "k"})
if "linewidth" in plotargs:
plotargs["lw"] = plotargs["linewidth"]
del plotargs["linewidth"]
if "lw" not in plotargs:
plotargs.update({"lw": "k"})
plotargs_orig = copy.deepcopy(plotargs)
# plot the tree
node_line_associators = {}
for ii, node in enumerate(self.convert_node_arg_to_nodes(node_arg)):
if node.parent_node is not None:
# plot line segment associated with node
nxyz = node.xyz
pxyz = node.parent_node.xyz
if use_radius:
plotargs["lw"] = plotargs_orig["lw"] * node.R
if project3d:
line = ax.plot(
[pxyz[0], nxyz[0]],
[pxyz[1], nxyz[1]],
[pxyz[2], nxyz[2]],
label=str(ii),
picker=2.0,
**plotargs,
)
else:
line = ax.plot(
[pxyz[0], nxyz[0]],
[pxyz[1], nxyz[1]],
label=str(ii),
picker=2.0,
**plotargs,
)
node_line_associators.update({str(ii): node})
else:
# node is soma, draw as circle if necessary
if draw_soma_circle:
circ = patches.Circle(node.xyz[0:2], node.R, color=plotargs["c"])
ax.add_patch(circ)
ax.axes.get_xaxis().set_visible(0)
ax.axes.get_yaxis().set_visible(0)
ax.axison = 0
# define the clickevent action
def onPick(event):
line = event.artist
node = node_line_associators[line.get_label()]
# print the associated node
print(
"\n>>> line segment at "
+ str(node)
+ ", distance to soma (um) = "
+ str(
self.path_length(
{"node": node.index, "x": 1}, {"node": 1, "x": 0.0}
)
)
)
# show morphology
cid = fig.canvas.mpl_connect("pick_event", onPick)
pl.show()
def find_common_root(self, name):
self._try_name(name)
# get the node indices of nodes
node_inds = self.get_node_indices(name)
# find the paths to the root
paths = [set(self.path_to_root(self[node_ind])) for node_ind in node_inds]
# possible roots
roots = list(set.intersection(*paths))
# return the node of highest order
rootind = np.argmax([self.order_of_node(node) for node in roots])
return roots[rootind]
[docs]
def create_new_tree(
self,
loc_arg,
name="new tree",
fake_soma=False,
new_tree=None,
):
"""
Creates a new tree where the provided location in `loc_arg` are now the nodes.
Note that if the soma is not in the list of locations, a common root location
might be added if necessary.
Distance relations between locations are maintained (note that this
relation is stored in `L` attribute of `neat.MorphNode`, the `p3d`
attribute containing the 3d coordinates does not maintain distances)
The radius of a node is taken as the average radius between the location
associated with the node and the location associated with the parent node,
weighted by the lengths of all individual nodes.
Parameters
----------
loc_arg: list of `neat.MorphLoc` or string
the locations.
name: str (default 'new tree')
The name under which the locations associated to the tree are stored.
fake_soma: bool (default `False`)
if `True`, finds the common root of the set of locations and
uses that as the soma of the new tree. If `False`, the real soma
is used.
new_tree: `None` or instance of subclass of `neat.MorphTree`
The new tree instance.
Returns
-------
`neat.MorphTree`
The new tree.
Raises
------
`ValueError`
If `new_tree` is not a subclass of `self.__class__`
"""
locs = self.convert_loc_arg_to_locs(loc_arg)
self.store_locs(locs, name)
if fake_soma:
# find the common root of the set of locations
snode = self.find_common_root(name)
else:
# use the soma as root
snode = self[1]
# add the common root to the set of locations, if necessary
new_locs = self.unique_locs([MorphLoc((snode.index, 1.0), self)] + locs)
self.store_locs(new_locs, name)
# create new tree
if new_tree is None:
new_tree = self.__class__()
elif not issubclass(type(new_tree), MorphTree):
raise ValueError(
f"`new_tree` is an instance of {new_tree.__class__}, "
f"but should be a subclass of <class 'neat.MorphTree'>."
)
# create the root node
p3d = (snode.xyz, snode.R, snode.swc_type)
new_snode = new_tree.create_corresponding_node(1, p3d)
new_snode.L = snode.L
new_tree.set_root(new_snode)
new_nodes = [new_snode]
# set the index of the location associated with the current node
nids = self.get_node_indices(name)
new_snode.content["loc idx"] = (
None if 1 not in nids else np.where(nids == 1)[0][0]
)
# make two other soma nodes
if fake_soma:
for index in [2, 3]:
new_cnode = new_tree.create_corresponding_node(index, p3d)
new_tree.add_node_with_parent(new_cnode, new_snode)
new_nodes.append(new_cnode)
else:
for cnode in snode.get_child_nodes(skip_inds=[]):
if cnode.index in [2, 3]:
p3d = (cnode.xyz, cnode.R, cnode.swc_type)
new_cnode = new_tree.create_corresponding_node(cnode.index, p3d)
new_tree.add_node_with_parent(new_cnode, new_snode)
new_nodes.append(new_cnode)
# make rest of tree
for cnode in snode.child_nodes:
self._add_nodes_to_tree(
cnode,
new_snode,
new_tree,
new_nodes,
name,
)
# set the lengths of the nodes
for new_node in new_tree:
if new_node.parent_node != None:
L, R = self.path_length(
new_locs[new_node.parent_node.content["loc idx"]],
new_locs[new_node.content["loc idx"]],
compute_radius=True,
)
else:
L = 0.0
R = self[1].R
new_node.L = L
new_node.R = R
# store the created root as the root of the original tree
new_tree._original_root = new_tree.root
return new_tree
def _add_nodes_to_tree(self, node, new_pnode, new_tree, new_nodes, name):
# get the specified locs
xs = self.xs[name]
# check which locinds are on the branch
ninds = self.get_loc_idxs_on_node(name, node)
order_inds = np.argsort(xs[ninds])
for ind in np.array(ninds)[order_inds]:
index = len(new_nodes) + 1
# new coordinates
new_xyz = node.parent_node.xyz * (1.0 - xs[ind]) + node.xyz * xs[ind]
# set a fake radius, it's overwritten later
fake_radius = -1.0
# make new node
p3d = (new_xyz, fake_radius, node.swc_type)
new_node = new_tree.create_corresponding_node(index, p3d)
new_node.content["loc idx"] = ind
# add new node
new_tree.add_node_with_parent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
# continue with the children
for cnode in node.child_nodes:
self._add_nodes_to_tree(
cnode,
new_pnode,
new_tree,
new_nodes,
name,
)
[docs]
def create_compartment_tree(self, loc_arg):
"""
Creates a new compartment tree where the provided set of locations
correspond to the nodes.
Parameters
----------
loc_arg: list of locations or str
if list of locations, specifies the locations, if str,
specifies the name under which the set of location is stored
that should be used to create the new tree
Returns
-------
`neat.CompartmentTree`
The compartment tree.
"""
# process input argument
if isinstance(loc_arg, list):
locs = [MorphLoc(loc, self) for loc in loc_arg]
name = "comp_locs"
self.store_locs(locs, name=name)
elif isinstance(loc_arg, str):
name = loc_arg
self._try_name(name)
else:
raise IOError("`loc_arg` should be list of locs or string")
nids = self.nids[name]
xs = self.xs[name]
# create new tree
new_tree = CompartmentTree()
# find the common root of the set of locations
snode = self.find_common_root(name)
# check if that root is in set of locations
possible_loc_idxs = self.get_loc_idxs_on_node(name, snode)
if len(possible_loc_idxs) > 0:
# create the new root node
new_pnode = CompartmentNode(0, loc_idx=int(possible_loc_idxs[0]))
new_tree.set_root(new_pnode)
new_nodes = [new_pnode]
# create other nodes
for loc_idx in possible_loc_idxs[1:]:
index = len(new_nodes)
# make new node
new_node = CompartmentNode(index, loc_idx=int(loc_idx))
# add new node
new_tree.add_node_with_parent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
else:
warnings.warn(
"Locations of name `"
+ name
+ "` do not define a root - "
+ "adding root to set of locations"
)
locs = self.get_locs(name)
locs = [(snode.index, 1.0)] + locs
self.store_locs(locs, name=name)
# create the new root node
new_pnode = CompartmentNode(0, loc_idx=0)
new_tree.set_root(new_pnode)
new_nodes = [new_pnode]
# make rest of tree
for cnode in snode.child_nodes:
self._add_compartment_nodes_to_tree(
cnode, new_pnode, new_tree, new_nodes, name
)
return new_tree
def _add_compartment_nodes_to_tree(
self, node, new_pnode, new_tree, new_nodes, name
):
# get the specified locs
xs = self.xs[name]
# check which locinds are on the branch
ninds = self.get_loc_idxs_on_node(name, node)
order_inds = np.argsort(xs[ninds])
for loc_idx in np.array(ninds)[order_inds]:
index = len(new_nodes)
# make new node
new_node = CompartmentNode(index, loc_idx=int(loc_idx))
# add new node
new_tree.add_node_with_parent(new_node, new_pnode)
new_nodes.append(new_node)
# set new node as next parent node
new_pnode = new_node
# continue with the children
for cnode in node.child_nodes:
self._add_compartment_nodes_to_tree(
cnode, new_pnode, new_tree, new_nodes, name
)
[docs]
@original_tree_decorator
def __copy__(self, new_tree=None):
"""
Fill the ``new_tree`` with it's corresponding nodes in the same
structure as ``self``, and copies all node variables that both tree
classes have in common
Parameters
----------
new_tree: :class:`STree` or derived class (default is ``None``)
the tree class in which the ``self`` is copied. If ``None``,
returns a copy of ``self``.
Returns
-------
The new tree instance
"""
if new_tree is None:
new_tree = self.__class__()
super().__copy__(new_tree=new_tree)
new_tree._original_root = new_tree.root
try:
with self.as_computational_tree:
# set the computational tree
new_node = new_tree.create_corresponding_node(self.root.index)
self.root.__copy__(new_node=new_node)
new_tree._computational_root = new_node
with new_tree.as_computational_tree:
self._recurse_copy(self.root, new_tree)
except AttributeError:
# computational tree is not defined
pass
return new_tree