Source code for neat.channels.ionchannels

# -*- coding: utf-8 -*-
#
# ionchannels.py
#
# This file is part of NEST.
#
# Copyright (C) 2004 The NEST Initiative
#
# NEST is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 2 of the License, or
# (at your option) any later version.
#
# NEST is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with NEST.  If not, see <http://www.gnu.org/licenses/>.

import sympy as sp
import numpy as np

import os
import ast
import warnings

from ..factorydefaults import DefaultPhysiology

# CONC_DICT = {
#     'na': 10.,  # mM
#     'k': 54.4,  # mM
#     'ca': 1e-4,  # mM
# }

# TEMP_DEFAULT = 36.

# E_ION_DICT = {
#     'na': 50.,
#     'k': -85.,
#     'ca': 50.,
# }


class IfExpVisitor(ast.NodeVisitor):
    """
    Returns the first `IfExp` node in the ast, signalling an if statement
    """

    def __init__(self):
        self.ifexp_node = None

    def visit_IfExp(self, node):
        """
        Function can NOT be renamed, is implicitly called by `ast.NodeVisitor.visit()`
        """
        self.ifexp_node = node

    def find_IfExp_node(self, node):
        self.visit(node)
        return_node = self.ifexp_node
        self.ifexp_node = None
        return return_node


class _func(object):
    def __init__(self, eval_func_aux, eval_func_vtrap, e_trap):
        self.eval_func_aux = eval_func_aux
        self.eval_func_vtrap = eval_func_vtrap
        self.e_trap = e_trap

    def __call__(self, *args):
        vv = args[0]
        if isinstance(vv, float):
            if np.abs(vv - self.e_trap) < 0.001:
                return self.eval_func_vtrap(*args)
            else:
                return self.eval_func_aux(*args)
        else:
            fv_return = np.zeros_like(vv)
            bool_vtrap = np.abs(vv - self.e_trap) < 0.0001
            inds_vtrap = np.where(bool_vtrap)
            args_ = [a[inds_vtrap] for a in args]
            fv_return[inds_vtrap] = self.eval_func_vtrap(*args_)
            inds = np.where(np.logical_not(bool_vtrap))
            args_ = [a[inds] for a in args]
            fv_return[inds] = self.eval_func_aux(*args_)
            return fv_return


def _insert_function_prefixes(
    string, prefix="np", functions=["exp", "sin", "cos", "tan", "pi"]
):
    """
    Prefix all occurences in the input `string` of the functions in the
    `functions` list with the provided `prefix`.

    Parameters
    ----------
    string: string
        the input string
    prefix: string, optional
        the prefix that is put before each function. Defaults to `'np'`
    functions: list of strings, optional
        the list of functions that will be prefixed. Defaults to
        `['exp', 'sin', 'cos', 'tan', 'pi']`

    Returns
    -------
    string

    Examples
    --------
    >>> _insert_function_prefixes('5. * exp(0.) + 3. * cos(pi)')
    '5. * np.exp(0.) + 3. * np.cos(pi)'
    """
    for func_name in functions:
        numpy_string = ""
        while len(string) > 0:
            ind = string.find(func_name)
            if ind == -1:
                numpy_string += string
                string = ""
            else:
                numpy_string += string[0:ind] + prefix + "." + func_name
                string = string[ind + len(func_name) :]
        string = numpy_string
    return string


def _broadcast(fun):
    """
    This function is to be used in together with `sympy.lambdify` to ensure that
    lambda functions generated from constant expressions are broadcast to the
    input shape
    """
    return lambda *x: np.broadcast_arrays(fun(*x), *x)[0]


class SPDict(dict):
    """
    Dictionary that accepts both strings and similarly name sympy symbols as keys
    """

    def __getitem__(self, key):
        try:
            return super(SPDict, self).__getitem__(key)
        except KeyError:
            if isinstance(key, sp.Symbol):
                return super(SPDict, self).__getitem__(str(key))
            else:
                return super(SPDict, self).__getitem__(sp.symbols(key))

    def __contains__(self, key):
        return super().__contains__(key) or super().__contains__(sp.symbols(key))


class CallDict(SPDict):
    """
    Callable dictionary, items are supposed to be callables
    that all accept an identical argument list
    """

    def __call__(self, *args):
        """
        Calls dictionary items (supposed to be callable)
        """
        return SPDict({str(k): f(*args) for k, f in self.items()})


