Add memoization in basic_modif.py
This commit is contained in:
parent
948e4da10c
commit
c54793c406
3 changed files with 26 additions and 9 deletions
|
@ -0,0 +1,8 @@
|
|||
def memoize(f):
|
||||
cache = {}
|
||||
def memoized_function(*args):
|
||||
if args not in cache:
|
||||
cache[args] = f(*args)
|
||||
return cache[args]
|
||||
|
||||
return memoized_function
|
|
@ -2,11 +2,12 @@ from python_symb.Expressions.expr import Expr
|
|||
from python_symb.MathTypes.symbols import Var
|
||||
from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp
|
||||
from typing import Union
|
||||
from python_symb.IndependantTools.programming_tools import memoize
|
||||
|
||||
|
||||
Number = Union[int, float]
|
||||
|
||||
|
||||
@memoize
|
||||
def expand(expr: Expr) -> Expr:
|
||||
|
||||
"""
|
||||
|
@ -17,7 +18,7 @@ def expand(expr: Expr) -> Expr:
|
|||
|
||||
example :
|
||||
5*(a+b) -> 5*a + 5*b
|
||||
Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [5, Expr(a)]), Expr(Mul, [5, Expr(b)])])
|
||||
Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [Expr(5), Expr(a)]), Expr(Mul, [Expr(5), Expr(b)])])
|
||||
"""
|
||||
|
||||
if expr.is_leaf:
|
||||
|
@ -64,19 +65,27 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr:
|
|||
# represent number of times the expression appears in the expression,
|
||||
# custom expr hash make for instance x+y and y+x the same when counting in motifs
|
||||
motifs = {}
|
||||
mem = {}
|
||||
|
||||
|
||||
def collect_motifs(expr: Expr):
|
||||
|
||||
match expr:
|
||||
case _ if expr in mem:
|
||||
motifs[expr] += mem[expr]
|
||||
|
||||
case Expr(BinOperator() as op, [left, right]) if op == focus_op:
|
||||
collect_motifs(left)
|
||||
collect_motifs(right)
|
||||
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and isinstance(right.value, Number):
|
||||
mem[expr] = right.value
|
||||
motifs[left] = motifs.get(expr, 0) + right.value
|
||||
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and op.properties.commutative and isinstance(left.value, Number):
|
||||
|
||||
motifs[right] = 1 if right not in motifs else motifs[right] + left.value
|
||||
mem[expr] = left.value
|
||||
motifs[right] = left.value if right not in motifs else motifs[right] + left.value
|
||||
|
||||
case _:
|
||||
mem[expr] = 1
|
||||
motifs[expr] = 1 if expr not in motifs else motifs[expr] + 1
|
||||
|
||||
collect_motifs(expr)
|
||||
|
|
|
@ -80,7 +80,7 @@ def test_regroup(print_result=False):
|
|||
|
||||
|
||||
def test_newton_bin(print_result=False):
|
||||
expr = (x+y)**2
|
||||
expr = (x+y)**4
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
expanded = expand(ungrouped)
|
||||
expand_grouped = regroup(expanded, Add)
|
||||
|
@ -95,11 +95,11 @@ def test_newton_bin(print_result=False):
|
|||
print(f"expanded: {expanded.to_infix_str()}")
|
||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||
|
||||
#assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
||||
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
||||
|
||||
|
||||
def test_big_bin_for_performance(print_result=False):
|
||||
expr = (x+y)**12
|
||||
expr = (x+y)**17
|
||||
print("a")
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
print("b")
|
||||
|
@ -116,12 +116,12 @@ def test_big_bin_for_performance(print_result=False):
|
|||
print("---")
|
||||
print(f"expr: {expr.to_infix_str()}")
|
||||
print(f"ungrouped: {ungrouped.to_infix_str()}")
|
||||
print(f"expanded: {expanded.to_infix_str()}")
|
||||
#print(f"expanded: {expanded.to_infix_str()}")
|
||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||
|
||||
|
||||
def test_trinomial(print_result=False):
|
||||
expr = (x+y+z)**3
|
||||
expr = (x+y+z+a)**2
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
expanded = expand(ungrouped)
|
||||
expand_grouped = regroup(expanded, Add)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue