NEURON
sympy_solver.cpp
Go to the documentation of this file.
1 /*
2  * Copyright 2023 Blue Brain Project, EPFL.
3  * See the top-level LICENSE file for details.
4  *
5  * SPDX-License-Identifier: Apache-2.0
6  */
7 
8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
10 
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
13 
14 #include "ast/program.hpp"
16 #include "parser/nmodl_driver.hpp"
17 #include "utils/test_utils.hpp"
28 
29 
30 using namespace nmodl;
31 using namespace codegen;
32 using namespace visitor;
33 using namespace test;
34 using namespace test_utils;
35 
36 using Catch::Matchers::ContainsSubstring; // ContainsSubstring in newer Catch2
37 
39 
40 using ast::AstNodeType;
42 
43 
44 //=============================================================================
45 // SympySolver visitor tests
46 //=============================================================================
47 
48 auto run_sympy_solver_visitor_ast(const std::string& text,
49  bool pade = false,
50  bool cse = false,
51  bool kinetic = false) {
52  // construct AST from text
54  const auto& ast = driver.parse_string(text);
55 
56  // construct symbol table from AST
58 
59  // unroll loops and fold constants
64 
65  if (kinetic) {
67  }
68 
69  // run SympySolver on AST
70  SympySolverVisitor(pade, cse).visit_program(*ast);
71 
72  // check that, after visitor rearrangement, parents are still up-to-date
74 
75  return ast;
76 }
77 
78 std::vector<std::string> run_sympy_solver_visitor(
79  const std::string& text,
80  bool pade = false,
81  bool cse = false,
82  AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
83  bool kinetic = false) {
84  std::vector<std::string> results;
85 
86  const auto& ast = run_sympy_solver_visitor_ast(text, pade, cse, kinetic);
87 
88  // run lookup visitor to extract results from AST
89  for (const auto& eq: collect_nodes(*ast, {ret_nodetype})) {
90  results.push_back(to_nmodl(eq));
91  }
92 
93  return results;
94 }
95 
96 // check if in a list of vars (like LOCAL) there are duplicates
97 bool is_unique_vars(std::string result) {
98  result.erase(std::remove(result.begin(), result.end(), ','), result.end());
99  std::stringstream ss(result);
100  std::string token;
101 
102  std::unordered_set<std::string> old_vars;
103 
104  while (getline(ss, token, ' ')) {
105  if (!old_vars.insert(token).second) {
106  return false;
107  }
108  }
109  return true;
110 }
111 
112 
113 /**
114  * \brief Compare nmodl blocks that contain systems of equations (i.e. derivative, linear, etc.)
115  *
116  * This is basically and advanced string == string comparison where we detect the (various) systems
117  * of equations and check if they are equivalent. Implemented mostly in python since we need a call
118  * to sympy to simplify the equations.
119  *
120  * - compare_systems_of_eq The core of the code. \p result_dict and \p expected_dict are
121  * dictionaries that represent the systems of equations in this way:
122  *
123  * a = b*x + c -> result_dict['a'] = 'b*x + c'
124  *
125  * where the variable \p a become a key \p k of the dictionary.
126  *
127  * In there we go over all the equations in \p result_dict and \p expected_dict and check that
128  * result_dict[k] - expected_dict[k] simplifies to 0.
129  *
130  * - sanitize is to transform the equations in something treatable by sympy (i.e. pow(dt, 3) ->
131  * dt**3
132  * - reduce back-substitution of the temporary variables
133  *
134  * \p require_fail requires that the equations are different. Used only for unit-test this function
135  *
136  * \warning do not use this method when there are tmp variables not in the form: tmp_<number>
137  */
138 void compare_blocks(const std::string& result,
139  const std::string& expected,
140  const bool require_fail = false) {
141  using namespace pybind11::literals;
142 
143  auto locals =
144  pybind11::dict("result"_a = result, "expected"_a = expected, "is_equal"_a = false);
145  pybind11::exec(R"(
146  # Comments are in the doxygen for better highlighting
147  def compare_blocks(result, expected):
148 
149  def sanitize(s):
150  import re
151  d = {'\[(\d+)\]':'_\\1', 'pow\‍((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
152  out = s
153  for key, val in d.items():
154  out = re.sub(key, val, out)
155  return out
156 
157  def compare_systems_of_eq(result_dict, expected_dict):
158  from sympy.parsing.sympy_parser import parse_expr
159  try:
160  for k, v in result_dict.items():
161  if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
162  return False
163  except KeyError:
164  return False
165 
166  result_dict.clear()
167  expected_dict.clear()
168  return True
169 
170  def reduce(s):
171  max_tmp = -1
172  d = {}
173 
174  sout = ""
175  # split of sout and a dict with the tmp variables
176  for line in s.split('\n'):
177  line_split = line.lstrip().split('=')
178 
179  if len(line_split) == 2 and line_split[0].startswith('tmp_'):
180  # back-substitution of tmp variables in tmp variables
181  tmp_var = line_split[0].strip()
182  if tmp_var in d:
183  continue
184 
185  max_tmp = max(max_tmp, int(tmp_var[4:]))
186  for k, v in d.items():
187  line_split[1] = line_split[1].replace(k, f'({v})')
188  d[tmp_var] = line_split[1]
189  elif 'LOCAL' in line:
190  sout += line.split('tmp_0')[0] + '\n'
191  else:
192  sout += line + '\n'
193 
194  # Back-substitution of the tmps
195  # so that we do not replace tmp_11 with (tmp_1)1
196  for j in range(max_tmp, -1, -1):
197  k = f'tmp_{j}'
198  sout = sout.replace(k, f'({d[k]})')
199 
200  return sout
201 
202  result = reduce(sanitize(result)).split('\n')
203  expected = reduce(sanitize(expected)).split('\n')
204 
205  if len(result) != len(expected):
206  return False
207 
208  result_dict = {}
209  expected_dict = {}
210  for token1, token2 in zip(result, expected):
211  if token1 == token2:
212  if not compare_systems_of_eq(result_dict, expected_dict):
213  return False
214  continue
215 
216  eq1 = token1.split('=')
217  eq2 = token2.split('=')
218  if len(eq1) == 2 and len(eq2) == 2:
219  result_dict[eq1[0]] = eq1[1]
220  expected_dict[eq2[0]] = eq2[1]
221  continue
222 
223  return False
224  return compare_systems_of_eq(result_dict, expected_dict)
225 
226  is_equal = compare_blocks(result, expected))",
227  pybind11::globals(),
228  locals);
229 
230  // Error log
231  if (require_fail == locals["is_equal"].cast<bool>()) {
232  if (require_fail) {
233  REQUIRE(result != expected);
234  } else {
235  REQUIRE(result == expected);
236  }
237  } else { // so that we signal to ctest that an assert was performed
238  REQUIRE(true);
239  }
240 }
241 
242 
244  // construct symbol table from AST
245  SymtabVisitor v_symtab;
246  v_symtab.visit_program(node);
247 
248  // run SympySolver on AST several times
249  SympySolverVisitor v_sympy1;
250  v_sympy1.visit_program(node);
251  v_sympy1.visit_program(node);
252 
253  // also use a second instance of SympySolver
254  SympySolverVisitor v_sympy2;
255  v_sympy2.visit_program(node);
256  v_sympy1.visit_program(node);
257  v_sympy2.visit_program(node);
258 }
259 
260 
262  std::stringstream stream;
264  return stream.str();
265 }
266 
267 SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]") {
268  GIVEN("Empty strings") {
269  THEN("Strings are equal") {
270  compare_blocks("", "");
271  }
272  }
273  GIVEN("Equivalent equation") {
274  THEN("Strings are equal") {
275  compare_blocks("a = 3*b + c", "a = 2*b + b + c");
276  }
277  }
278  GIVEN("Equivalent systems of equations") {
279  std::string result = R"(
280  x = 3*b + c
281  y = 2*a + b)";
282  std::string expected = R"(
283  x = b+2*b + c
284  y = 2*a + 2*b-b)";
285  THEN("Systems of equations are equal") {
286  compare_blocks(result, expected);
287  }
288  }
289  GIVEN("Equivalent systems of equations with brackets") {
290  std::string result = R"(
291  DERIVATIVE {
292  A[0] = 3*b + c
293  y = pow(a, 3) + b
294  })";
295  std::string expected = R"(
296  DERIVATIVE {
297  tmp_0 = a + c
298  tmp_1 = tmp_0 - a
299  A[0] = b+2*b + tmp_1
300  y = pow(a, 2)*a + 2*b-b
301  })";
302  THEN("Blocks are equal") {
303  compare_blocks(result, expected);
304  }
305  }
306  GIVEN("Different systems of equations (additional space)") {
307  std::string result = R"(
308  DERIVATIVE {
309  x = 3*b + c
310  y = 2*a + b
311  })";
312  std::string expected = R"(
313  DERIVATIVE {
314  x = b+2*b + c
315  y = 2*a + 2*b-b
316  })";
317  THEN("Blocks are different") {
318  compare_blocks(result, expected, true);
319  }
320  }
321  GIVEN("Different systems of equations") {
322  std::string result = R"(
323  DERIVATIVE {
324  tmp_0 = a - c
325  tmp_1 = tmp_0 - a
326  x = 3*b + tmp_1
327  y = 2*a + b
328  })";
329  std::string expected = R"(
330  DERIVATIVE {
331  x = b+2*b + c
332  y = 2*a + 2*b-b
333  })";
334  THEN("Blocks are different") {
335  compare_blocks(result, expected, true);
336  }
337  }
338 }
339 
340 SCENARIO("Check local vars name-clash prevention", "[visitor][sympy]") {
341  GIVEN("LOCAL tmp") {
342  std::string nmodl_text = R"(
343  STATE {
344  x y
345  }
346  BREAKPOINT {
347  SOLVE states METHOD sparse
348  }
349  DERIVATIVE states {
350  LOCAL tmp, b
351  x' = tmp + b
352  y' = tmp + b
353  })";
354  THEN("There are no duplicate vars in LOCAL") {
355  auto result =
356  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
357  REQUIRE(!result.empty());
358  REQUIRE(is_unique_vars(result[0]));
359  }
360  }
361  GIVEN("LOCAL tmp_0") {
362  std::string nmodl_text = R"(
363  STATE {
364  x y
365  }
366  BREAKPOINT {
367  SOLVE states METHOD sparse
368  }
369  DERIVATIVE states {
370  LOCAL tmp_0, b
371  x' = tmp_0 + b
372  y' = tmp_0 + b
373  })";
374  THEN("There are no duplicate vars in LOCAL") {
375  auto result =
376  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::LOCAL_LIST_STATEMENT);
377  REQUIRE(!result.empty());
378  REQUIRE(is_unique_vars(result[0]));
379  }
380  }
381 }
382 
383 SCENARIO("Solve ODEs with cnexp or euler method using SympySolverVisitor",
384  "[visitor][sympy][cnexp][euler]") {
385  GIVEN("Derivative block without ODE, solver method cnexp") {
386  std::string nmodl_text = R"(
387  BREAKPOINT {
388  SOLVE states METHOD cnexp
389  }
390  DERIVATIVE states {
391  m = m + h
392  }
393  )";
394  THEN("No ODEs found - do nothing") {
396  REQUIRE(result.empty());
397  }
398  }
399  GIVEN("Derivative block with ODES, solver method is euler") {
400  std::string nmodl_text = R"(
401  BREAKPOINT {
402  SOLVE states METHOD euler
403  }
404  DERIVATIVE states {
405  m' = (mInf-m)/mTau
406  h' = (hInf-h)/hTau
407  z = a*b + c
408  }
409  )";
410  THEN("Construct forwards Euler solutions") {
412  REQUIRE(result.size() == 2);
413  REQUIRE(result[0] == "m = (-dt*(m-mInf)+m*mTau)/mTau");
414  REQUIRE(result[1] == "h = (-dt*(h-hInf)+h*hTau)/hTau");
415  }
416  }
417  GIVEN("Derivative block with calling external functions passes sympy") {
418  std::string nmodl_text = R"(
419  BREAKPOINT {
420  SOLVE states METHOD euler
421  }
422  DERIVATIVE states {
423  m' = sawtooth(m)
424  n' = sin(n)
425  p' = my_user_func(p)
426  }
427  )";
428  THEN("Construct forward Euler interpreting external functions as symbols") {
430  REQUIRE(result.size() == 3);
431  REQUIRE(result[0] == "m = dt*sawtooth(m)+m");
432  REQUIRE(result[1] == "n = dt*sin(n)+n");
433  REQUIRE(result[2] == "p = dt*my_user_func(p)+p");
434  }
435  }
436  GIVEN("Derivative block with ODE, 1 state var in array, solver method euler") {
437  std::string nmodl_text = R"(
438  STATE {
439  m[1]
440  }
441  BREAKPOINT {
442  SOLVE states METHOD euler
443  }
444  DERIVATIVE states {
445  m'[0] = (mInf-m[0])/mTau
446  }
447  )";
448  THEN("Construct forwards Euler solutions") {
450  REQUIRE(result.size() == 1);
451  REQUIRE(result[0] == "m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
452  }
453  }
454  GIVEN("Derivative block with ODE, 1 state var in array, solver method cnexp") {
455  std::string nmodl_text = R"(
456  STATE {
457  m[1]
458  }
459  BREAKPOINT {
460  SOLVE states METHOD cnexp
461  }
462  DERIVATIVE states {
463  m'[0] = (mInf-m[0])/mTau
464  }
465  )";
466  THEN("Construct forwards Euler solutions") {
468  REQUIRE(result.size() == 1);
469  REQUIRE(result[0] == "m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
470  }
471  }
472  GIVEN("Derivative block with linear ODES, solver method cnexp") {
473  std::string nmodl_text = R"(
474  BREAKPOINT {
475  SOLVE states METHOD cnexp
476  }
477  DERIVATIVE states {
478  m' = (mInf-m)/mTau
479  z = a*b + c
480  h' = hInf/hTau - h/hTau
481  }
482  )";
483  THEN("Integrate equations analytically") {
485  REQUIRE(result.size() == 2);
486  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
487  REQUIRE(result[1] == "h = hInf-(-h+hInf)*exp(-dt/hTau)");
488  }
489  }
490  GIVEN("Derivative block including non-linear but solvable ODES, solver method cnexp") {
491  std::string nmodl_text = R"(
492  BREAKPOINT {
493  SOLVE states METHOD cnexp
494  }
495  DERIVATIVE states {
496  m' = (mInf-m)/mTau
497  h' = c2 * h*h
498  }
499  )";
500  THEN("Integrate equations analytically") {
502  REQUIRE(result.size() == 2);
503  REQUIRE(result[0] == "m = mInf-(-m+mInf)*exp(-dt/mTau)");
504  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
505  }
506  }
507  GIVEN("Derivative block including array of 2 state vars, solver method cnexp") {
508  std::string nmodl_text = R"(
509  BREAKPOINT {
510  SOLVE states METHOD cnexp
511  }
512  STATE {
513  X[2]
514  }
515  DERIVATIVE states {
516  X'[0] = (mInf-X[0])/mTau
517  X'[1] = c2 * X[1]*X[1]
518  }
519  )";
520  THEN("Integrate equations analytically") {
522  REQUIRE(result.size() == 2);
523  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
524  REQUIRE(result[1] == "X[1] = -X[1]/(c2*dt*X[1]-1.0)");
525  }
526  }
527  GIVEN("Derivative block including loop over array vars, solver method cnexp") {
528  std::string nmodl_text = R"(
529  DEFINE N 3
530  BREAKPOINT {
531  SOLVE states METHOD cnexp
532  }
533  ASSIGNED {
534  mTau[N]
535  }
536  STATE {
537  X[N]
538  }
539  DERIVATIVE states {
540  FROM i=0 TO N-1 {
541  X'[i] = (mInf-X[i])/mTau[i]
542  }
543  }
544  )";
545  THEN("Integrate equations analytically") {
547  REQUIRE(result.size() == 3);
548  REQUIRE(result[0] == "X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
549  REQUIRE(result[1] == "X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
550  REQUIRE(result[2] == "X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
551  }
552  }
553  GIVEN("Derivative block including loop over array vars, solver method euler") {
554  std::string nmodl_text = R"(
555  DEFINE N 3
556  BREAKPOINT {
557  SOLVE states METHOD euler
558  }
559  ASSIGNED {
560  mTau[N]
561  }
562  STATE {
563  X[N]
564  }
565  DERIVATIVE states {
566  FROM i=0 TO N-1 {
567  X'[i] = (mInf-X[i])/mTau[i]
568  }
569  }
570  )";
571  THEN("Integrate equations analytically") {
573  REQUIRE(result.size() == 3);
574  REQUIRE(result[0] == "X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
575  REQUIRE(result[1] == "X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
576  REQUIRE(result[2] == "X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
577  }
578  }
579  GIVEN("Derivative block including ODES that can't currently be solved, solver method cnexp") {
580  std::string nmodl_text = R"(
581  BREAKPOINT {
582  SOLVE states METHOD cnexp
583  }
584  DERIVATIVE states {
585  z' = a/z + b/z/z
586  h' = c2 * h*h
587  x' = a
588  y' = c3 * y*y*y
589  }
590  )";
591  THEN("Integrate equations analytically where possible, otherwise leave untouched") {
593  REQUIRE(result.size() == 4);
594  /// sympy 1.9 able to solve ode but not older versions
595  REQUIRE((result[0] == "z' = a/z+b/z/z" ||
596  result[0] ==
597  "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
598  REQUIRE(result[1] == "h = -h/(c2*dt*h-1.0)");
599  REQUIRE(result[2] == "x = a*dt+x");
600  /// sympy 1.4 able to solve ode but not older versions
601  REQUIRE((result[3] == "y' = c3*y*y*y" ||
602  result[3] == "y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
603  }
604  }
605  GIVEN("Derivative block with cnexp solver method, AST after SympySolver pass") {
606  std::string nmodl_text = R"(
607  BREAKPOINT {
608  SOLVE states METHOD cnexp
609  }
610  DERIVATIVE states {
611  m' = (mInf-m)/mTau
612  }
613  )";
614  // construct AST from text
616  auto ast = driver.parse_string(nmodl_text);
617 
618  // construct symbol table from AST
620 
621  // run SympySolver on AST
623 
624  std::string AST_string = ast_to_string(*ast);
625 
626  THEN("More SympySolver passes do nothing to the AST and don't throw") {
627  REQUIRE_NOTHROW(run_sympy_visitor_passes(*ast));
628  REQUIRE(AST_string == ast_to_string(*ast));
629  }
630  }
631 }
632 
633 SCENARIO("Solve ODEs with derivimplicit method using SympySolverVisitor",
634  "[visitor][sympy][derivimplicit]") {
635  GIVEN("Derivative block with derivimplicit solver method and conditional block") {
636  std::string nmodl_text = R"(
637  STATE {
638  m
639  }
640  BREAKPOINT {
641  SOLVE states METHOD derivimplicit
642  }
643  DERIVATIVE states {
644  IF (mInf == 1) {
645  mInf = mInf+1
646  }
647  m' = (mInf-m)/mTau
648  }
649  )";
650  std::string expected_result = R"(
651  DERIVATIVE states {
652  EIGEN_NEWTON_SOLVE[1]{
653  LOCAL old_m
654  }{
655  IF (mInf == 1) {
656  mInf = mInf+1
657  }
658  old_m = m
659  }{
660  nmodl_eigen_x[0] = m
661  }{
662  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
663  nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
664  }{
665  m = nmodl_eigen_x[0]
666  }{
667  }
668  })";
669  THEN("SympySolver correctly inserts ode to block") {
670  CAPTURE(nmodl_text);
671  auto result =
672  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
673  compare_blocks(result[0], reindent_text(expected_result));
674  }
675  }
676 
677  GIVEN("Derivative block, sparse, print in order") {
678  std::string nmodl_text = R"(
679  STATE {
680  x y
681  }
682  BREAKPOINT {
683  SOLVE states METHOD sparse
684  }
685  DERIVATIVE states {
686  LOCAL a, b
687  y' = a
688  x' = b
689  })";
690  std::string expected_result = R"(
691  DERIVATIVE states {
692  EIGEN_NEWTON_SOLVE[2]{
693  LOCAL a, b, old_y, old_x
694  }{
695  old_y = y
696  old_x = x
697  }{
698  nmodl_eigen_x[0] = x
699  nmodl_eigen_x[1] = y
700  }{
701  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
702  nmodl_eigen_j[0] = 0
703  nmodl_eigen_j[2] = -1/dt
704  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
705  nmodl_eigen_j[1] = -1/dt
706  nmodl_eigen_j[3] = 0
707  }{
708  x = nmodl_eigen_x[0]
709  y = nmodl_eigen_x[1]
710  }{
711  }
712  })";
713 
714  THEN("Construct & solve linear system for backwards Euler") {
715  auto result =
716  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
717 
718  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
719  }
720  }
721  GIVEN("Derivative block, sparse, print in order, vectors") {
722  std::string nmodl_text = R"(
723  STATE {
724  M[2]
725  }
726  BREAKPOINT {
727  SOLVE states METHOD sparse
728  }
729  DERIVATIVE states {
730  LOCAL a, b
731  M'[1] = a
732  M'[0] = b
733  })";
734  std::string expected_result = R"(
735  DERIVATIVE states {
736  EIGEN_NEWTON_SOLVE[2]{
737  LOCAL a, b, old_M_1, old_M_0
738  }{
739  old_M_1 = M[1]
740  old_M_0 = M[0]
741  }{
742  nmodl_eigen_x[0] = M[0]
743  nmodl_eigen_x[1] = M[1]
744  }{
745  nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
746  nmodl_eigen_j[0] = 0
747  nmodl_eigen_j[2] = -1/dt
748  nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
749  nmodl_eigen_j[1] = -1/dt
750  nmodl_eigen_j[3] = 0
751  }{
752  M[0] = nmodl_eigen_x[0]
753  M[1] = nmodl_eigen_x[1]
754  }{
755  }
756  })";
757 
758  THEN("Construct & solve linear system for backwards Euler") {
759  auto result =
760  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
761 
762  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
763  }
764  }
765  GIVEN("Derivative block, sparse, derivatives mixed with local variable reassignment") {
766  std::string nmodl_text = R"(
767  STATE {
768  x y
769  }
770  BREAKPOINT {
771  SOLVE states METHOD sparse
772  }
773  DERIVATIVE states {
774  LOCAL a, b
775  x' = a
776  b = b + 1
777  y' = b
778  })";
779  std::string expected_result = R"(
780  DERIVATIVE states {
781  EIGEN_NEWTON_SOLVE[2]{
782  LOCAL a, b, old_x, old_y
783  }{
784  old_x = x
785  old_y = y
786  }{
787  nmodl_eigen_x[0] = x
788  nmodl_eigen_x[1] = y
789  }{
790  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
791  nmodl_eigen_j[0] = -1/dt
792  nmodl_eigen_j[2] = 0
793  b = b+1
794  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
795  nmodl_eigen_j[1] = 0
796  nmodl_eigen_j[3] = -1/dt
797  }{
798  x = nmodl_eigen_x[0]
799  y = nmodl_eigen_x[1]
800  }{
801  }
802  })";
803 
804  THEN("Construct & solve linear system for backwards Euler") {
805  auto result =
806  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
807 
808  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
809  }
810  }
811  GIVEN(
812  "Throw exception during derivative variable reassignment interleaved in the differential "
813  "equation set") {
814  std::string nmodl_text = R"(
815  STATE {
816  x y
817  }
818  BREAKPOINT {
819  SOLVE states METHOD sparse
820  }
821  DERIVATIVE states {
822  LOCAL a, b
823  x' = a
824  x = x + 1
825  y' = b + x
826  })";
827 
828  THEN(
829  "Throw an error because state variable assignments are not allowed inside the system "
830  "of differential "
831  "equations") {
832  REQUIRE_THROWS_WITH(
833  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK),
834  Catch::Matchers::ContainsSubstring(
835  "State variable assignment(s) interleaved in system of "
836  "equations/differential equations") &&
837  Catch::Matchers::StartsWith("SympyReplaceSolutionsVisitor"));
838  }
839  }
840  GIVEN("Derivative block in control flow block") {
841  std::string nmodl_text = R"(
842  STATE {
843  x y
844  }
845  BREAKPOINT {
846  SOLVE states METHOD sparse
847  }
848  DERIVATIVE states {
849  LOCAL a, b
850  if (a == 1) {
851  x' = a
852  y' = b
853  }
854  })";
855  std::string expected_result = R"(
856  DERIVATIVE states {
857  LOCAL a, b
858  IF (a == 1) {
859  EIGEN_NEWTON_SOLVE[2]{
860  LOCAL old_x, old_y
861  }{
862  old_x = x
863  old_y = y
864  }{
865  nmodl_eigen_x[0] = x
866  nmodl_eigen_x[1] = y
867  }{
868  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
869  nmodl_eigen_j[0] = -1/dt
870  nmodl_eigen_j[2] = 0
871  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
872  nmodl_eigen_j[1] = 0
873  nmodl_eigen_j[3] = -1/dt
874  }{
875  x = nmodl_eigen_x[0]
876  y = nmodl_eigen_x[1]
877  }{
878  }
879  }
880  })";
881 
882  THEN("Construct & solve linear system for backwards Euler") {
883  auto result =
884  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
885 
886  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
887  }
888  }
889  GIVEN(
890  "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
891  "block") {
892  std::string nmodl_text = R"(
893  STATE {
894  x y
895  }
896  BREAKPOINT {
897  SOLVE states METHOD sparse
898  }
899  DERIVATIVE states {
900  LOCAL a, b
901  x' = a * y+b
902  if (b == 1) {
903  a = a + 1
904  }
905  y' = x + a*y
906  })";
907  std::string expected_result = R"(
908  DERIVATIVE states {
909  EIGEN_NEWTON_SOLVE[2]{
910  LOCAL a, b, old_x, old_y
911  }{
912  old_x = x
913  old_y = y
914  }{
915  nmodl_eigen_x[0] = x
916  nmodl_eigen_x[1] = y
917  }{
918  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
919  nmodl_eigen_j[0] = -1/dt
920  nmodl_eigen_j[2] = a
921  IF (b == 1) {
922  a = a+1
923  }
924  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
925  nmodl_eigen_j[1] = 1.0
926  nmodl_eigen_j[3] = a-1/dt
927  }{
928  x = nmodl_eigen_x[0]
929  y = nmodl_eigen_x[1]
930  }{
931  }
932  })";
933  std::string expected_result_cse = R"(
934  DERIVATIVE states {
935  EIGEN_NEWTON_SOLVE[2]{
936  LOCAL a, b, old_x, old_y
937  }{
938  old_x = x
939  old_y = y
940  }{
941  nmodl_eigen_x[0] = x
942  nmodl_eigen_x[1] = y
943  }{
944  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
945  nmodl_eigen_j[0] = -1/dt
946  nmodl_eigen_j[2] = a
947  IF (b == 1) {
948  a = a+1
949  }
950  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
951  nmodl_eigen_j[1] = 1.0
952  nmodl_eigen_j[3] = a-1/dt
953  }{
954  x = nmodl_eigen_x[0]
955  y = nmodl_eigen_x[1]
956  }{
957  }
958  })";
959 
960  THEN("Construct & solve linear system for backwards Euler") {
961  auto result =
962  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
963  auto result_cse =
964  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
965 
966  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
967  compare_blocks(reindent_text(result_cse[0]), reindent_text(expected_result_cse));
968  }
969  }
970 
971  GIVEN("Derivative block of coupled & linear ODES, solver method sparse") {
972  std::string nmodl_text = R"(
973  STATE {
974  x y z
975  }
976  BREAKPOINT {
977  SOLVE states METHOD sparse
978  }
979  DERIVATIVE states {
980  LOCAL a, b, c, d, h
981  x' = a*z + b*h
982  y' = c + 2*x
983  z' = d*z - y
984  }
985  )";
986  std::string expected_result = R"(
987  DERIVATIVE states {
988  EIGEN_NEWTON_SOLVE[3]{
989  LOCAL a, b, c, d, h, old_x, old_y, old_z
990  }{
991  old_x = x
992  old_y = y
993  old_z = z
994  }{
995  nmodl_eigen_x[0] = x
996  nmodl_eigen_x[1] = y
997  nmodl_eigen_x[2] = z
998  }{
999  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1000  nmodl_eigen_j[0] = -1/dt
1001  nmodl_eigen_j[3] = 0
1002  nmodl_eigen_j[6] = a
1003  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1004  nmodl_eigen_j[1] = 2.0
1005  nmodl_eigen_j[4] = -1/dt
1006  nmodl_eigen_j[7] = 0
1007  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1008  nmodl_eigen_j[2] = 0
1009  nmodl_eigen_j[5] = -1.0
1010  nmodl_eigen_j[8] = d-1/dt
1011  }{
1012  x = nmodl_eigen_x[0]
1013  y = nmodl_eigen_x[1]
1014  z = nmodl_eigen_x[2]
1015  }{
1016  }
1017  })";
1018  std::string expected_cse_result = R"(
1019  DERIVATIVE states {
1020  EIGEN_NEWTON_SOLVE[3]{
1021  LOCAL a, b, c, d, h, old_x, old_y, old_z
1022  }{
1023  old_x = x
1024  old_y = y
1025  old_z = z
1026  }{
1027  nmodl_eigen_x[0] = x
1028  nmodl_eigen_x[1] = y
1029  nmodl_eigen_x[2] = z
1030  }{
1031  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1032  nmodl_eigen_j[0] = -1/dt
1033  nmodl_eigen_j[3] = 0
1034  nmodl_eigen_j[6] = a
1035  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1036  nmodl_eigen_j[1] = 2.0
1037  nmodl_eigen_j[4] = -1/dt
1038  nmodl_eigen_j[7] = 0
1039  nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1040  nmodl_eigen_j[2] = 0
1041  nmodl_eigen_j[5] = -1.0
1042  nmodl_eigen_j[8] = d-1/dt
1043  }{
1044  x = nmodl_eigen_x[0]
1045  y = nmodl_eigen_x[1]
1046  z = nmodl_eigen_x[2]
1047  }{
1048  }
1049  })";
1050 
1051  THEN("Construct & solve linear system for backwards Euler") {
1052  auto result =
1053  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1054  auto result_cse =
1055  run_sympy_solver_visitor(nmodl_text, true, true, AstNodeType::DERIVATIVE_BLOCK);
1056 
1057  compare_blocks(result[0], reindent_text(expected_result));
1058  compare_blocks(result_cse[0], reindent_text(expected_cse_result));
1059  }
1060  }
1061  GIVEN("Derivative block including ODES with sparse method (from nmodl paper)") {
1062  std::string nmodl_text = R"(
1063  STATE {
1064  mc m
1065  }
1066  BREAKPOINT {
1067  SOLVE scheme1 METHOD sparse
1068  }
1069  DERIVATIVE scheme1 {
1070  mc' = -a*mc + b*m
1071  m' = a*mc - b*m
1072  }
1073  )";
1074  std::string expected_result = R"(
1075  DERIVATIVE scheme1 {
1076  EIGEN_NEWTON_SOLVE[2]{
1077  LOCAL old_mc, old_m
1078  }{
1079  old_mc = mc
1080  old_m = m
1081  }{
1082  nmodl_eigen_x[0] = mc
1083  nmodl_eigen_x[1] = m
1084  }{
1085  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1086  nmodl_eigen_j[0] = -a-1/dt
1087  nmodl_eigen_j[2] = b
1088  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1089  nmodl_eigen_j[1] = a
1090  nmodl_eigen_j[3] = -b-1/dt
1091  }{
1092  mc = nmodl_eigen_x[0]
1093  m = nmodl_eigen_x[1]
1094  }{
1095  }
1096  })";
1097  THEN("Construct & solve linear system") {
1098  CAPTURE(nmodl_text);
1099  auto result =
1100  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1101  compare_blocks(result[0], reindent_text(expected_result));
1102  }
1103  }
1104  GIVEN("Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1105  std::string nmodl_text = R"(
1106  STATE {
1107  mc m
1108  }
1109  BREAKPOINT {
1110  SOLVE scheme1 METHOD sparse
1111  }
1112  DERIVATIVE scheme1 {
1113  mc' = -a*mc + b*m
1114  m' = a*mc - b*m
1115  CONSERVE m = 1 - mc
1116  }
1117  )";
1118  std::string expected_result = R"(
1119  DERIVATIVE scheme1 {
1120  EIGEN_NEWTON_SOLVE[2]{
1121  LOCAL old_mc
1122  }{
1123  old_mc = mc
1124  }{
1125  nmodl_eigen_x[0] = mc
1126  nmodl_eigen_x[1] = m
1127  }{
1128  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1129  nmodl_eigen_j[0] = -a-1/dt
1130  nmodl_eigen_j[2] = b
1131  nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1132  nmodl_eigen_j[1] = -1.0
1133  nmodl_eigen_j[3] = -1.0
1134  }{
1135  mc = nmodl_eigen_x[0]
1136  m = nmodl_eigen_x[1]
1137  }{
1138  }
1139  })";
1140  THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1141  CAPTURE(nmodl_text);
1142  auto result =
1143  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1144  compare_blocks(result[0], reindent_text(expected_result));
1145  }
1146  }
1147  GIVEN(
1148  "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1149  "= ...") {
1150  std::string nmodl_text = R"(
1151  STATE {
1152  mc m
1153  }
1154  BREAKPOINT {
1155  SOLVE scheme1 METHOD sparse
1156  }
1157  DERIVATIVE scheme1 {
1158  mc' = -a*mc + b*m
1159  m' = a*mc - b*m
1160  CONSERVE m + mc = 1
1161  }
1162  )";
1163  std::string expected_result = R"(
1164  DERIVATIVE scheme1 {
1165  EIGEN_NEWTON_SOLVE[2]{
1166  LOCAL old_mc, old_m
1167  }{
1168  old_mc = mc
1169  old_m = m
1170  }{
1171  nmodl_eigen_x[0] = mc
1172  nmodl_eigen_x[1] = m
1173  }{
1174  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1175  nmodl_eigen_j[0] = -a-1/dt
1176  nmodl_eigen_j[2] = b
1177  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1178  nmodl_eigen_j[1] = a
1179  nmodl_eigen_j[3] = -b-1/dt
1180  }{
1181  mc = nmodl_eigen_x[0]
1182  m = nmodl_eigen_x[1]
1183  }{
1184  }
1185  })";
1186  THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1187  CAPTURE(nmodl_text);
1188  auto result =
1189  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1190  compare_blocks(result[0], reindent_text(expected_result));
1191  }
1192  }
1193  GIVEN("Derivative block with ODES with sparse method, two CONSERVE statements") {
1194  std::string nmodl_text = R"(
1195  STATE {
1196  c1 o1 o2 p0 p1
1197  }
1198  BREAKPOINT {
1199  SOLVE ihkin METHOD sparse
1200  }
1201  DERIVATIVE ihkin {
1202  LOCAL alpha, beta, k3p, k4, k1ca, k2
1203  evaluate_fct(v, cai)
1204  CONSERVE p1 = 1-p0
1205  CONSERVE o2 = 1-c1-o1
1206  c1' = (-1*(alpha*c1-beta*o1))
1207  o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1208  o2' = (1*(k3p*o1-k4*o2))
1209  p0' = (-1*(k1ca*p0-k2*p1))
1210  p1' = (1*(k1ca*p0-k2*p1))
1211  })";
1212  std::string expected_result = R"(
1213  DERIVATIVE ihkin {
1214  EIGEN_NEWTON_SOLVE[5]{
1215  LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1216  }{
1217  evaluate_fct(v, cai)
1218  old_c1 = c1
1219  old_o1 = o1
1220  old_p0 = p0
1221  }{
1222  nmodl_eigen_x[0] = c1
1223  nmodl_eigen_x[1] = o1
1224  nmodl_eigen_x[2] = o2
1225  nmodl_eigen_x[3] = p0
1226  nmodl_eigen_x[4] = p1
1227  }{
1228  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1229  nmodl_eigen_j[0] = -alpha-1/dt
1230  nmodl_eigen_j[5] = beta
1231  nmodl_eigen_j[10] = 0
1232  nmodl_eigen_j[15] = 0
1233  nmodl_eigen_j[20] = 0
1234  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1235  nmodl_eigen_j[1] = alpha
1236  nmodl_eigen_j[6] = -beta-k3p-1/dt
1237  nmodl_eigen_j[11] = k4
1238  nmodl_eigen_j[16] = 0
1239  nmodl_eigen_j[21] = 0
1240  nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1241  nmodl_eigen_j[2] = -1.0
1242  nmodl_eigen_j[7] = -1.0
1243  nmodl_eigen_j[12] = -1.0
1244  nmodl_eigen_j[17] = 0
1245  nmodl_eigen_j[22] = 0
1246  nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1247  nmodl_eigen_j[3] = 0
1248  nmodl_eigen_j[8] = 0
1249  nmodl_eigen_j[13] = 0
1250  nmodl_eigen_j[18] = -k1ca-1/dt
1251  nmodl_eigen_j[23] = k2
1252  nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1253  nmodl_eigen_j[4] = 0
1254  nmodl_eigen_j[9] = 0
1255  nmodl_eigen_j[14] = 0
1256  nmodl_eigen_j[19] = -1.0
1257  nmodl_eigen_j[24] = -1.0
1258  }{
1259  c1 = nmodl_eigen_x[0]
1260  o1 = nmodl_eigen_x[1]
1261  o2 = nmodl_eigen_x[2]
1262  p0 = nmodl_eigen_x[3]
1263  p1 = nmodl_eigen_x[4]
1264  }{
1265  }
1266  })";
1267  THEN(
1268  "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1269  "algebraic relations") {
1270  CAPTURE(nmodl_text);
1271  auto result =
1272  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1273  compare_blocks(result[0], reindent_text(expected_result));
1274  }
1275  }
1276  GIVEN("Derivative block including ODES with sparse method - single var in array") {
1277  std::string nmodl_text = R"(
1278  STATE {
1279  W[1]
1280  }
1281  ASSIGNED {
1282  A[2]
1283  B[1]
1284  }
1285  BREAKPOINT {
1286  SOLVE scheme1 METHOD sparse
1287  }
1288  DERIVATIVE scheme1 {
1289  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1290  }
1291  )";
1292  std::string expected_result = R"(
1293  DERIVATIVE scheme1 {
1294  EIGEN_NEWTON_SOLVE[1]{
1295  LOCAL old_W_0
1296  }{
1297  old_W_0 = W[0]
1298  }{
1299  nmodl_eigen_x[0] = W[0]
1300  }{
1301  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1302  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1303  }{
1304  W[0] = nmodl_eigen_x[0]
1305  }{
1306  }
1307  })";
1308  THEN("Construct & solver linear system") {
1309  CAPTURE(nmodl_text);
1310  auto result =
1311  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1312  compare_blocks(result[0], reindent_text(expected_result));
1313  }
1314  }
1315  GIVEN("Derivative block including ODES with sparse method - array vars") {
1316  std::string nmodl_text = R"(
1317  STATE {
1318  M[2]
1319  }
1320  ASSIGNED {
1321  A[2]
1322  B[2]
1323  }
1324  BREAKPOINT {
1325  SOLVE scheme1 METHOD sparse
1326  }
1327  DERIVATIVE scheme1 {
1328  M'[0] = -A[0]*M[0] + B[0]*M[1]
1329  M'[1] = A[1]*M[0] - B[1]*M[1]
1330  }
1331  )";
1332  std::string expected_result = R"(
1333  DERIVATIVE scheme1 {
1334  EIGEN_NEWTON_SOLVE[2]{
1335  LOCAL old_M_0, old_M_1
1336  }{
1337  old_M_0 = M[0]
1338  old_M_1 = M[1]
1339  }{
1340  nmodl_eigen_x[0] = M[0]
1341  nmodl_eigen_x[1] = M[1]
1342  }{
1343  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1344  nmodl_eigen_j[0] = -A[0]-1/dt
1345  nmodl_eigen_j[2] = B[0]
1346  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1347  nmodl_eigen_j[1] = A[1]
1348  nmodl_eigen_j[3] = -B[1]-1/dt
1349  }{
1350  M[0] = nmodl_eigen_x[0]
1351  M[1] = nmodl_eigen_x[1]
1352  }{
1353  }
1354  })";
1355  THEN("Construct & solver linear system") {
1356  CAPTURE(nmodl_text);
1357  auto result =
1358  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1359  compare_blocks(result[0], reindent_text(expected_result));
1360  }
1361  }
1362  GIVEN("Derivative block including ODES with derivimplicit method - single var in array") {
1363  std::string nmodl_text = R"(
1364  STATE {
1365  W[1]
1366  }
1367  ASSIGNED {
1368  A[2]
1369  B[1]
1370  }
1371  BREAKPOINT {
1372  SOLVE scheme1 METHOD derivimplicit
1373  }
1374  DERIVATIVE scheme1 {
1375  W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1376  }
1377  )";
1378  std::string expected_result = R"(
1379  DERIVATIVE scheme1 {
1380  EIGEN_NEWTON_SOLVE[1]{
1381  LOCAL old_W_0
1382  }{
1383  old_W_0 = W[0]
1384  }{
1385  nmodl_eigen_x[0] = W[0]
1386  }{
1387  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1388  nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1389  }{
1390  W[0] = nmodl_eigen_x[0]
1391  }{
1392  }
1393  })";
1394  THEN("Construct newton solve block") {
1395  CAPTURE(nmodl_text);
1396  auto result =
1397  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1398  compare_blocks(result[0], reindent_text(expected_result));
1399  }
1400  }
1401  GIVEN("Derivative block including ODES with derivimplicit method") {
1402  std::string nmodl_text = R"(
1403  STATE {
1404  m h n
1405  }
1406  BREAKPOINT {
1407  SOLVE states METHOD derivimplicit
1408  }
1409  DERIVATIVE states {
1410  rates(v)
1411  m' = (minf-m)/mtau - 3*h
1412  h' = (hinf-h)/htau + m*m
1413  n' = (ninf-n)/ntau
1414  }
1415  )";
1416  /// new derivative block with EigenNewtonSolverBlock node
1417  std::string expected_result = R"(
1418  DERIVATIVE states {
1419  EIGEN_NEWTON_SOLVE[3]{
1420  LOCAL old_m, old_h, old_n
1421  }{
1422  rates(v)
1423  old_m = m
1424  old_h = h
1425  old_n = n
1426  }{
1427  nmodl_eigen_x[0] = m
1428  nmodl_eigen_x[1] = h
1429  nmodl_eigen_x[2] = n
1430  }{
1431  nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1432  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1433  nmodl_eigen_j[3] = -3.0
1434  nmodl_eigen_j[6] = 0
1435  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1436  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1437  nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1438  nmodl_eigen_j[7] = 0
1439  nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1440  nmodl_eigen_j[2] = 0
1441  nmodl_eigen_j[5] = 0
1442  nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1443  }{
1444  m = nmodl_eigen_x[0]
1445  h = nmodl_eigen_x[1]
1446  n = nmodl_eigen_x[2]
1447  }{
1448  }
1449  })";
1450  THEN("Construct newton solve block") {
1451  CAPTURE(nmodl_text);
1452  auto result =
1453  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1454  compare_blocks(result[0], reindent_text(expected_result));
1455  }
1456  }
1457  GIVEN("Multiple derivative blocks each with derivimplicit method") {
1458  std::string nmodl_text = R"(
1459  STATE {
1460  m h
1461  }
1462  BREAKPOINT {
1463  SOLVE states1 METHOD derivimplicit
1464  SOLVE states2 METHOD derivimplicit
1465  }
1466 
1467  DERIVATIVE states1 {
1468  m' = (minf-m)/mtau
1469  h' = (hinf-h)/htau + m*m
1470  }
1471 
1472  DERIVATIVE states2 {
1473  h' = (hinf-h)/htau + m*m
1474  m' = (minf-m)/mtau + h
1475  }
1476  )";
1477  /// EigenNewtonSolverBlock in each derivative block
1478  std::string expected_result_0 = R"(
1479  DERIVATIVE states1 {
1480  EIGEN_NEWTON_SOLVE[2]{
1481  LOCAL old_m, old_h
1482  }{
1483  old_m = m
1484  old_h = h
1485  }{
1486  nmodl_eigen_x[0] = m
1487  nmodl_eigen_x[1] = h
1488  }{
1489  nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1490  nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1491  nmodl_eigen_j[2] = 0
1492  nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1493  nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1494  nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1495  }{
1496  m = nmodl_eigen_x[0]
1497  h = nmodl_eigen_x[1]
1498  }{
1499  }
1500  })";
1501  std::string expected_result_1 = R"(
1502  DERIVATIVE states2 {
1503  EIGEN_NEWTON_SOLVE[2]{
1504  LOCAL old_h, old_m
1505  }{
1506  old_h = h
1507  old_m = m
1508  }{
1509  nmodl_eigen_x[0] = m
1510  nmodl_eigen_x[1] = h
1511  }{
1512  nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1513  nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1514  nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1515  nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1516  nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1517  nmodl_eigen_j[3] = 1.0
1518  }{
1519  m = nmodl_eigen_x[0]
1520  h = nmodl_eigen_x[1]
1521  }{
1522  }
1523  })";
1524  THEN("Construct newton solve block") {
1525  auto result =
1526  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK);
1527  CAPTURE(nmodl_text);
1528  compare_blocks(result[0], reindent_text(expected_result_0));
1529  compare_blocks(result[1], reindent_text(expected_result_1));
1530  }
1531  }
1532 }
1533 
1534 
1535 //=============================================================================
1536 // LINEAR solve block tests
1537 //=============================================================================
1538 
1539 SCENARIO("LINEAR solve block (SympySolver Visitor)", "[sympy][linear]") {
1540  GIVEN("1 state-var symbolic LINEAR solve block") {
1541  std::string nmodl_text = R"(
1542  STATE {
1543  x
1544  }
1545  LINEAR lin {
1546  ~ 2*a*x = 1
1547  })";
1548  std::string expected_text = R"(
1549  LINEAR lin {
1550  x = 0.5/a
1551  })";
1552  THEN("solve analytically") {
1553  auto result =
1554  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1555  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1556  }
1557  }
1558  GIVEN("2 state-var LINEAR solve block") {
1559  std::string nmodl_text = R"(
1560  STATE {
1561  x y
1562  }
1563  LINEAR lin {
1564  ~ x + 4*y = 5*a
1565  ~ x - y = 0
1566  })";
1567  std::string expected_text = R"(
1568  LINEAR lin {
1569  x = a
1570  y = a
1571  })";
1572  THEN("solve analytically") {
1573  auto result =
1574  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1575  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1576  }
1577  }
1578  GIVEN("Linear block, print in order, vectors") {
1579  std::string nmodl_text = R"(
1580  STATE {
1581  M[2]
1582  }
1583  LINEAR lin {
1584  ~ M[1] = M[0] + 1
1585  ~ M[0] = 2
1586  })";
1587  std::string expected_result = R"(
1588  LINEAR lin {
1589  M[1] = 3.0
1590  M[0] = 2.0
1591  })";
1592 
1593  THEN("Construct & solve linear system") {
1594  auto result =
1595  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1596 
1597  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1598  }
1599  }
1600  GIVEN("Linear block, by value replacement, interleaved") {
1601  std::string nmodl_text = R"(
1602  STATE {
1603  x y
1604  }
1605  LINEAR lin {
1606  LOCAL a
1607  a = 0
1608  ~ x = y + a
1609  a = 1
1610  ~ y = a
1611  a = 2
1612  })";
1613  std::string expected_result = R"(
1614  LINEAR lin {
1615  LOCAL a
1616  a = 0
1617  x = 2.0*a
1618  a = 1
1619  y = a
1620  a = 2
1621  })";
1622 
1623  THEN("Construct & solve linear system") {
1624  auto result =
1625  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1626 
1627  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1628  }
1629  }
1630  GIVEN("Linear block in control flow block") {
1631  std::string nmodl_text = R"(
1632  STATE {
1633  x y
1634  }
1635  LINEAR lin {
1636  LOCAL a
1637  if (a == 1) {
1638  ~ x = y + a
1639  ~ y = a
1640  }
1641  })";
1642  std::string expected_result = R"(
1643  LINEAR lin {
1644  LOCAL a
1645  IF (a == 1) {
1646  x = 2.0*a
1647  y = a
1648  }
1649  })";
1650 
1651  THEN("Construct & solve linear system") {
1652  auto result =
1653  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1654 
1655  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1656  }
1657  }
1658  GIVEN("Linear block, linear equations mixed with control flow blocks and reassignments") {
1659  std::string nmodl_text = R"(
1660  STATE {
1661  x y
1662  }
1663  LINEAR lin {
1664  LOCAL a
1665  ~ x = y + a
1666  if (a == 1) {
1667  a = a + 1
1668  x = a + 1
1669  }
1670  ~ y = a
1671  })";
1672  std::string expected_result = R"(
1673  LINEAR lin {
1674  LOCAL a
1675  x = 2.0*a
1676  IF (a == 1) {
1677  a = a+1
1678  x = a+1
1679  }
1680  y = a
1681  })";
1682 
1683  THEN("Construct & solve linear system") {
1684  auto result =
1685  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1686 
1687  compare_blocks(reindent_text(result[0]), reindent_text(expected_result));
1688  }
1689  }
1690  GIVEN("4 state-var LINEAR solve block") {
1691  std::string nmodl_text = R"(
1692  STATE {
1693  w x y z
1694  }
1695  LINEAR lin {
1696  ~ w + z/3.2 = -2.0*y
1697  ~ x + 4*c*y = -5.343*a
1698  ~ a + x/b + z - y = 0.842*b*b
1699  ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1700  })";
1701  std::string expected_text = R"(
1702  LINEAR lin {
1703  EIGEN_LINEAR_SOLVE[4]{
1704  }{
1705  }{
1706  nmodl_eigen_x[0] = w
1707  nmodl_eigen_x[1] = x
1708  nmodl_eigen_x[2] = y
1709  nmodl_eigen_x[3] = z
1710  nmodl_eigen_f[0] = 0
1711  nmodl_eigen_f[1] = 5.343*a
1712  nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1713  nmodl_eigen_f[3] = -1.43543/c
1714  nmodl_eigen_j[0] = -1.0
1715  nmodl_eigen_j[4] = 0
1716  nmodl_eigen_j[8] = -2.0
1717  nmodl_eigen_j[12] = -0.3125
1718  nmodl_eigen_j[1] = 0
1719  nmodl_eigen_j[5] = -1.0
1720  nmodl_eigen_j[9] = -4.0*c
1721  nmodl_eigen_j[13] = 0
1722  nmodl_eigen_j[2] = 0
1723  nmodl_eigen_j[6] = -1/b
1724  nmodl_eigen_j[10] = 1.0
1725  nmodl_eigen_j[14] = -1.0
1726  nmodl_eigen_j[3] = 0
1727  nmodl_eigen_j[7] = -1.0
1728  nmodl_eigen_j[11] = -1.3
1729  nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1730  }{
1731  w = nmodl_eigen_x[0]
1732  x = nmodl_eigen_x[1]
1733  y = nmodl_eigen_x[2]
1734  z = nmodl_eigen_x[3]
1735  }{
1736  }
1737  })";
1738  THEN("return matrix system to solve") {
1739  auto result =
1740  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1741  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1742  }
1743  }
1744 
1745  GIVEN("LINEAR solve block with an explicit SOLVEFOR statement") {
1746  std::string nmodl_text = R"(
1747  STATE {
1748  x
1749  y
1750  z
1751  }
1752  LINEAR lin SOLVEFOR x, y {
1753  ~ 3 * x = v - y
1754  ~ x = z * y - 5
1755  })";
1756  std::string expected_text = R"(
1757  LINEAR lin SOLVEFOR x,y{
1758  y = (v+15.0)/(3.0*z+1.0)
1759  x = (v*z-5.0)/(3.0*z+1.0)
1760  })";
1761  THEN("solve analytically") {
1762  auto result =
1763  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::LINEAR_BLOCK);
1764  REQUIRE(reindent_text(result[0]) == reindent_text(expected_text));
1765  }
1766  }
1767 }
1768 
1769 //=============================================================================
1770 // NONLINEAR solve block tests
1771 //=============================================================================
1772 
1773 SCENARIO("Solve NONLINEAR block using SympySolver Visitor", "[visitor][solver][sympy][nonlinear]") {
1774  GIVEN("1 state-var numeric NONLINEAR solve block") {
1775  std::string nmodl_text = R"(
1776  STATE {
1777  x
1778  }
1779  NONLINEAR nonlin {
1780  ~ x = 5
1781  })";
1782  std::string expected_text = R"(
1783  NONLINEAR nonlin {
1784  EIGEN_NEWTON_SOLVE[1]{
1785  }{
1786  }{
1787  nmodl_eigen_x[0] = x
1788  }{
1789  nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1790  nmodl_eigen_j[0] = -1.0
1791  }{
1792  x = nmodl_eigen_x[0]
1793  }{
1794  }
1795  })";
1796 
1797  THEN("return F & J for newton solver") {
1798  auto result =
1799  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1800  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1801  }
1802  }
1803  GIVEN("array state-var numeric NONLINEAR solve block") {
1804  std::string nmodl_text = R"(
1805  STATE {
1806  s[3]
1807  }
1808  NONLINEAR nonlin {
1809  ~ s[0] = 1
1810  ~ s[1] = 3
1811  ~ s[2] + s[1] = s[0]
1812  })";
1813  std::string expected_text = R"(
1814  NONLINEAR nonlin {
1815  EIGEN_NEWTON_SOLVE[3]{
1816  }{
1817  }{
1818  nmodl_eigen_x[0] = s[0]
1819  nmodl_eigen_x[1] = s[1]
1820  nmodl_eigen_x[2] = s[2]
1821  }{
1822  nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1823  nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1824  nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1825  nmodl_eigen_j[0] = -1.0
1826  nmodl_eigen_j[3] = 0
1827  nmodl_eigen_j[6] = 0
1828  nmodl_eigen_j[1] = 0
1829  nmodl_eigen_j[4] = -1.0
1830  nmodl_eigen_j[7] = 0
1831  nmodl_eigen_j[2] = 1.0
1832  nmodl_eigen_j[5] = -1.0
1833  nmodl_eigen_j[8] = -1.0
1834  }{
1835  s[0] = nmodl_eigen_x[0]
1836  s[1] = nmodl_eigen_x[1]
1837  s[2] = nmodl_eigen_x[2]
1838  }{
1839  }
1840  })";
1841  THEN("return F & J for newton solver") {
1842  auto result =
1843  run_sympy_solver_visitor(nmodl_text, false, false, AstNodeType::NON_LINEAR_BLOCK);
1844  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1845  }
1846  }
1847 }
1848 SCENARIO("Solve KINETIC block using SympySolver Visitor", "[visitor][solver][sympy][kinetic]") {
1849  GIVEN("KINETIC block with not inlined function should work") {
1850  std::string nmodl_text = R"(
1851  BREAKPOINT {
1852  SOLVE kstates METHOD sparse
1853  }
1854  STATE {
1855  C1
1856  C2
1857  }
1858  FUNCTION alfa(v(mV)) {
1859  alfa = v
1860  }
1861  KINETIC kstates {
1862  ~ C1 <-> C2 (alfa(v), alfa(v))
1863  })";
1864  std::string expected_text = R"(
1865  DERIVATIVE kstates {
1866  EIGEN_NEWTON_SOLVE[2]{
1867  LOCAL kf0_, kb0_, old_C1, old_C2
1868  }{
1869  kb0_ = alfa(v)
1870  kf0_ = alfa(v)
1871  old_C1 = C1
1872  old_C2 = C2
1873  }{
1874  nmodl_eigen_x[0] = C1
1875  nmodl_eigen_x[1] = C2
1876  }{
1877  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1878  nmodl_eigen_j[0] = -kf0_-1/dt
1879  nmodl_eigen_j[2] = kb0_
1880  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1881  nmodl_eigen_j[1] = kf0_
1882  nmodl_eigen_j[3] = -kb0_-1/dt
1883  }{
1884  C1 = nmodl_eigen_x[0]
1885  C2 = nmodl_eigen_x[1]
1886  }{
1887  }
1888  })";
1889  THEN("Run Kinetic and Sympy Visitor") {
1890  std::vector<std::string> result;
1891  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1892  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1893  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1894  }
1895  }
1896  GIVEN("Protected names in Sympy are respected") {
1897  std::string nmodl_text = R"(
1898  BREAKPOINT {
1899  SOLVE kstates METHOD sparse
1900  }
1901  STATE {
1902  C1
1903  C2
1904  }
1905  FUNCTION beta(v(mV)) {
1906  beta = v
1907  }
1908  FUNCTION lowergamma(v(mV)) {
1909  lowergamma = v
1910  }
1911  KINETIC kstates {
1912  ~ C1 <-> C2 (beta(v), lowergamma(v))
1913  })";
1914  std::string expected_text = R"(
1915  DERIVATIVE kstates {
1916  EIGEN_NEWTON_SOLVE[2]{
1917  LOCAL kf0_, kb0_, old_C1, old_C2
1918  }{
1919  kf0_ = beta(v)
1920  kb0_ = lowergamma(v)
1921  old_C1 = C1
1922  old_C2 = C2
1923  }{
1924  nmodl_eigen_x[0] = C1
1925  nmodl_eigen_x[1] = C2
1926  }{
1927  nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1928  nmodl_eigen_j[0] = -kf0_-1/dt
1929  nmodl_eigen_j[2] = kb0_
1930  nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1931  nmodl_eigen_j[1] = kf0_
1932  nmodl_eigen_j[3] = -kb0_-1/dt
1933  }{
1934  C1 = nmodl_eigen_x[0]
1935  C2 = nmodl_eigen_x[1]
1936  }{
1937  }
1938  })";
1939  THEN("Run Kinetic and Sympy Visitor") {
1940  std::vector<std::string> result;
1941  REQUIRE_NOTHROW(result = run_sympy_solver_visitor(
1942  nmodl_text, false, false, AstNodeType::DERIVATIVE_BLOCK, true));
1943  compare_blocks(reindent_text(result[0]), reindent_text(expected_text));
1944  }
1945  }
1946 }
1947 
1948 SCENARIO("Replace unimplementable cnexp solution with derivimplicit solution",
1949  "[visitor][sympy][cnexp][derivimplicit]") {
1950  GIVEN("Derivative block that has a LambertW analytic solution") {
1951  std::string nmodl_text = R"(
1952  STATE {
1953  a
1954  }
1955  BREAKPOINT {
1956  SOLVE states METHOD cnexp
1957  }
1958  DERIVATIVE states {
1959  a' = -a/(1 + a)
1960  }
1961  )";
1962  THEN("The method has been replaced with derivimplicit") {
1964  REQUIRE_THAT(to_nmodl(result), Catch::Matchers::ContainsSubstring("derivimplicit"));
1965  }
1966  }
1967 }
Visitor for checking parents of ast nodes
Represents top level AST node for whole NMODL input.
Definition: program.hpp:39
Class that binds all pieces together for parsing nmodl file.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Perform constant folding of integer/float/double expressions.
Visitor for kinetic block statements
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for printing AST back to NMODL
void visit_program(const ast::Program &node) override
visit node of type ast::Program
Visitor for systems of algebraic and differential equations
void visit_program(ast::Program &node) override
visit node of type ast::Program
Concrete visitor for constructing symbol table from AST.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for checking parents of ast nodes
int check_ast(const ast::Ast &node)
A small wrapper to have a nicer call in parser.cpp.
Visitor for printing C++ code compatible with legacy api of CoreNEURON
Perform constant folding of integer/float/double expressions.
int nmodl_text
Definition: modl.cpp:58
AstNodeType
Enum type for every AST node type.
Definition: ast_decl.hpp:166
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Definition: unit_driver.cpp:40
Visitor to inline local procedure and function calls
Visitor for kinetic block statements
Unroll for loop in the AST.
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
Definition: test_utils.cpp:55
encapsulates code generation backend implementations
Definition: ast_common.hpp:26
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Visitor that solves ODEs using old solvers of NEURON
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
static Node * node(Object *)
Definition: netcvode.cpp:291
static double remove(void *v)
Definition: ocdeck.cpp:205
#define text
Definition: plot.cpp:60
Auto generated AST classes declaration.
Replace solve block statements with actual solution node in the AST.
void compare_blocks(const std::string &result, const std::string &expected, const bool require_fail=false)
Compare nmodl blocks that contain systems of equations (i.e.
std::string ast_to_string(ast::Program &node)
std::vector< std::string > run_sympy_solver_visitor(const std::string &text, bool pade=false, bool cse=false, AstNodeType ret_nodetype=AstNodeType::DIFF_EQ_EXPRESSION, bool kinetic=false)
bool is_unique_vars(std::string result)
auto run_sympy_solver_visitor_ast(const std::string &text, bool pade=false, bool cse=false, bool kinetic=false)
SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]")
void run_sympy_visitor_passes(ast::Program &node)
Visitor for systems of algebraic and differential equations
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::parser::UnitDriver driver
Definition: parser.cpp:28