Add memoization in basic_modif.py

This commit is contained in:
Clément Barthélemy 2024-02-24 03:57:37 +01:00
parent 948e4da10c
commit c54793c406
3 changed files with 26 additions and 9 deletions

View file

@ -2,11 +2,12 @@ from python_symb.Expressions.expr import Expr
from python_symb.MathTypes.symbols import Var
from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp
from typing import Union
from python_symb.IndependantTools.programming_tools import memoize
Number = Union[int, float]
@memoize
def expand(expr: Expr) -> Expr:
"""
@ -17,7 +18,7 @@ def expand(expr: Expr) -> Expr:
example :
5*(a+b) -> 5*a + 5*b
Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [5, Expr(a)]), Expr(Mul, [5, Expr(b)])])
Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [Expr(5), Expr(a)]), Expr(Mul, [Expr(5), Expr(b)])])
"""
if expr.is_leaf:
@ -64,19 +65,27 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr:
# represent number of times the expression appears in the expression,
# custom expr hash make for instance x+y and y+x the same when counting in motifs
motifs = {}
mem = {}
def collect_motifs(expr: Expr):
match expr:
case _ if expr in mem:
motifs[expr] += mem[expr]
case Expr(BinOperator() as op, [left, right]) if op == focus_op:
collect_motifs(left)
collect_motifs(right)
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and isinstance(right.value, Number):
mem[expr] = right.value
motifs[left] = motifs.get(expr, 0) + right.value
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and op.properties.commutative and isinstance(left.value, Number):
motifs[right] = 1 if right not in motifs else motifs[right] + left.value
mem[expr] = left.value
motifs[right] = left.value if right not in motifs else motifs[right] + left.value
case _:
mem[expr] = 1
motifs[expr] = 1 if expr not in motifs else motifs[expr] + 1
collect_motifs(expr)