180 lines
5.9 KiB
Python
180 lines
5.9 KiB
Python
from python_symb.Expressions.expr import Expr
|
|
from python_symb.MathTypes.symbols import Var
|
|
from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp
|
|
from typing import Union
|
|
from python_symb.IndependantTools.programming_tools import memoize
|
|
|
|
|
|
Number = Union[int, float]
|
|
|
|
@memoize
|
|
def expand(expr: Expr) -> Expr:
|
|
|
|
"""
|
|
Expand an expression
|
|
|
|
:param expr: expression to expand
|
|
:return: expanded expression
|
|
|
|
example :
|
|
5*(a+b) -> 5*a + 5*b
|
|
Expr(Mul, [5, Expr(Add, [Expr(a), Expr(b)])]) -> Expr(Add, [Expr(Mul, [Expr(5), Expr(a)]), Expr(Mul, [Expr(5), Expr(b)])])
|
|
"""
|
|
|
|
if expr.is_leaf:
|
|
return expr
|
|
|
|
match expr:
|
|
|
|
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
|
|
expanded_right = expand(right)
|
|
return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children]))
|
|
|
|
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
|
|
expanded_left = expand(left)
|
|
return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children]))
|
|
|
|
case Expr(BinOperator() as Op, [left, right]):
|
|
left_leaf, right_leaf = left.is_leaf, right.is_leaf
|
|
if not left_leaf:
|
|
left = expand(left)
|
|
if not right_leaf:
|
|
right = expand(right)
|
|
return Expr(Op, [left, right])
|
|
|
|
return expr
|
|
|
|
|
|
def _regroup(expr: Expr, focus_op: BinOperator) -> Expr:
|
|
"""
|
|
regroup an expression, with the contraint that the value of expr is focus_op
|
|
Will be used to regroup an expression
|
|
|
|
:param expr: expression to regroup
|
|
:param focus_op: operator to regroup
|
|
:return
|
|
x+x+x+x -> 4*x
|
|
Expr(Add, [Expr(x), Expr(Add, [Expr(x), Expr(Add, [Expr(x), Expr(x)])])]) -> Expr(Mul, [4, Expr(x)])
|
|
|
|
with Mul == Add.repeated_op
|
|
"""
|
|
assert focus_op.repeated_op is not None, f'{focus_op} has no repeated_op'
|
|
assert expr.value == focus_op, f'{expr.value} is not a {focus_op}'
|
|
|
|
# Motifs : Key : (Expr) -> Value : int
|
|
# represent number of times the expression appears in the expression,
|
|
# custom expr hash make for instance x+y and y+x the same when counting in motifs
|
|
motifs = {}
|
|
mem = {}
|
|
|
|
|
|
def collect_motifs(expr: Expr):
|
|
|
|
match expr:
|
|
case _ if expr in mem:
|
|
motifs[expr] += mem[expr]
|
|
|
|
case Expr(BinOperator() as op, [left, right]) if op == focus_op:
|
|
collect_motifs(left)
|
|
collect_motifs(right)
|
|
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and isinstance(right.value, Number):
|
|
mem[expr] = right.value
|
|
motifs[left] = motifs.get(expr, 0) + right.value
|
|
case Expr(BinOperator() as op, [left, right]) if op == focus_op.repeated_op and op.properties.commutative and isinstance(left.value, Number):
|
|
mem[expr] = left.value
|
|
motifs[right] = left.value if right not in motifs else motifs[right] + left.value
|
|
|
|
case _:
|
|
mem[expr] = 1
|
|
motifs[expr] = 1 if expr not in motifs else motifs[expr] + 1
|
|
|
|
collect_motifs(expr)
|
|
tuple_motifs = list(motifs.items())
|
|
|
|
def reconstruct(tuple_motifs):
|
|
match tuple_motifs:
|
|
case [(expr, int(a))]:
|
|
if a == focus_op.repeated_op.properties.neutral_element:
|
|
return expr
|
|
elif focus_op.repeated_op.properties.commutative:
|
|
return Expr(focus_op.repeated_op, [Expr(a), expr])
|
|
return Expr(focus_op.repeated_op, [expr, Expr(a)])
|
|
|
|
case [(expr, int(a)), *rest]:
|
|
if a == focus_op.repeated_op.properties.neutral_element:
|
|
return Expr(focus_op, [expr, reconstruct(rest)])
|
|
elif focus_op.repeated_op.properties.commutative:
|
|
return Expr(focus_op, [Expr(focus_op.repeated_op, [Expr(a), expr]), reconstruct(rest)])
|
|
return Expr(focus_op, [Expr(focus_op.repeated_op, [expr, Expr(a)]), reconstruct(rest)])
|
|
|
|
return reconstruct(tuple_motifs)
|
|
|
|
def regroup(expr: Expr, focus_op: BinOperator) -> Expr:
|
|
"""
|
|
Regroup an expression
|
|
|
|
:param expr: expression to regroup
|
|
:param focus_op: operator to regroup
|
|
:return: regrouped expression
|
|
|
|
example :
|
|
x+x+x+x -> 4*x
|
|
Expr(Add, [Expr(x), Expr(Add, [Expr(x), Expr(Add, [Expr(x), Expr(x)])])]) -> Expr(Mul, [4, Expr(x)])
|
|
"""
|
|
|
|
if expr.is_leaf:
|
|
return expr
|
|
|
|
match expr:
|
|
case Expr(BinOperator() as op, [left, right]) if op == focus_op:
|
|
return _regroup(expr, focus_op)
|
|
|
|
case Expr(BinOperator() as op, [left, right]):
|
|
return Expr(op, [regroup(left, focus_op), regroup(right, focus_op)])
|
|
|
|
return expr
|
|
|
|
|
|
def ungroup(expr: Expr, focus_op: BinOperator) -> Expr:
|
|
"""
|
|
Ungroup an expression
|
|
|
|
:param expr: expression to ungroup
|
|
:param focus_op: operator to ungroup
|
|
|
|
exemple :
|
|
with focus_op = Exp
|
|
(x+y) ^ 2 -> (x+y)*(x+y)
|
|
"""
|
|
|
|
def recreate(expr: Expr, nb_repeat: int) -> Expr:
|
|
if nb_repeat == 1:
|
|
return expr
|
|
return Expr(focus_op.deconstruct_op, [expr, recreate(expr, nb_repeat-1)])
|
|
|
|
if expr.is_leaf:
|
|
return expr
|
|
|
|
match expr:
|
|
case Expr(BinOperator() as op, [left, leaf]) if op == focus_op and isinstance(leaf.value, int):
|
|
return recreate(ungroup(left, focus_op), leaf.value)
|
|
case Expr(BinOperator() as op, [leaf, right]) if op == focus_op and isinstance(leaf.value, int):
|
|
return recreate(ungroup(right, focus_op), leaf.value)
|
|
case Expr(BinOperator() as op, [left, right]):
|
|
return Expr(op, [ungroup(left, focus_op), ungroup(right, focus_op)])
|
|
|
|
return expr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|