[docs] class IonChannel(object): """ Base ion channel class that implements linearization and code generation for NEURON (.mod-files) and C++. Userdefined ion channels should inherit from this class and implement the `define()` function, where the specific attributes of the ion channel are set. The ion channel current is of the form .. math:: i_{chan} = \overline{g} \, p_o(x_1, ... , x_n) \, (e - v) where $p_o$ is the open probability defined as a function of a number of state variables. State variables evolve according to .. math:: \dot{x}_i = f_i(x_i, v, c_1, ..., c_k) with $c_1, ..., c_n$ the (optional) set of concentrations the ion channel depends on. There are two canonical ways to define $f_i$, either based on reaction rates :math:`\\alpha` and :math:`\\beta`: .. math:: \dot{x}_i = \\alpha_i(v) \, (1 - x_i) - \\beta_i(v) \, x_i, or based on an asymptotic value :math:`x_i^{\infty}` and time-scale :math:`\\tau_i` .. math:: \dot{x}_i = \\frac{x_i^{\infty}(v) - x_i}{\\tau_i(v)}. `IonChannel` accepts handles either description. For the former description, dicts `self.alpha` and `self.beta` must be defined with as keys the names of every state variable in the open probability. Similarly, for the latter description, dicts `self.tauinf` and `self.varinf` must be defined with as keys the name of every state variable. The user **must** define the attributes `p_open`, and either `alpha` and `beta` or `tauinf` and `varinf` in the `define()` function. The other attributes `ion`, `conc`, `q10`, `temp`, and `e` are optional. Parameters ---------- p_open: str The open probability of the ion channel. alpha, beta: dict {str: str} dictionary of the rate function for each state variables. Keys must correspond to the name of every state variable in `p_open`, values must be formulas written as strings with `v` and possible ion as variabels tauinf, varinf: dict {str: str} state variable time scale and asymptotic activation level. Keys must correspond to the name of every state variable in `p_open`, values must be formulas written as strings with `v` and possible ion as variabels ion: str ('na', 'ca', 'k' or ''), optional The ion to which the ion channel is permeable conc: set of str (containing 'na', 'ca', 'k') or dict of {str: float} The concentrations the ion channel activation depends on. Can be a set of ions or a dict with the ions as keys and default values as float. q10: str, optional Temperature dependence of the state variable rate functions. May be a float or a string convertible to a sympy expression containing the `temp` parameter (temperature in ``[deg C]``). This factor divides the time-scales :math:`\tau_i(v)` of the ion channel. If not given, default is 1. temp: float, optional The temperature at which the ion channel is evaluated. Can be modified after initializiation by calling `IonChannel.set_default_params(temp=new_temperature)`. If not given, the evaluates `self.q10` at the default temperature of 36 degC. e: float, optional Reversal of the ion channel in ``[mV]``. functions that need it allow the default value to be overwritten with a keyword argument. If nothing is provided, will take a default reversal for `self.ion` (which is -85 mV for 'K', 50 mV for 'Na' and 50 mV for 'Ca'). If no ion is provided, errors will occur if functions that need `e` are called without specifying the value as a keyword argument. Examples -------- >>> class Na_Ta(IonChannel): >>> def define(self): >>> # from (Colbert and Pan, 2002), Used in (Hay, 2011) >>> self.ion = 'na' >>> # concentrations the ion channel depends on >>> self.conc = {} >>> # define channel open probability >>> self.p_open = 'h * m ** 3' >>> # define activation functions >>> self.alpha, self.beta = {}, {} >>> self.alpha['m'] = '0.182 * (v + 38.) / (1. - exp(-(v + 38.) / 6.))' # 1/ms >>> self.beta['m'] = '-0.124 * (v + 38.) / (1. - exp( (v + 38.) / 6.))' # 1/ms >>> self.alpha['h'] = '-0.015 * (v + 66.) / (1. - exp( (v + 66.) / 6.))' # 1/ms >>> self.beta['h'] = '0.015 * (v + 66.) / (1. - exp(-(v + 66.) / 6.))' # 1/ms >>> # temperature factor for reaction rates >>> self.q10 = '2.3^((temp - 23.)/10.)' """ def __init__(self, **kwargs): """ Will give an ``AttributeError`` if initialized as is. Should only be initialized from its' derived classes that implement specific ion channel types. """ # initialize default configuration self.cfg = DefaultPhysiology() # define the channel based on user specified state variables and activations self.define() # ion that carries the channel current if not hasattr(self, "ion"): self.ion = "" # temperature factor, if it exist if not hasattr(self, "q10"): self.q10 = "1." self.q10 = sp.sympify(self.q10, evaluate=False) # sympy temperature symbols assert len(self.q10.free_symbols) <= 1 if len(self.q10.free_symbols) > 0: assert str(list(self.q10.free_symbols)[0]) == "temp" self.sp_t = list(self.q10.free_symbols)[0] else: self.sp_t = sp.symbols("temp") # the voltage variable self.sp_v = sp.symbols("v") # extract the state variables self.p_open = sp.sympify(self.p_open) self.statevars = self.p_open.free_symbols # if voltage occurs directly in open probability, # remove it from statevars if self.sp_v in self.statevars: self.statevars.remove(self.sp_v) if not "tauinf" in self.__dict__: self.tauinf = {} if not "varinf" in self.__dict__: self.varinf = {} for svar in self.ordered_statevars: key = str(svar) if key in (self.varinf.keys() | self.tauinf.keys()): self.varinf[svar] = sp.sympify(self.varinf[key], evaluate=False) self.tauinf[svar] = sp.sympify(self.tauinf[key], evaluate=False) self.varinf[svar] = sp.simplify(self.varinf[svar]) self.tauinf[svar] = sp.simplify(self.tauinf[svar] / self.q10) del self.varinf[key] del self.tauinf[key] # construct the rate functions if "alpha" in self.__dict__ and "beta" in self.__dict__: for svar in self.ordered_statevars: key = str(svar) if key in (self.alpha.keys() | self.beta.keys()): self.alpha[svar] = sp.sympify(self.alpha[key], evaluate=False) self.beta[svar] = sp.sympify(self.beta[key], evaluate=False) self.varinf[svar] = sp.simplify( self.alpha[svar] / (self.alpha[svar] + self.beta[svar]) ) self.tauinf[svar] = sp.simplify( (1.0 / self.q10) / (self.alpha[svar] + self.beta[svar]) ) del self.alpha del self.beta # check if rate equations where defined if len(self.varinf) == 0 or len(self.tauinf) == 0: raise AttributeError( "Necessary attributes not defined, define either " + "`alpha` and `beta` or `tauinf` and `varinf`." ) self.varinf, self.tauinf = SPDict(self.varinf), SPDict(self.tauinf) # set the right hand side of the differential equation for # state variables self.fstatevar = SPDict() for svar in self.ordered_statevars: self.fstatevar[svar] = (-svar + self.varinf[svar]) / self.tauinf[svar] # concentrations the ion channel depends on if not hasattr(self, "conc"): # if concentration ions are not defined, attempt to extract them from # the state variable functions self.conc = set() for key, expr in self.fstatevar.items(): self.conc |= expr.free_symbols # set union # remove everything that is not a concentration self.conc -= self.statevars self.conc -= {self.sp_v, self.sp_t} # if no default concentrations are defined, default values are taken # from default concentration values if not hasattr(self.conc, "values"): self.conc = SPDict( { sp.symbols(str(ion)): self.cfg.conc[str(ion)] for ion in list(sorted(self.conc)) } ) # sympy concentration symbols self.sp_c = [ion for ion in self.conc] # default parameters self.default_params = SPDict({}) self.default_params[str(self.sp_t)] = ( self.temp if "temp" in self.__dict__ else self.cfg.temp ) try: self.default_params["e"] = ( self.e if "e" in self.__dict__ else self.cfg.e_rev[self.ion] ) except KeyError: warnings.warn("No default reversal potential defined.") # self._lambdify_channel() self.set_default_params(**kwargs) def __getstate__(self): """ remove lambdified functions from dict as they can not be pickled """ d = dict(self.__dict__) del d["f_statevar"] del d["f_varinf"] del d["f_tauinf"] del d["f_p_open"] del d["dp_dx"], d["df_dv"], d["df_dx"], d["df_dc"] return d def __setstate__(self, s): """ since lambdified functions were not pickled we need to restore them """ self.__dict__ = s self._lambdify_channel()
[docs] def set_default_params(self, **kwargs): """ **kwargs Default values for temperature (`temp`), reversal (`e`) """ self.default_params.update(kwargs) # set the lambda functions for efficient numpy evaluation self._lambdify_channel()
def _substitute_defaults(self, expr): """ Substitute default values in input expression Parameters ---------- expr: sympy expression """ for param, val in self.default_params.items(): expr = expr.subs(sp.symbols(param), val) return expr @property def ordered_statevars(self): return list(sorted(self.statevars, key=str)) def _lambdify_channel(self): """ Create lambda functions based on sympy expression for relevant ion channel functions """ from sympy.utilities.autowrap import ufuncify # arguments for lambda function args = [self.sp_v] + self.ordered_statevars + self.sp_c args_ = [self.sp_v] + self.sp_c # lambdified open probability self.f_p_open = _broadcast(sp.lambdify(args, self.p_open)) # storatestate variable function self.f_statevar = CallDict() self.f_varinf, self.f_tauinf = CallDict(), CallDict() # storage of derivatives self.dp_dx = CallDict() self.df_dv, self.df_dx, self.df_dc = CallDict(), CallDict(), CallDict() for svar, f_svar in self.fstatevar.items(): f_svar = self._substitute_defaults(f_svar) varinf = self._substitute_defaults(self.varinf[svar]) tauinf = self._substitute_defaults(self.tauinf[svar]) # state variable function self.f_statevar = _broadcast(sp.lambdify(args, f_svar)) # state variable activation & timescale self.f_varinf[svar] = _broadcast(sp.lambdify(args_, varinf)) self.f_tauinf[svar] = _broadcast(sp.lambdify(args_, tauinf)) # derivatives of open probability to state variables self.dp_dx[svar] = _broadcast( sp.lambdify(args, sp.diff(self.p_open, svar, 1)) ) # derivatives of state variable function to voltage self.df_dv[svar] = _broadcast( sp.lambdify(args, sp.diff(f_svar, self.sp_v, 1)) ) # derivatives of state variable function to state variable self.df_dx[svar] = _broadcast(sp.lambdify(args, sp.diff(f_svar, svar, 1))) # derivatives of state variable function to concentrations self.df_dc[svar] = CallDict( { c: _broadcast(sp.lambdify(args, sp.diff(f_svar, c, 1))) for c in self.sp_c } ) def _args_as_list(self, v, w_statevar=True, **kwargs): """ Converts arguments to list for lambdified functions """ arg_list = [v] if w_statevar: for svar in self.ordered_statevars: key = str(svar) try: arg_list.append(kwargs[key]) except KeyError: # state variable is not in kwargs # set default value based on voltage args = self._args_as_list(v, w_statevar=False, **kwargs) arg_list.append(self.f_varinf[svar](*args)) for c in self.sp_c: key = str(c) try: arg_list.append(kwargs[key]) except KeyError: # ion is not in kwargs # set stored default value arg_list.append(self.conc[c]) return arg_list
[docs] def compute_p_open(self, v, **kwargs): """ Compute the open probability of the ion channel Parameters ---------- v: float or `np.ndarray` of float The voltage at which to evaluate the open probability **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- float or `np.ndarray` of float The open probability """ args = self._args_as_list(v, **kwargs) return self.f_p_open(*args)
[docs] def compute_derivatives(self, v, **kwargs): """ Compute: (i) the derivatives of the open probability to the state variables (ii) The derivatives of state functions to the voltage (iii) The derivatives of state functions to the state variables Parameters ---------- v: float or `np.ndarray` The voltage at which to evaluate the open probability **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- tuple of three floats or three `np.ndarray`s of float The derivatives """ args = self._args_as_list(v, **kwargs) return self.dp_dx(*args), self.df_dv(*args), self.df_dx(*args)
[docs] def compute_derivativesConc(self, v, **kwargs): """ Compute the derivatives of the state functions to the concentrations Parameters ---------- v: float or `np.ndarray` The voltage at which to evaluate the open probability **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- tuple of three floats or three `np.ndarray`s of float The derivatives """ args = self._args_as_list(v, **kwargs) return self.df_dc(*args)
[docs] def compute_varinf(self, v): """ Compute the asymptotic values for the state variables at a given activation level Parameters ---------- v: float or `np.ndarray` The voltage at which to evaluate the open probability Returns ------- dict of `np.ndarray` of dict of float The asymptotic activations, items are of same type (and shape) as `v` """ args = self._args_as_list(v, w_statevar=False, **{}) return self.f_varinf(*args)
[docs] def compute_tauinf(self, v): """ Compute the time-scales for the state variables at a given activation level Parameters ---------- v: float or `np.ndarray` The voltage at which to evaluate the open probability Returns ------- dict of `np.ndarray` of dict of float The asymptotic activations, items are of same type (and shape) as `v` """ args = self._args_as_list(v, w_statevar=False, **{}) return self.f_tauinf(*args)
def compute_lin_statevar_response(self, v, freqs, v_resp, **kwargs): """ Combute the linearizations of the individual state variables Parameters ---------- v: float or `np.ndarray` The voltage(s) ``[mV]`` around which to linearize the ion channel freqs float, complex, or `np.ndarray` of float or complex: The frequencies ``[Hz]`` at which to evaluate the linearized contribution v_resp: `np.ndarray` (``dtype=complex``, ``ndim=1``, ``shape=(s,k)``) Linearized voltage responses in the frequency domain, evaluated at ``s`` frequencies and ``k`` locations **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- `SPDict` of float, complex or `np.ndarray` of float or complex The linearized current. Key are the state variable name. Shape of each entry is dimension of `freqs` followed by the dimensions of `v`. """ dp_dx, df_dv, df_dx = self.compute_derivatives(v, **kwargs) # determine the output shape according to numpy broadcasting rules args_aux = [v_resp] + self._args_as_list(v, **kwargs) out_shape = np.broadcast(*args_aux).shape lin_svar = SPDict( { str(svar): np.zeros(out_shape, dtype=np.array(freqs).dtype) for svar in self.ordered_statevars } ) for svar, dp_dx_ in dp_dx.items(): df_dv_ = df_dv[svar] * 1e3 # convert to 1 / s df_dx_ = df_dx[svar] * 1e3 # convert to 1 / s # add to the impedance contribution lin_svar[str(svar)] = df_dv_ / (freqs - df_dx_) * v_resp return lin_svar
[docs] def compute_linear(self, v, freqs, **kwargs): """ Combute the contributions of the state variables to the linearized channel current Parameters ---------- v: float or `np.ndarray` The voltage ``[mV]`` at which to evaluate the open probability freqs float, complex, or `np.ndarray` of float or complex: The frequencies ``[Hz]`` at which to evaluate the linearized contribution **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- float, complex or `np.ndarray` of float or complex The linearized current. Shape is dimension of `freqs` followed by the dimensions of `v`. """ dp_dx, df_dv, df_dx = self.compute_derivatives(v, **kwargs) # determine the output shape according to numpy broadcasting rules args_aux = [freqs] + self._args_as_list(v, **kwargs) out_shape = np.broadcast(*args_aux).shape lin_f = np.zeros(out_shape, dtype=np.array(freqs).dtype) for svar, dp_dx_ in dp_dx.items(): df_dv_ = df_dv[svar] * 1e3 # convert to 1 / s df_dx_ = df_dx[svar] * 1e3 # convert to 1 / s # add to the impedance contribution lin_f += dp_dx_ * df_dv_ / (freqs - df_dx_) return lin_f
[docs] def compute_linear_conc(self, v, freqs, ion, **kwargs): """ Combute the contributions of the state variables to the linearized channel current Parameters ---------- v: float or `np.ndarray` The voltage ``[mV]`` at which to evaluate the open probability freqs: float, complex, or `np.ndarray` of float or complex: The frequencies ``[Hz]`` at which to evaluate the linearized contribution ion: str The ion name for which to compute the linearized contribution **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- float, complex or `np.ndarray` of float or complex The linearized current. Shape is dimension of `freqs` followed by the dimensions of `v`. """ dp_dx, df_dv, df_dx = self.compute_derivatives(v, **kwargs) df_dc = self.compute_derivativesConc(v, **kwargs) # determine the output shape according to numpy broadcasting rules args_aux = [freqs] + self._args_as_list(v, **kwargs) out_shape = np.broadcast(*args_aux).shape lin_f = np.zeros(out_shape, dtype=np.array(freqs).dtype) for svar, dp_dx_ in dp_dx.items(): df_dc_ = df_dc[svar][ion] * 1e3 # convert to 1 / s df_dx_ = df_dx[svar] * 1e3 # convert to 1 / s # add to the impedance contribution lin_f += dp_dx_ * df_dc_ / (freqs - df_dx_) return lin_f
def _get_reversal(self, e): if e is None: try: e = self.default_params["e"] except KeyError: raise KeyError("No default reversal defined, provide value for `e`.") return e
[docs] def compute_lin_sum(self, v, freqs, e=None, **kwargs): """ Combute the linearized channel current contribution (without concentributions from the concentration - see `compute_lin_conc()`) Parameters ---------- v: float or `np.ndarray` The voltage ``[mV]`` at which to evaluate the open probability freqs: float, complex, or `np.ndarray` of float or complex: The frequencies ``[Hz]`` at which to evaluate the linearized contribution e: float or `None` The reversal potential of the channel. Defaults to the value stored in `self.default_params['e']` if not provided. **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- float, complex or `np.ndarray` of float or complex The linearized current. Shape is dimension of `freqs` followed by the dimensions of `v`. """ e = self._get_reversal(e) return (e - v) * self.compute_linear(v, freqs, **kwargs) - self.compute_p_open( v, **kwargs )
[docs] def compute_lin_conc(self, v, freqs, ion, e=None, **kwargs): """ Combute the linearized channel current contribution from the concentrations Parameters ---------- v: float or `np.ndarray` The voltage ``[mV]`` at which to evaluate the open probability freqs: float, complex, or `np.ndarray` of float or complex: The frequencies ``[Hz]`` at which to evaluate the linearized contribution ion: str The ion name for which to compute the linearized contribution e: float or `None` The reversal potential of the channel. Defaults to the value stored in `self.default_params['e']` if not provided. **kwargs: float or `np.ndarray` Optional values for the state variables and concentrations. Returns ------- float, complex or `np.ndarray` of float or complex The linearized current. Shape is dimension of `freqs` followed by the dimensions of `v`. """ e = self._get_reversal(e) return (e - v) * self.compute_linear_conc(v, freqs, ion, **kwargs)
def write_mod_file(self, path, g=0.0, e=None): """ Writes a modfile of the ion channel for simulations with neuron """ cname = self.__class__.__name__ sv = [str(svar) for svar in self.ordered_statevars] cs = [str(conc) for conc in self.conc] e = self._get_reversal(e) modname = "I" + cname + ".mod" fname = os.path.join(path, modname) file = open(fname, "w") file.write( ": This mod file is automaticaly generated by the " + "``neat.channels.ionchannels`` module\n\n" ) file.write("NEURON {\n") file.write(" SUFFIX I%s\n" % cname) if self.ion == "": file.write(" NONSPECIFIC_CURRENT i" + "\n") else: file.write(" USEION %s WRITE i%s\n" % (self.ion, self.ion)) for c in cs: file.write(" USEION %s READ %si\n" % (c, c)) file.write(" RANGE g, e" + "\n") taustring = "tau_" + ", tau_".join(sv) varstring = "_inf, ".join(sv) + "_inf" file.write(" RANGE %s, %s\n" % (varstring, taustring)) file.write(" THREADSAFE" + "\n") file.write("}\n\n") file.write("PARAMETER {\n") file.write(" g = " + str(g * 1e-6) + " (S/cm2)" + "\n") file.write(" e = " + str(e) + " (mV)" + "\n") file.write(" celsius (degC)\n") file.write("}\n\n") file.write("UNITS {\n") file.write(" (mA) = (milliamp)" + "\n") file.write(" (mV) = (millivolt)" + "\n") file.write(" (mM) = (milli/liter)" + "\n") file.write("}\n\n") file.write("ASSIGNED {\n") file.write(" i%s (mA/cm2)\n" % self.ion) for var in sv: file.write(" %s_inf \n" % var) file.write(" tau_%s (ms) \n" % var) for ion in cs: file.write(" " + ion + "i (mM)" + "\n") file.write(" v (mV)" + "\n") file.write(" %s (degC)\n" % (self.sp_t)) file.write("}\n\n") file.write("STATE {\n") for var in sv: file.write(" %s\n" % var) file.write("}\n\n") calcstring = "i%s = g * (%s) * (v - e)" % ( self.ion, sp.printing.ccode(self.p_open), ) file.write("BREAKPOINT {\n") file.write(" SOLVE states METHOD cnexp" + "\n") file.write(" %s\n" % calcstring) file.write("}\n\n") concstring = "i, ".join(cs) if len(cs) > 0: concstring = ", " + concstring concstring += "i" file.write("INITIAL {\n") file.write(" rates(v%s)\n" % concstring) for var in sv: file.write(" %s = %s_inf\n" % (var, var)) file.write("}\n\n") file.write("DERIVATIVE states {\n") file.write(" rates(v%s)\n" % concstring) for var in sv: file.write(" %s' = (%s_inf - %s) / tau_%s \n" % (var, var, var, var)) file.write("}\n\n") # substitution for common neuron names repl_pairs = [(str(c), str(c) + "i") for c in self.conc] file.write("PROCEDURE rates(v%s) {\n" % concstring) file.write(" %s = celsius\n" % str(self.sp_t)) for var, svar in zip(sv, self.ordered_statevars): vi = sp.printing.ccode(self.varinf[svar], assign_to=f"{var}_inf") ti = sp.printing.ccode(self.tauinf[svar], assign_to=f"tau_{var}") for repl_pair in repl_pairs: vi = vi.replace(*repl_pair) ti = ti.replace(*repl_pair) # no ";" in mod-file, add indent vi = vi.replace(";", "").replace("\n", "\n ") ti = ti.replace(";", "").replace("\n", "\n ") file.write(f" {vi}\n") file.write(f" {ti}\n") file.write("}\n\n") file.close() return modname def _create_nestml_funcstr(self, code_str, n_spaces=0, indent=8): """ This function is used to recursively expand if... else... statements across multiple lines, as by default the single line version is printed by `sympy.pycode()` and `ast.unparse()` """ tree = ast.parse(code_str) iev = IfExpVisitor() ifexp = iev.find_IfExp_node(tree) if ifexp is not None: # sanity check assert iev.find_IfExp_node(ifexp.test) is None # if test is True cond_1_str = self._create_nestml_funcstr( ast.unparse(ifexp.body), n_spaces=n_spaces, indent=n_spaces + indent ) # if test is False cond_0_str = self._create_nestml_funcstr( ast.unparse(ifexp.orelse), n_spaces=n_spaces, indent=n_spaces + indent ) code_str = ( " " * indent + f"if {ast.unparse(ifexp.test)}:\n" + f"{cond_1_str}" + " " * indent + f"else:\n" + f"{cond_0_str}" ) else: try: code_str = ( " " * indent + f"val = {sp.printing.ccode(sp.sympify(code_str))}\n" ) except TypeError as e: print(e) return code_str def write_nestml_blocks( self, blocks=["state", "parameters", "equations", "function"], v_comp=-75.0, g=0.0, e=None, ): cname = self.__class__.__name__ sv = [str(svar) for svar in self.ordered_statevars] cs = [str(conc) for conc in self.conc] sv_suff = [sv_ + "_" + cname for sv_ in sv] e = self._get_reversal(e) sv_init = self.compute_varinf(v_comp) blocks_dict = {block: "" for block in blocks} func_call_args = ["v_comp real"] for ckey, cval in self.conc.items(): func_call_args.append(f"{ckey} real") func_call_args = ", ".join(func_call_args) func_args = ["v_comp"] for ckey, cval in self.conc.items(): func_args.append(f"c_{ckey}") func_args = ", ".join(func_args) if "state" in blocks: state_str = "\n" + " # state variables %s\n" % cname for sv_, sv_key in zip(sv_suff, sv): state_str += " %s real = %.8f\n" % (sv_, sv_init[sv_key]) blocks_dict["state"] += state_str if "parameters" in blocks: param_str = ( "\n" + " # parameters %s\n" % cname + " gbar_%s real = %.2f\n" % (cname, g) + " e_%s real = %.2f\n" % (cname, e) ) blocks_dict["parameters"] += param_str if "equations" in blocks: # reformulate open probability in terms of suffixed variables p_open_ = self.p_open for svar, sv_ in zip(self.ordered_statevars, sv_suff): p_open_ = p_open_.subs(svar, sp.symbols(sv_)) p_open_ = p_open_.subs( self.sp_v, sp.UnevaluatedExpr(sp.symbols("v_comp")) ) eq_str = ( "\n" + " # equation %s\n" % cname + " inline i_%s real = gbar_%s * (%s) * (e_%s - v_comp) @mechanism::channel\n" % (cname, cname, str(p_open_), cname) ) for var, var_suff, svar in zip(sv, sv_suff, self.ordered_statevars): vi = sp.printing.ccode(self.varinf[svar]) ti = sp.printing.ccode(self.tauinf[svar]) eq_str += f" {var_suff}' = ( {var}_inf_{cname}( {func_args} ) - {var_suff} ) / ( tau_{var}_{cname}( {func_args} ) * 1s )\n" eq_str += "\n" blocks_dict["equations"] += eq_str def _customsimplify(expr): return sp.logcombine(sp.powsimp(sp.expand(expr))) if "function" in blocks: func_str = "\n" + " # functions %s\n" % cname for svar, sv_, sv_suff_ in zip(self.ordered_statevars, sv, sv_suff): # substitute possible default values and concentrations varinf_func = self._substitute_defaults(self.varinf[svar]) func_args = ["v_comp real"] for ckey, cval in self.conc.items(): func_args.append(f"{ckey} real") func_args = ", ".join(func_args) # varinf_func = varinf_func.subs(ckey, cval) # print activation function to nestml file varinf_func = varinf_func.subs( svar, sp.UnevaluatedExpr(sp.symbols(sv_suff_)) ) varinf_func = varinf_func.subs( self.sp_v, sp.UnevaluatedExpr(sp.symbols("v_comp")) ) code_str = sp.pycode(varinf_func, fully_qualified_modules=False) func_str += ( f" function {sv_}_inf_{cname} ({func_call_args}) real:\n" f" val real\n" f"{self._create_nestml_funcstr(code_str, n_spaces=4, indent=8)}" f" return val\n\n" ) # substitute possible default values and concentrations tauinf_func = self._substitute_defaults(self.tauinf[svar]) for ckey, cval in self.conc.items(): tauinf_func = tauinf_func.subs(ckey, cval) tauinf_func = tauinf_func.subs( svar, sp.UnevaluatedExpr(sp.symbols(sv_suff_)) ) tauinf_func = tauinf_func.subs( self.sp_v, sp.UnevaluatedExpr(sp.symbols("v_comp")) ) code_str = sp.pycode(tauinf_func, fully_qualified_modules=False) func_str += ( f"\n function tau_{sv_}_{cname} ({func_call_args}) real:\n" f" val real\n" f"{self._create_nestml_funcstr(code_str, n_spaces=4, indent=8)}" f" return val\n\n" ) blocks_dict["function"] += func_str return blocks_dict def write_cpp_code(self, path): """ Concentration dependent ion channels get constant concentrations substituted for c++ simulation """ c_name = self.__class__.__name__ svs = [str(svar) for svar in self.ordered_statevars] # rewrite open probabilities p_open_m = self.p_open p_open_m_inf = self.p_open for svar in self.ordered_statevars: p_open_m = p_open_m.subs(svar, sp.symbols("m_" + str(svar))) p_open_m_inf = p_open_m_inf.subs( svar, sp.symbols("m_" + str(svar) + "_inf") ) # substitue concentrations in expression def _replaceConc(expr_str, prefix="", suffix=""): for ion, conc in self.conc.items(): expr_str = expr_str.replace(str(ion), prefix + str(ion) + suffix) return expr_str # open header and cc files fcc = open(os.path.join(path, "Ionchannels.cc"), "a") fh = open(os.path.join(path, "Ionchannels.h"), "a") # define class and functions in header file fh.write("class %s: public IonChannel{\n" % c_name) fh.write("private:" + "\n") for svar in self.ordered_statevars: sv = sp.printing.ccode(svar) fh.write(" double m_%s;\n" % sv) fh.write(" double m_%s_inf, m_tau_%s;\n" % (sv, sv)) fh.write(" double m_v_%s = 10000.;\n" % sv) fh.write(" double m_p_open_eq = 0.0, m_p_open = 0.0;\n") # hardcode default concentrations for ion, conc in self.conc.items(): fh.write(" double m_%s = %.8f;\n" % (ion, conc)) fh.write("public:" + "\n") fh.write(" void calcFunStatevar(double v) override;" + "\n") fh.write(" double calcPOpen() override;" + "\n") fh.write(" void setPOpen() override;" + "\n") fh.write(" void setPOpenEQ(double v) override;" + "\n") fh.write(" void advance(double dt) override;" + "\n") fh.write(" double getCond() override;" + "\n") fh.write(" double getCondNewton() override;" + "\n") fh.write(" double f(double v) override;" + "\n") fh.write(" double DfDv(double v) override;" + "\n") fh.write(" void setfNewtonConstant(double* vs, int v_size) override;" + "\n") fh.write(" double fNewton(double v) override;" + "\n") fh.write(" double DfDvNewton(double v) override;" + "\n") fh.write("};" + "\n") # function in cc file fcc.write("void %s::calcFunStatevar(double v){\n" % c_name) for svar in self.ordered_statevars: varinf = self._substitute_defaults(self.varinf[svar]) tauinf = self._substitute_defaults(self.tauinf[svar]) sv = str(svar) vi = _replaceConc(sp.printing.ccode(varinf), prefix="m_") ti = _replaceConc(sp.printing.ccode(tauinf), prefix="m_") fcc.write(" m_%s_inf = %s;\n" % (sv, vi)) # if self.varinf.shape[1] == 2 and ind == (0,0): if sv == "m": # instantaneous approximation possible if statevar is activation (denoted by 'm') fcc.write(" if(m_instantaneous)" + "\n") fcc.write( " m_tau_%s = %s;\n" % (sv, sp.printing.ccode(sp.Float(1e-5))) ) fcc.write(" else" + "\n") fcc.write(" m_tau_%s = %s;\n" % (sv, ti)) else: fcc.write(" m_tau_%s = %s;\n" % (sv, ti)) fcc.write("}\n") fcc.write("double %s::calcPOpen(){\n" % c_name) fcc.write(" return %s;\n" % sp.printing.ccode(p_open_m)) fcc.write("}\n") fcc.write("void %s::setPOpen(){\n" % c_name) fcc.write(" m_p_open = calcPOpen();\n") fcc.write("}\n") fcc.write("void %s::setPOpenEQ(double v){\n" % c_name) fcc.write(" calcFunStatevar(v);\n") fcc.write("\n") for sv in svs: fcc.write(" m_%s = m_%s_inf;\n" % (sv, sv)) fcc.write(" m_p_open_eq = %s;\n" % sp.printing.ccode(p_open_m_inf)) fcc.write("}\n") fcc.write("void %s::advance(double dt){\n" % c_name) for sv in svs: fcc.write(" double p0_%s = exp(-dt / m_tau_%s);\n" % (sv, sv)) fcc.write(" m_%s *= p0_%s ;\n" % (sv, sv)) fcc.write(" m_%s += (1. - p0_%s) * m_%s_inf;\n" % (sv, sv, sv)) fcc.write("}\n") fcc.write("double %s::getCond(){\n" % c_name) fcc.write(" return m_g_bar * (m_p_open - m_p_open_eq);\n") fcc.write("}\n") fcc.write("double %s::getCondNewton(){\n" % c_name) fcc.write(" return m_g_bar;\n") fcc.write("}\n") # function for temporal integration fcc.write("double %s::f(double v){\n" % c_name) fcc.write(" return (m_e_rev - v);\n") fcc.write("}\n") fcc.write("double %s::DfDv(double v){\n" % c_name) fcc.write(" return -1.;\n") fcc.write("}\n") # set voltage values to evaluate at constant voltage during newton iteration fcc.write("void %s::setfNewtonConstant(double* vs, int v_size){\n" % c_name) fcc.write(" if(v_size != %d)" % len(self.ordered_statevars) + "\n") fcc.write( ' cerr << "input arg [vs] has incorrect size, ' + 'should have same size as number of channel state variables" << endl' + ";\n" ) for ii, svar in enumerate(self.ordered_statevars): fcc.write(" m_v_%s = vs[%d];\n" % (str(svar), ii)) fcc.write("}\n") # functions for solving Newton iteration fcc.write("double %s::fNewton(double v){\n" % c_name) p_o = self.p_open for svar in self.ordered_statevars: sv = "v_" + str(svar) # substitute default parameters vi = self._substitute_defaults(self.varinf[svar]) # write ccode and substitute variable names vi_ccode = sp.printing.ccode(vi) vi_ccode = vi_ccode.replace(str(self.sp_v), sv) vi_ccode = _replaceConc(vi_ccode, prefix="m_") # assign dynamic or fixed voltage to the activation fcc.write(" double %s;\n" % (sv)) fcc.write(" if(m_%s > 1000.){\n" % sv) fcc.write(" %s = v;\n" % (sv)) fcc.write(" } else{\n") fcc.write(" %s = m_%s;\n" % (sv, sv)) fcc.write(" }" + "\n") fcc.write(" double %s = %s;\n" % (str(svar), vi_ccode)) fcc.write( " return (m_e_rev - v) * (%s - m_p_open_eq);\n" % sp.printing.ccode(self.p_open) ) fcc.write("}\n") fcc.write("double %s::DfDvNewton(double v){\n" % c_name) dp_o = {svar: sp.diff(self.p_open, svar, 1) for svar in self.ordered_statevars} # print derivatives for svar in self.ordered_statevars: sv = "v_" + str(svar) v_var = sp.symbols(sv) # substitute default parameters vi = self._substitute_defaults(self.varinf[svar]) # write ccode and substitute variable names vi_ccode = sp.printing.ccode(vi) vi_ccode = vi_ccode.replace(str(self.sp_v), sv) vi_ccode = _replaceConc(vi_ccode, prefix="m_") # compute voltage derivatives dvi_dv = sp.diff(vi, self.sp_v, 1) dvi_dv_ccode = sp.printing.ccode(dvi_dv) dvi_dv_ccode = dvi_dv_ccode.replace(str(self.sp_v), sv) dvi_dv_ccode = _replaceConc(dvi_dv_ccode, prefix="m_") # compute derivative fcc.write(" double %s;\n" % sv) fcc.write(" double d%s_dv;\n" % str(svar)) fcc.write(" if(m_%s > 1000.){\n" % sv) fcc.write(" %s = v;\n" % sv) fcc.write(" d%s_dv = %s;\n" % (str(svar), dvi_dv_ccode)) fcc.write(" } else{\n") fcc.write(" %s = m_%s;\n" % (sv, sv)) fcc.write(" d%s_dv = 0;\n" % str(svar)) fcc.write(" }\n") fcc.write(" double %s = %s;\n" % (str(svar), vi_ccode)) expr_str = " + ".join( [ "%s * d%s_dv" % (sp.printing.ccode(dp_o[svar]), str(svar)) for svar in self.ordered_statevars ] ) fcc.write( " return -1. * (%s - m_p_open_eq) + (%s) * (m_e_rev - v);\n" % (sp.printing.ccode(self.p_open), expr_str) ) fcc.write("}\n") fh.write("\n") fcc.write("\n") fh.close() fcc.close()