From c54793c406a10231fc61f1b2dc74f54e356c87cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Barth=C3=A9lemy?= Date: Sat, 24 Feb 2024 03:57:37 +0100 Subject: [PATCH] Add memoization in basic_modif.py --- .../IndependantTools/programming_tools.py | 8 ++++++++ python_symb/TreeModification/basic_modif.py | 17 +++++++++++++---- python_symb/test.py | 10 +++++----- 3 files changed, 26 insertions(+), 9 deletions(-) diff --git a/python_symb/IndependantTools/programming_tools.py b/python_symb/IndependantTools/programming_tools.py index e69de29..72bd015 100644 --- a/python_symb/IndependantTools/programming_tools.py +++ b/python_symb/IndependantTools/programming_tools.py @@ -0,0 +1,8 @@ +def memoize(f): + cache = {} + def memoized_function(*args): + if args not in cache: + cache[args] = f(*args) + return cache[args] + + return memoized_function \ No newline at end of file diff --git a/python_symb/TreeModification/basic_modif.py b/python_symb/TreeModification/basic_modif.py index 827cc8f..52b541e 100644 --- a/python_symb/TreeModification/basic_modif.py +++ b/python_symb/TreeModification/basic_modif.py @@ -2,11 +2,12 @@ from python_symb.Expressions.expr import Expr from python_symb.MathTypes.symbols import Var from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp from typing import Union +from python_symb.IndependantTools.programming_tools import memoize Number = Union[int, float] - +@memoize def expand(expr: Expr) -> Expr: """ @@ -17,7 +18,7 @@ def expand(expr: Expr) -> Expr: example : 5*(a+b) -> 5*a + 5*b - Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [5, Expr(a)]), Expr(Mul, [5, Expr(b)])]) + Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [Expr(5), Expr(a)]), Expr(Mul, [Expr(5), Expr(b)])]) """ if expr.is_leaf: @@ -64,19 +65,27 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr: # represent number of times the expression appears in the expression, # custom expr hash make for instance x+y and y+x the same when counting in motifs motifs = {} + mem = {} + def collect_motifs(expr: Expr): + match expr: + case _ if expr in mem: + motifs[expr] += mem[expr] + case Expr(BinOperator() as op, [left, right]) if op == focus_op: collect_motifs(left) collect_motifs(right) case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and isinstance(right.value, Number): + mem[expr] = right.value motifs[left] = motifs.get(expr, 0) + right.value case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and op.properties.commutative and isinstance(left.value, Number): - - motifs[right] = 1 if right not in motifs else motifs[right] + left.value + mem[expr] = left.value + motifs[right] = left.value if right not in motifs else motifs[right] + left.value case _: + mem[expr] = 1 motifs[expr] = 1 if expr not in motifs else motifs[expr] + 1 collect_motifs(expr) diff --git a/python_symb/test.py b/python_symb/test.py index c5d3331..4967ec5 100644 --- a/python_symb/test.py +++ b/python_symb/test.py @@ -80,7 +80,7 @@ def test_regroup(print_result=False): def test_newton_bin(print_result=False): - expr = (x+y)**2 + expr = (x+y)**4 ungrouped = ungroup(expr, Exp) expanded = expand(ungrouped) expand_grouped = regroup(expanded, Add) @@ -95,11 +95,11 @@ def test_newton_bin(print_result=False): print(f"expanded: {expanded.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}") - #assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) + assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) def test_big_bin_for_performance(print_result=False): - expr = (x+y)**12 + expr = (x+y)**17 print("a") ungrouped = ungroup(expr, Exp) print("b") @@ -116,12 +116,12 @@ def test_big_bin_for_performance(print_result=False): print("---") print(f"expr: {expr.to_infix_str()}") print(f"ungrouped: {ungrouped.to_infix_str()}") - print(f"expanded: {expanded.to_infix_str()}") + #print(f"expanded: {expanded.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}") def test_trinomial(print_result=False): - expr = (x+y+z)**3 + expr = (x+y+z+a)**2 ungrouped = ungroup(expr, Exp) expanded = expand(ungrouped) expand_grouped = regroup(expanded, Add)