Regroup test in test.py

Added neutral_element and absorbing_element in BinProperties
New simplifications functions in TreeModification
This commit is contained in:
Clément Barthélemy 2024-02-21 20:13:47 +01:00
parent 8ee24d763b
commit c27d706a26
4 changed files with 251 additions and 86 deletions

View file

@ -1,6 +1,6 @@
from python_symb.Expressions.expr import Expr
from python_symb.MathTypes.symbols import Var
from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul
from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp
from typing import Union
@ -23,10 +23,10 @@ def expand(expr: Expr) -> Expr:
return expr
match expr:
case Expr(BinOperator() as Op1, [Expr(Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children]))
case Expr(BinOperator() as Op1, [left, Expr(Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children]))
case Expr(BinOperator() as Op, [left, right]):
@ -53,7 +53,7 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr:
# Motifs : Key : (Expr) -> Value : int
# represent number of times the expression appears in the expression,
# custom hash make for instance x+y and y+x the same when counting
# custom expr hash make for instance x+y and y+x the same when counting in motifs
motifs = {}
def collect_motifs(expr: Expr):
@ -76,11 +76,16 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr:
def reconstruct(tuple_motifs):
match tuple_motifs:
case [(expr, int(a))]:
if focus_op.repeated_op.properties.commutative:
if a == focus_op.repeated_op.properties.neutral_element:
return expr
elif focus_op.repeated_op.properties.commutative:
return Expr(focus_op.repeated_op, [Expr(a), expr])
return Expr(focus_op.repeated_op, [expr, Expr(a)])
case [(expr, int(a)), *rest]:
if focus_op.repeated_op.properties.commutative:
if a == focus_op.repeated_op.properties.neutral_element:
return Expr(focus_op, [expr, reconstruct(rest)])
elif focus_op.repeated_op.properties.commutative:
return Expr(focus_op, [Expr(focus_op.repeated_op, [Expr(a), expr]), reconstruct(rest)])
return Expr(focus_op, [Expr(focus_op.repeated_op, [expr, Expr(a)]), reconstruct(rest)])
@ -112,42 +117,46 @@ def regroup(expr: Expr, focus_op: BinOperator) -> Expr:
return expr
def test():
def ungroup(expr: Expr, focus_op: BinOperator) -> Expr:
"""
Ungroup an expression
:param expr: expression to ungroup
:param focus_op: operator to ungroup
exemple :
with focus_op = Exp
(x+y) ^ 2 -> (x+y)*(x+y)
"""
def recreate(expr: Expr, nb_repeat: int) -> Expr:
if nb_repeat == 1:
return expr
return Expr(focus_op.deconstruct_op, [expr, recreate(expr, nb_repeat-1)])
if expr.is_leaf:
return expr
match expr:
case Expr(BinOperator() as op, [left, leaf]) if op == focus_op and isinstance(leaf.value, int):
return recreate(ungroup(left, focus_op), leaf.value)
case Expr(BinOperator() as op, [leaf, right]) if op == focus_op and isinstance(leaf.value, int):
return recreate(ungroup(right, focus_op), leaf.value)
case Expr(BinOperator() as op, [left, right]):
return Expr(op, [ungroup(left, focus_op), ungroup(right, focus_op)])
return expr
x, y = Var('x'), Var('y')
a, b = Var('a'), Var('b')
def test_expand():
expr = (x+y)*(x+y)*(x+y)
expr = expand(expr)
expr = regroup(expr, Add)
print(f"(x+y)*(x+y)*(x+y) -> {expr.to_infix_str()}")
expr = (x+y+a)*b
print(f"(x+y+a)*b -> {expand(expr).to_infix_str()}")
def test_regroup():
expr = x+2*x+y+y+2*y
print(f"x+2*x+y+y+2*y -> {regroup(expr, Add).to_infix_str()}")
def test_power():
expr = x*x*x + y*y*y
print(f"x*x*x -> {regroup(expr, Mul).to_infix_str()}")
def test_all():
expr = (x+y)*(x+y)*(x+y)
expanded_expr = expand(expr)
regrouped_expr = regroup(expanded_expr, Add)
print(f"(x+y)*(x+y)*(x+y) -> {regrouped_expr.to_infix_str()}")
test_expand()
test_regroup()
test_power()
#test_all()
if __name__ == "__main__":
test()