8 #include <catch2/catch_test_macros.hpp>
9 #include <catch2/matchers/catch_matchers_string.hpp>
11 #include <pybind11/embed.h>
12 #include <pybind11/stl.h>
30 using namespace nmodl;
32 using namespace visitor;
34 using namespace test_utils;
36 using Catch::Matchers::ContainsSubstring;
51 bool kinetic =
false) {
79 const std::string&
text,
82 AstNodeType ret_nodetype = AstNodeType::DIFF_EQ_EXPRESSION,
83 bool kinetic =
false) {
84 std::vector<std::string> results;
99 std::stringstream ss(
result);
102 std::unordered_set<std::string> old_vars;
104 while (getline(ss, token,
' ')) {
105 if (!old_vars.insert(token).second) {
139 const std::string& expected,
140 const bool require_fail =
false) {
141 using namespace pybind11::literals;
144 pybind11::dict(
"result"_a =
result,
"expected"_a = expected,
"is_equal"_a =
false);
146 # Comments are in the doxygen for better highlighting
147 def compare_blocks(result, expected):
151 d = {'\[(\d+)\]':'_\\1', 'pow\((\w+), ?(\d+)\)':'\\1**\\2', 'beta': 'beta_var', 'gamma': 'gamma_var'}
153 for key, val in d.items():
154 out = re.sub(key, val, out)
157 def compare_systems_of_eq(result_dict, expected_dict):
158 from sympy.parsing.sympy_parser import parse_expr
160 for k, v in result_dict.items():
161 if parse_expr(f'simplify(({v})-({expected_dict[k]}))'):
167 expected_dict.clear()
175 # split of sout and a dict with the tmp variables
176 for line in s.split('\n'):
177 line_split = line.lstrip().split('=')
179 if len(line_split) == 2 and line_split[0].startswith('tmp_'):
180 # back-substitution of tmp variables in tmp variables
181 tmp_var = line_split[0].strip()
185 max_tmp = max(max_tmp, int(tmp_var[4:]))
186 for k, v in d.items():
187 line_split[1] = line_split[1].replace(k, f'({v})')
188 d[tmp_var] = line_split[1]
189 elif 'LOCAL' in line:
190 sout += line.split('tmp_0')[0] + '\n'
194 # Back-substitution of the tmps
195 # so that we do not replace tmp_11 with (tmp_1)1
196 for j in range(max_tmp, -1, -1):
198 sout = sout.replace(k, f'({d[k]})')
202 result = reduce(sanitize(result)).split('\n')
203 expected = reduce(sanitize(expected)).split('\n')
205 if len(result) != len(expected):
210 for token1, token2 in zip(result, expected):
212 if not compare_systems_of_eq(result_dict, expected_dict):
216 eq1 = token1.split('=')
217 eq2 = token2.split('=')
218 if len(eq1) == 2 and len(eq2) == 2:
219 result_dict[eq1[0]] = eq1[1]
220 expected_dict[eq2[0]] = eq2[1]
224 return compare_systems_of_eq(result_dict, expected_dict)
226 is_equal = compare_blocks(result, expected))",
231 if (require_fail == locals[
"is_equal"].cast<bool>()) {
233 REQUIRE(
result != expected);
235 REQUIRE(
result == expected);
262 std::stringstream stream;
267 SCENARIO(
"Check compare_blocks in sympy unit tests",
"[visitor][sympy]") {
268 GIVEN(
"Empty strings") {
269 THEN(
"Strings are equal") {
273 GIVEN(
"Equivalent equation") {
274 THEN(
"Strings are equal") {
278 GIVEN(
"Equivalent systems of equations") {
282 std::string expected = R"(
285 THEN("Systems of equations are equal") {
289 GIVEN(
"Equivalent systems of equations with brackets") {
295 std::string expected = R"(
300 y = pow(a, 2)*a + 2*b-b
302 THEN("Blocks are equal") {
306 GIVEN(
"Different systems of equations (additional space)") {
312 std::string expected = R"(
317 THEN("Blocks are different") {
321 GIVEN(
"Different systems of equations") {
329 std::string expected = R"(
334 THEN("Blocks are different") {
340 SCENARIO(
"Check local vars name-clash prevention",
"[visitor][sympy]") {
347 SOLVE states METHOD sparse
354 THEN("There are no duplicate vars in LOCAL") {
361 GIVEN(
"LOCAL tmp_0") {
367 SOLVE states METHOD sparse
374 THEN("There are no duplicate vars in LOCAL") {
383 SCENARIO(
"Solve ODEs with cnexp or euler method using SympySolverVisitor",
384 "[visitor][sympy][cnexp][euler]") {
385 GIVEN(
"Derivative block without ODE, solver method cnexp") {
388 SOLVE states METHOD cnexp
394 THEN("No ODEs found - do nothing") {
399 GIVEN(
"Derivative block with ODES, solver method is euler") {
402 SOLVE states METHOD euler
410 THEN("Construct forwards Euler solutions") {
412 REQUIRE(
result.size() == 2);
413 REQUIRE(
result[0] ==
"m = (-dt*(m-mInf)+m*mTau)/mTau");
414 REQUIRE(
result[1] ==
"h = (-dt*(h-hInf)+h*hTau)/hTau");
417 GIVEN(
"Derivative block with calling external functions passes sympy") {
420 SOLVE states METHOD euler
428 THEN("Construct forward Euler interpreting external functions as symbols") {
430 REQUIRE(
result.size() == 3);
431 REQUIRE(
result[0] ==
"m = dt*sawtooth(m)+m");
432 REQUIRE(
result[1] ==
"n = dt*sin(n)+n");
433 REQUIRE(
result[2] ==
"p = dt*my_user_func(p)+p");
436 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method euler") {
442 SOLVE states METHOD euler
445 m'[0] = (mInf-m[0])/mTau
448 THEN("Construct forwards Euler solutions") {
450 REQUIRE(
result.size() == 1);
451 REQUIRE(
result[0] ==
"m[0] = (dt*(mInf-m[0])+mTau*m[0])/mTau");
454 GIVEN(
"Derivative block with ODE, 1 state var in array, solver method cnexp") {
460 SOLVE states METHOD cnexp
463 m'[0] = (mInf-m[0])/mTau
466 THEN("Construct forwards Euler solutions") {
468 REQUIRE(
result.size() == 1);
469 REQUIRE(
result[0] ==
"m[0] = mInf-(mInf-m[0])*exp(-dt/mTau)");
472 GIVEN(
"Derivative block with linear ODES, solver method cnexp") {
475 SOLVE states METHOD cnexp
480 h' = hInf/hTau - h/hTau
483 THEN("Integrate equations analytically") {
485 REQUIRE(
result.size() == 2);
486 REQUIRE(
result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
487 REQUIRE(
result[1] ==
"h = hInf-(-h+hInf)*exp(-dt/hTau)");
490 GIVEN(
"Derivative block including non-linear but solvable ODES, solver method cnexp") {
493 SOLVE states METHOD cnexp
500 THEN("Integrate equations analytically") {
502 REQUIRE(
result.size() == 2);
503 REQUIRE(
result[0] ==
"m = mInf-(-m+mInf)*exp(-dt/mTau)");
504 REQUIRE(
result[1] ==
"h = -h/(c2*dt*h-1.0)");
507 GIVEN(
"Derivative block including array of 2 state vars, solver method cnexp") {
510 SOLVE states METHOD cnexp
516 X'[0] = (mInf-X[0])/mTau
517 X'[1] = c2 * X[1]*X[1]
520 THEN("Integrate equations analytically") {
522 REQUIRE(
result.size() == 2);
523 REQUIRE(
result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau)");
524 REQUIRE(
result[1] ==
"X[1] = -X[1]/(c2*dt*X[1]-1.0)");
527 GIVEN(
"Derivative block including loop over array vars, solver method cnexp") {
531 SOLVE states METHOD cnexp
541 X'[i] = (mInf-X[i])/mTau[i]
545 THEN("Integrate equations analytically") {
547 REQUIRE(
result.size() == 3);
548 REQUIRE(
result[0] ==
"X[0] = mInf-(mInf-X[0])*exp(-dt/mTau[0])");
549 REQUIRE(
result[1] ==
"X[1] = mInf-(mInf-X[1])*exp(-dt/mTau[1])");
550 REQUIRE(
result[2] ==
"X[2] = mInf-(mInf-X[2])*exp(-dt/mTau[2])");
553 GIVEN(
"Derivative block including loop over array vars, solver method euler") {
557 SOLVE states METHOD euler
567 X'[i] = (mInf-X[i])/mTau[i]
571 THEN("Integrate equations analytically") {
573 REQUIRE(
result.size() == 3);
574 REQUIRE(
result[0] ==
"X[0] = (dt*(mInf-X[0])+X[0]*mTau[0])/mTau[0]");
575 REQUIRE(
result[1] ==
"X[1] = (dt*(mInf-X[1])+X[1]*mTau[1])/mTau[1]");
576 REQUIRE(
result[2] ==
"X[2] = (dt*(mInf-X[2])+X[2]*mTau[2])/mTau[2]");
579 GIVEN(
"Derivative block including ODES that can't currently be solved, solver method cnexp") {
582 SOLVE states METHOD cnexp
591 THEN("Integrate equations analytically where possible, otherwise leave untouched") {
593 REQUIRE(
result.size() == 4);
595 REQUIRE((
result[0] ==
"z' = a/z+b/z/z" ||
597 "z = (0.5*pow(a, 2)*pow(z, 2)-a*b*z+pow(b, 2)*log(a*z+b))/pow(a, 3)"));
598 REQUIRE(
result[1] ==
"h = -h/(c2*dt*h-1.0)");
599 REQUIRE(
result[2] ==
"x = a*dt+x");
601 REQUIRE((
result[3] ==
"y' = c3*y*y*y" ||
602 result[3] ==
"y = sqrt(-pow(y, 2)/(2.0*c3*dt*pow(y, 2)-1.0))"));
605 GIVEN(
"Derivative block with cnexp solver method, AST after SympySolver pass") {
608 SOLVE states METHOD cnexp
626 THEN(
"More SympySolver passes do nothing to the AST and don't throw") {
633 SCENARIO(
"Solve ODEs with derivimplicit method using SympySolverVisitor",
634 "[visitor][sympy][derivimplicit]") {
635 GIVEN(
"Derivative block with derivimplicit solver method and conditional block") {
641 SOLVE states METHOD derivimplicit
650 std::string expected_result = R"(
652 EIGEN_NEWTON_SOLVE[1]{
662 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+mInf)+mTau*(-nmodl_eigen_x[0]+old_m))/(dt*mTau)
663 nmodl_eigen_j[0] = (-dt-mTau)/(dt*mTau)
669 THEN("SympySolver correctly inserts ode to block") {
677 GIVEN(
"Derivative block, sparse, print in order") {
683 SOLVE states METHOD sparse
690 std::string expected_result = R"(
692 EIGEN_NEWTON_SOLVE[2]{
693 LOCAL a, b, old_y, old_x
701 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_y)/dt
703 nmodl_eigen_j[2] = -1/dt
704 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_x)/dt
705 nmodl_eigen_j[1] = -1/dt
714 THEN("Construct & solve linear system for backwards Euler") {
721 GIVEN(
"Derivative block, sparse, print in order, vectors") {
727 SOLVE states METHOD sparse
734 std::string expected_result = R"(
736 EIGEN_NEWTON_SOLVE[2]{
737 LOCAL a, b, old_M_1, old_M_0
742 nmodl_eigen_x[0] = M[0]
743 nmodl_eigen_x[1] = M[1]
745 nmodl_eigen_f[0] = (-nmodl_eigen_x[1]+a*dt+old_M_1)/dt
747 nmodl_eigen_j[2] = -1/dt
748 nmodl_eigen_f[1] = (-nmodl_eigen_x[0]+b*dt+old_M_0)/dt
749 nmodl_eigen_j[1] = -1/dt
752 M[0] = nmodl_eigen_x[0]
753 M[1] = nmodl_eigen_x[1]
758 THEN("Construct & solve linear system for backwards Euler") {
765 GIVEN(
"Derivative block, sparse, derivatives mixed with local variable reassignment") {
771 SOLVE states METHOD sparse
779 std::string expected_result = R"(
781 EIGEN_NEWTON_SOLVE[2]{
782 LOCAL a, b, old_x, old_y
790 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
791 nmodl_eigen_j[0] = -1/dt
794 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
796 nmodl_eigen_j[3] = -1/dt
804 THEN("Construct & solve linear system for backwards Euler") {
812 "Throw exception during derivative variable reassignment interleaved in the differential "
819 SOLVE states METHOD sparse
829 "Throw an error because state variable assignments are not allowed inside the system "
834 Catch::Matchers::ContainsSubstring(
835 "State variable assignment(s) interleaved in system of "
836 "equations/differential equations") &&
837 Catch::Matchers::StartsWith(
"SympyReplaceSolutionsVisitor"));
840 GIVEN(
"Derivative block in control flow block") {
846 SOLVE states METHOD sparse
855 std::string expected_result = R"(
859 EIGEN_NEWTON_SOLVE[2]{
868 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+a*dt+old_x)/dt
869 nmodl_eigen_j[0] = -1/dt
871 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+b*dt+old_y)/dt
873 nmodl_eigen_j[3] = -1/dt
882 THEN("Construct & solve linear system for backwards Euler") {
890 "Derivative block, sparse, coupled derivatives mixed with reassignment and control flow "
897 SOLVE states METHOD sparse
907 std::string expected_result = R"(
909 EIGEN_NEWTON_SOLVE[2]{
910 LOCAL a, b, old_x, old_y
918 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
919 nmodl_eigen_j[0] = -1/dt
924 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
925 nmodl_eigen_j[1] = 1.0
926 nmodl_eigen_j[3] = a-1/dt
933 std::string expected_result_cse = R"(
935 EIGEN_NEWTON_SOLVE[2]{
936 LOCAL a, b, old_x, old_y
944 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[1]*a+b)+old_x)/dt
945 nmodl_eigen_j[0] = -1/dt
950 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]+nmodl_eigen_x[1]*a)+old_y)/dt
951 nmodl_eigen_j[1] = 1.0
952 nmodl_eigen_j[3] = a-1/dt
960 THEN("Construct & solve linear system for backwards Euler") {
971 GIVEN(
"Derivative block of coupled & linear ODES, solver method sparse") {
977 SOLVE states METHOD sparse
986 std::string expected_result = R"(
988 EIGEN_NEWTON_SOLVE[3]{
989 LOCAL a, b, c, d, h, old_x, old_y, old_z
999 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1000 nmodl_eigen_j[0] = -1/dt
1001 nmodl_eigen_j[3] = 0
1002 nmodl_eigen_j[6] = a
1003 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1004 nmodl_eigen_j[1] = 2.0
1005 nmodl_eigen_j[4] = -1/dt
1006 nmodl_eigen_j[7] = 0
1007 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1008 nmodl_eigen_j[2] = 0
1009 nmodl_eigen_j[5] = -1.0
1010 nmodl_eigen_j[8] = d-1/dt
1012 x = nmodl_eigen_x[0]
1013 y = nmodl_eigen_x[1]
1014 z = nmodl_eigen_x[2]
1018 std::string expected_cse_result = R"(
1020 EIGEN_NEWTON_SOLVE[3]{
1021 LOCAL a, b, c, d, h, old_x, old_y, old_z
1027 nmodl_eigen_x[0] = x
1028 nmodl_eigen_x[1] = y
1029 nmodl_eigen_x[2] = z
1031 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(nmodl_eigen_x[2]*a+b*h)+old_x)/dt
1032 nmodl_eigen_j[0] = -1/dt
1033 nmodl_eigen_j[3] = 0
1034 nmodl_eigen_j[6] = a
1035 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(2.0*nmodl_eigen_x[0]+c)+old_y)/dt
1036 nmodl_eigen_j[1] = 2.0
1037 nmodl_eigen_j[4] = -1/dt
1038 nmodl_eigen_j[7] = 0
1039 nmodl_eigen_f[2] = (-nmodl_eigen_x[2]+dt*(-nmodl_eigen_x[1]+nmodl_eigen_x[2]*d)+old_z)/dt
1040 nmodl_eigen_j[2] = 0
1041 nmodl_eigen_j[5] = -1.0
1042 nmodl_eigen_j[8] = d-1/dt
1044 x = nmodl_eigen_x[0]
1045 y = nmodl_eigen_x[1]
1046 z = nmodl_eigen_x[2]
1051 THEN("Construct & solve linear system for backwards Euler") {
1061 GIVEN(
"Derivative block including ODES with sparse method (from nmodl paper)") {
1067 SOLVE scheme1 METHOD sparse
1069 DERIVATIVE scheme1 {
1074 std::string expected_result = R"(
1075 DERIVATIVE scheme1 {
1076 EIGEN_NEWTON_SOLVE[2]{
1082 nmodl_eigen_x[0] = mc
1083 nmodl_eigen_x[1] = m
1085 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1086 nmodl_eigen_j[0] = -a-1/dt
1087 nmodl_eigen_j[2] = b
1088 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1089 nmodl_eigen_j[1] = a
1090 nmodl_eigen_j[3] = -b-1/dt
1092 mc = nmodl_eigen_x[0]
1093 m = nmodl_eigen_x[1]
1097 THEN("Construct & solve linear system") {
1104 GIVEN(
"Derivative block with ODES with sparse method, CONSERVE statement of form m = ...") {
1110 SOLVE scheme1 METHOD sparse
1112 DERIVATIVE scheme1 {
1118 std::string expected_result = R"(
1119 DERIVATIVE scheme1 {
1120 EIGEN_NEWTON_SOLVE[2]{
1125 nmodl_eigen_x[0] = mc
1126 nmodl_eigen_x[1] = m
1128 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1129 nmodl_eigen_j[0] = -a-1/dt
1130 nmodl_eigen_j[2] = b
1131 nmodl_eigen_f[1] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]+1.0
1132 nmodl_eigen_j[1] = -1.0
1133 nmodl_eigen_j[3] = -1.0
1135 mc = nmodl_eigen_x[0]
1136 m = nmodl_eigen_x[1]
1140 THEN("Construct & solve linear system, replace ODE for m with rhs of CONSERVE statement") {
1148 "Derivative block with ODES with sparse method, invalid CONSERVE statement of form m + mc "
1155 SOLVE scheme1 METHOD sparse
1157 DERIVATIVE scheme1 {
1163 std::string expected_result = R"(
1164 DERIVATIVE scheme1 {
1165 EIGEN_NEWTON_SOLVE[2]{
1171 nmodl_eigen_x[0] = mc
1172 nmodl_eigen_x[1] = m
1174 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*a+nmodl_eigen_x[1]*b)+old_mc)/dt
1175 nmodl_eigen_j[0] = -a-1/dt
1176 nmodl_eigen_j[2] = b
1177 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*a-nmodl_eigen_x[1]*b)+old_m)/dt
1178 nmodl_eigen_j[1] = a
1179 nmodl_eigen_j[3] = -b-1/dt
1181 mc = nmodl_eigen_x[0]
1182 m = nmodl_eigen_x[1]
1186 THEN("Construct & solve linear system, ignore invalid CONSERVE statement") {
1193 GIVEN(
"Derivative block with ODES with sparse method, two CONSERVE statements") {
1199 SOLVE ihkin METHOD sparse
1202 LOCAL alpha, beta, k3p, k4, k1ca, k2
1203 evaluate_fct(v, cai)
1205 CONSERVE o2 = 1-c1-o1
1206 c1' = (-1*(alpha*c1-beta*o1))
1207 o1' = (1*(alpha*c1-beta*o1))+(-1*(k3p*o1-k4*o2))
1208 o2' = (1*(k3p*o1-k4*o2))
1209 p0' = (-1*(k1ca*p0-k2*p1))
1210 p1' = (1*(k1ca*p0-k2*p1))
1212 std::string expected_result = R"(
1214 EIGEN_NEWTON_SOLVE[5]{
1215 LOCAL alpha, beta, k3p, k4, k1ca, k2, old_c1, old_o1, old_p0
1217 evaluate_fct(v, cai)
1222 nmodl_eigen_x[0] = c1
1223 nmodl_eigen_x[1] = o1
1224 nmodl_eigen_x[2] = o2
1225 nmodl_eigen_x[3] = p0
1226 nmodl_eigen_x[4] = p1
1228 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*alpha+nmodl_eigen_x[1]*beta)+old_c1)/dt
1229 nmodl_eigen_j[0] = -alpha-1/dt
1230 nmodl_eigen_j[5] = beta
1231 nmodl_eigen_j[10] = 0
1232 nmodl_eigen_j[15] = 0
1233 nmodl_eigen_j[20] = 0
1234 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*alpha-nmodl_eigen_x[1]*beta-nmodl_eigen_x[1]*k3p+nmodl_eigen_x[2]*k4)+old_o1)/dt
1235 nmodl_eigen_j[1] = alpha
1236 nmodl_eigen_j[6] = -beta-k3p-1/dt
1237 nmodl_eigen_j[11] = k4
1238 nmodl_eigen_j[16] = 0
1239 nmodl_eigen_j[21] = 0
1240 nmodl_eigen_f[2] = -nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]+1.0
1241 nmodl_eigen_j[2] = -1.0
1242 nmodl_eigen_j[7] = -1.0
1243 nmodl_eigen_j[12] = -1.0
1244 nmodl_eigen_j[17] = 0
1245 nmodl_eigen_j[22] = 0
1246 nmodl_eigen_f[3] = (-nmodl_eigen_x[3]+dt*(-nmodl_eigen_x[3]*k1ca+nmodl_eigen_x[4]*k2)+old_p0)/dt
1247 nmodl_eigen_j[3] = 0
1248 nmodl_eigen_j[8] = 0
1249 nmodl_eigen_j[13] = 0
1250 nmodl_eigen_j[18] = -k1ca-1/dt
1251 nmodl_eigen_j[23] = k2
1252 nmodl_eigen_f[4] = -nmodl_eigen_x[3]-nmodl_eigen_x[4]+1.0
1253 nmodl_eigen_j[4] = 0
1254 nmodl_eigen_j[9] = 0
1255 nmodl_eigen_j[14] = 0
1256 nmodl_eigen_j[19] = -1.0
1257 nmodl_eigen_j[24] = -1.0
1259 c1 = nmodl_eigen_x[0]
1260 o1 = nmodl_eigen_x[1]
1261 o2 = nmodl_eigen_x[2]
1262 p0 = nmodl_eigen_x[3]
1263 p1 = nmodl_eigen_x[4]
1268 "Construct & solve linear system, replacing ODEs for p1 and o2 with CONSERVE statement "
1269 "algebraic relations") {
1276 GIVEN(
"Derivative block including ODES with sparse method - single var in array") {
1286 SOLVE scheme1 METHOD sparse
1288 DERIVATIVE scheme1 {
1289 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1292 std::string expected_result = R"(
1293 DERIVATIVE scheme1 {
1294 EIGEN_NEWTON_SOLVE[1]{
1299 nmodl_eigen_x[0] = W[0]
1301 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1302 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1304 W[0] = nmodl_eigen_x[0]
1308 THEN("Construct & solver linear system") {
1315 GIVEN(
"Derivative block including ODES with sparse method - array vars") {
1325 SOLVE scheme1 METHOD sparse
1327 DERIVATIVE scheme1 {
1328 M'[0] = -A[0]*M[0] + B[0]*M[1]
1329 M'[1] = A[1]*M[0] - B[1]*M[1]
1332 std::string expected_result = R"(
1333 DERIVATIVE scheme1 {
1334 EIGEN_NEWTON_SOLVE[2]{
1335 LOCAL old_M_0, old_M_1
1340 nmodl_eigen_x[0] = M[0]
1341 nmodl_eigen_x[1] = M[1]
1343 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[1]*B[0])+old_M_0)/dt
1344 nmodl_eigen_j[0] = -A[0]-1/dt
1345 nmodl_eigen_j[2] = B[0]
1346 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*A[1]-nmodl_eigen_x[1]*B[1])+old_M_1)/dt
1347 nmodl_eigen_j[1] = A[1]
1348 nmodl_eigen_j[3] = -B[1]-1/dt
1350 M[0] = nmodl_eigen_x[0]
1351 M[1] = nmodl_eigen_x[1]
1355 THEN("Construct & solver linear system") {
1362 GIVEN(
"Derivative block including ODES with derivimplicit method - single var in array") {
1372 SOLVE scheme1 METHOD derivimplicit
1374 DERIVATIVE scheme1 {
1375 W'[0] = -A[0]*W[0] + B[0]*W[0] + 3*A[1]
1378 std::string expected_result = R"(
1379 DERIVATIVE scheme1 {
1380 EIGEN_NEWTON_SOLVE[1]{
1385 nmodl_eigen_x[0] = W[0]
1387 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*A[0]+nmodl_eigen_x[0]*B[0]+3.0*A[1])+old_W_0)/dt
1388 nmodl_eigen_j[0] = -A[0]+B[0]-1/dt
1390 W[0] = nmodl_eigen_x[0]
1394 THEN("Construct newton solve block") {
1401 GIVEN(
"Derivative block including ODES with derivimplicit method") {
1407 SOLVE states METHOD derivimplicit
1411 m' = (minf-m)/mtau - 3*h
1412 h' = (hinf-h)/htau + m*m
1417 std::string expected_result = R
"(
1419 EIGEN_NEWTON_SOLVE[3]{
1420 LOCAL old_m, old_h, old_n
1427 nmodl_eigen_x[0] = m
1428 nmodl_eigen_x[1] = h
1429 nmodl_eigen_x[2] = n
1431 nmodl_eigen_f[0] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt-3.0*nmodl_eigen_x[1]+minf/mtau+old_m/dt
1432 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1433 nmodl_eigen_j[3] = -3.0
1434 nmodl_eigen_j[6] = 0
1435 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1436 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1437 nmodl_eigen_j[4] = (-dt-htau)/(dt*htau)
1438 nmodl_eigen_j[7] = 0
1439 nmodl_eigen_f[2] = (dt*(-nmodl_eigen_x[2]+ninf)+ntau*(-nmodl_eigen_x[2]+old_n))/(dt*ntau)
1440 nmodl_eigen_j[2] = 0
1441 nmodl_eigen_j[5] = 0
1442 nmodl_eigen_j[8] = (-dt-ntau)/(dt*ntau)
1444 m = nmodl_eigen_x[0]
1445 h = nmodl_eigen_x[1]
1446 n = nmodl_eigen_x[2]
1450 THEN("Construct newton solve block") {
1457 GIVEN(
"Multiple derivative blocks each with derivimplicit method") {
1463 SOLVE states1 METHOD derivimplicit
1464 SOLVE states2 METHOD derivimplicit
1467 DERIVATIVE states1 {
1469 h' = (hinf-h)/htau + m*m
1472 DERIVATIVE states2 {
1473 h' = (hinf-h)/htau + m*m
1474 m' = (minf-m)/mtau + h
1478 std::string expected_result_0 = R
"(
1479 DERIVATIVE states1 {
1480 EIGEN_NEWTON_SOLVE[2]{
1486 nmodl_eigen_x[0] = m
1487 nmodl_eigen_x[1] = h
1489 nmodl_eigen_f[0] = (dt*(-nmodl_eigen_x[0]+minf)+mtau*(-nmodl_eigen_x[0]+old_m))/(dt*mtau)
1490 nmodl_eigen_j[0] = (-dt-mtau)/(dt*mtau)
1491 nmodl_eigen_j[2] = 0
1492 nmodl_eigen_f[1] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau- nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1493 nmodl_eigen_j[1] = 2.0*nmodl_eigen_x[0]
1494 nmodl_eigen_j[3] = (-dt-htau)/(dt*htau)
1496 m = nmodl_eigen_x[0]
1497 h = nmodl_eigen_x[1]
1501 std::string expected_result_1 = R"(
1502 DERIVATIVE states2 {
1503 EIGEN_NEWTON_SOLVE[2]{
1509 nmodl_eigen_x[0] = m
1510 nmodl_eigen_x[1] = h
1512 nmodl_eigen_f[0] = pow(nmodl_eigen_x[0], 2)-nmodl_eigen_x[1]/htau-nmodl_eigen_x[1]/dt+hinf/htau+old_h/dt
1513 nmodl_eigen_j[0] = 2.0*nmodl_eigen_x[0]
1514 nmodl_eigen_j[2] = (-dt-htau)/(dt*htau)
1515 nmodl_eigen_f[1] = -nmodl_eigen_x[0]/mtau-nmodl_eigen_x[0]/dt+nmodl_eigen_x[1]+minf/mtau+old_m/dt
1516 nmodl_eigen_j[1] = (-dt-mtau)/(dt*mtau)
1517 nmodl_eigen_j[3] = 1.0
1519 m = nmodl_eigen_x[0]
1520 h = nmodl_eigen_x[1]
1524 THEN("Construct newton solve block") {
1539 SCENARIO(
"LINEAR solve block (SympySolver Visitor)",
"[sympy][linear]") {
1540 GIVEN(
"1 state-var symbolic LINEAR solve block") {
1548 std::string expected_text = R"(
1552 THEN("solve analytically") {
1558 GIVEN(
"2 state-var LINEAR solve block") {
1567 std::string expected_text = R"(
1572 THEN("solve analytically") {
1578 GIVEN(
"Linear block, print in order, vectors") {
1587 std::string expected_result = R"(
1593 THEN("Construct & solve linear system") {
1600 GIVEN(
"Linear block, by value replacement, interleaved") {
1613 std::string expected_result = R"(
1623 THEN("Construct & solve linear system") {
1630 GIVEN(
"Linear block in control flow block") {
1642 std::string expected_result = R"(
1651 THEN("Construct & solve linear system") {
1658 GIVEN(
"Linear block, linear equations mixed with control flow blocks and reassignments") {
1672 std::string expected_result = R"(
1683 THEN("Construct & solve linear system") {
1690 GIVEN(
"4 state-var LINEAR solve block") {
1696 ~ w + z/3.2 = -2.0*y
1697 ~ x + 4*c*y = -5.343*a
1698 ~ a + x/b + z - y = 0.842*b*b
1699 ~ x + 1.3*y - 0.1*z/(a*a*b) = 1.43543/c
1701 std::string expected_text = R"(
1703 EIGEN_LINEAR_SOLVE[4]{
1706 nmodl_eigen_x[0] = w
1707 nmodl_eigen_x[1] = x
1708 nmodl_eigen_x[2] = y
1709 nmodl_eigen_x[3] = z
1710 nmodl_eigen_f[0] = 0
1711 nmodl_eigen_f[1] = 5.343*a
1712 nmodl_eigen_f[2] = a-0.84199999999999997*pow(b, 2)
1713 nmodl_eigen_f[3] = -1.43543/c
1714 nmodl_eigen_j[0] = -1.0
1715 nmodl_eigen_j[4] = 0
1716 nmodl_eigen_j[8] = -2.0
1717 nmodl_eigen_j[12] = -0.3125
1718 nmodl_eigen_j[1] = 0
1719 nmodl_eigen_j[5] = -1.0
1720 nmodl_eigen_j[9] = -4.0*c
1721 nmodl_eigen_j[13] = 0
1722 nmodl_eigen_j[2] = 0
1723 nmodl_eigen_j[6] = -1/b
1724 nmodl_eigen_j[10] = 1.0
1725 nmodl_eigen_j[14] = -1.0
1726 nmodl_eigen_j[3] = 0
1727 nmodl_eigen_j[7] = -1.0
1728 nmodl_eigen_j[11] = -1.3
1729 nmodl_eigen_j[15] = 0.10000000000000001/(pow(a, 2)*b)
1731 w = nmodl_eigen_x[0]
1732 x = nmodl_eigen_x[1]
1733 y = nmodl_eigen_x[2]
1734 z = nmodl_eigen_x[3]
1738 THEN("return matrix system to solve") {
1745 GIVEN(
"LINEAR solve block with an explicit SOLVEFOR statement") {
1752 LINEAR lin SOLVEFOR x, y {
1756 std::string expected_text = R"(
1757 LINEAR lin SOLVEFOR x,y{
1758 y = (v+15.0)/(3.0*z+1.0)
1759 x = (v*z-5.0)/(3.0*z+1.0)
1761 THEN("solve analytically") {
1773 SCENARIO(
"Solve NONLINEAR block using SympySolver Visitor",
"[visitor][solver][sympy][nonlinear]") {
1774 GIVEN(
"1 state-var numeric NONLINEAR solve block") {
1782 std::string expected_text = R"(
1784 EIGEN_NEWTON_SOLVE[1]{
1787 nmodl_eigen_x[0] = x
1789 nmodl_eigen_f[0] = 5.0-nmodl_eigen_x[0]
1790 nmodl_eigen_j[0] = -1.0
1792 x = nmodl_eigen_x[0]
1797 THEN("return F & J for newton solver") {
1803 GIVEN(
"array state-var numeric NONLINEAR solve block") {
1811 ~ s[2] + s[1] = s[0]
1813 std::string expected_text = R"(
1815 EIGEN_NEWTON_SOLVE[3]{
1818 nmodl_eigen_x[0] = s[0]
1819 nmodl_eigen_x[1] = s[1]
1820 nmodl_eigen_x[2] = s[2]
1822 nmodl_eigen_f[0] = 1.0-nmodl_eigen_x[0]
1823 nmodl_eigen_f[1] = 3.0-nmodl_eigen_x[1]
1824 nmodl_eigen_f[2] = nmodl_eigen_x[0]-nmodl_eigen_x[1]-nmodl_eigen_x[2]
1825 nmodl_eigen_j[0] = -1.0
1826 nmodl_eigen_j[3] = 0
1827 nmodl_eigen_j[6] = 0
1828 nmodl_eigen_j[1] = 0
1829 nmodl_eigen_j[4] = -1.0
1830 nmodl_eigen_j[7] = 0
1831 nmodl_eigen_j[2] = 1.0
1832 nmodl_eigen_j[5] = -1.0
1833 nmodl_eigen_j[8] = -1.0
1835 s[0] = nmodl_eigen_x[0]
1836 s[1] = nmodl_eigen_x[1]
1837 s[2] = nmodl_eigen_x[2]
1841 THEN("return F & J for newton solver") {
1848 SCENARIO(
"Solve KINETIC block using SympySolver Visitor",
"[visitor][solver][sympy][kinetic]") {
1849 GIVEN(
"KINETIC block with not inlined function should work") {
1852 SOLVE kstates METHOD sparse
1858 FUNCTION alfa(v(mV)) {
1862 ~ C1 <-> C2 (alfa(v), alfa(v))
1864 std::string expected_text = R"(
1865 DERIVATIVE kstates {
1866 EIGEN_NEWTON_SOLVE[2]{
1867 LOCAL kf0_, kb0_, old_C1, old_C2
1874 nmodl_eigen_x[0] = C1
1875 nmodl_eigen_x[1] = C2
1877 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1878 nmodl_eigen_j[0] = -kf0_-1/dt
1879 nmodl_eigen_j[2] = kb0_
1880 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1881 nmodl_eigen_j[1] = kf0_
1882 nmodl_eigen_j[3] = -kb0_-1/dt
1884 C1 = nmodl_eigen_x[0]
1885 C2 = nmodl_eigen_x[1]
1889 THEN("Run Kinetic and Sympy Visitor") {
1890 std::vector<std::string>
result;
1892 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));
1896 GIVEN(
"Protected names in Sympy are respected") {
1899 SOLVE kstates METHOD sparse
1905 FUNCTION beta(v(mV)) {
1908 FUNCTION lowergamma(v(mV)) {
1912 ~ C1 <-> C2 (beta(v), lowergamma(v))
1914 std::string expected_text = R"(
1915 DERIVATIVE kstates {
1916 EIGEN_NEWTON_SOLVE[2]{
1917 LOCAL kf0_, kb0_, old_C1, old_C2
1920 kb0_ = lowergamma(v)
1924 nmodl_eigen_x[0] = C1
1925 nmodl_eigen_x[1] = C2
1927 nmodl_eigen_f[0] = (-nmodl_eigen_x[0]+dt*(-nmodl_eigen_x[0]*kf0_+nmodl_eigen_x[1]*kb0_)+old_C1)/dt
1928 nmodl_eigen_j[0] = -kf0_-1/dt
1929 nmodl_eigen_j[2] = kb0_
1930 nmodl_eigen_f[1] = (-nmodl_eigen_x[1]+dt*(nmodl_eigen_x[0]*kf0_-nmodl_eigen_x[1]*kb0_)+old_C2)/dt
1931 nmodl_eigen_j[1] = kf0_
1932 nmodl_eigen_j[3] = -kb0_-1/dt
1934 C1 = nmodl_eigen_x[0]
1935 C2 = nmodl_eigen_x[1]
1939 THEN("Run Kinetic and Sympy Visitor") {
1940 std::vector<std::string>
result;
1942 nmodl_text,
false,
false, AstNodeType::DERIVATIVE_BLOCK,
true));
1948 SCENARIO(
"Replace unimplementable cnexp solution with derivimplicit solution",
1949 "[visitor][sympy][cnexp][derivimplicit]") {
1950 GIVEN(
"Derivative block that has a LambertW analytic solution") {
1956 SOLVE states METHOD cnexp
1962 THEN("The method has been replaced with derivimplicit") {
1964 REQUIRE_THAT(
to_nmodl(
result), Catch::Matchers::ContainsSubstring(
"derivimplicit"));
Visitor for checking parents of ast nodes
Represents top level AST node for whole NMODL input.
Class that binds all pieces together for parsing nmodl file.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Perform constant folding of integer/float/double expressions.
Visitor for kinetic block statements
void visit_program(ast::Program &node) override
visit node of type ast::Program
Unroll for loop in the AST.
Visitor for printing AST back to NMODL
void visit_program(const ast::Program &node) override
visit node of type ast::Program
Visitor for systems of algebraic and differential equations
void visit_program(ast::Program &node) override
visit node of type ast::Program
Concrete visitor for constructing symbol table from AST.
void visit_program(ast::Program &node) override
visit node of type ast::Program
Visitor for checking parents of ast nodes
int check_ast(const ast::Ast &node)
A small wrapper to have a nicer call in parser.cpp.
Visitor for printing C++ code compatible with legacy api of CoreNEURON
Perform constant folding of integer/float/double expressions.
AstNodeType
Enum type for every AST node type.
bool parse_string(const std::string &input)
parser Units provided as string (used for testing)
Visitor to inline local procedure and function calls
Visitor for kinetic block statements
Unroll for loop in the AST.
std::string reindent_text(const std::string &text, int indent_level)
Reindent nmodl text for text-to-text comparison.
encapsulates code generation backend implementations
std::vector< std::shared_ptr< const ast::Ast > > collect_nodes(const ast::Ast &node, const std::vector< ast::AstNodeType > &types)
traverse node recursively and collect nodes of given types
std::string to_nmodl(const ast::Ast &node, const std::set< ast::AstNodeType > &exclude_types)
Given AST node, return the NMODL string representation.
Visitor that solves ODEs using old solvers of NEURON
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
static Node * node(Object *)
static double remove(void *v)
Auto generated AST classes declaration.
Replace solve block statements with actual solution node in the AST.
void compare_blocks(const std::string &result, const std::string &expected, const bool require_fail=false)
Compare nmodl blocks that contain systems of equations (i.e.
std::string ast_to_string(ast::Program &node)
std::vector< std::string > run_sympy_solver_visitor(const std::string &text, bool pade=false, bool cse=false, AstNodeType ret_nodetype=AstNodeType::DIFF_EQ_EXPRESSION, bool kinetic=false)
bool is_unique_vars(std::string result)
auto run_sympy_solver_visitor_ast(const std::string &text, bool pade=false, bool cse=false, bool kinetic=false)
SCENARIO("Check compare_blocks in sympy unit tests", "[visitor][sympy]")
void run_sympy_visitor_passes(ast::Program &node)
Visitor for systems of algebraic and differential equations
THIS FILE IS GENERATED AT BUILD TIME AND SHALL NOT BE EDITED.
nmodl::parser::UnitDriver driver