Optimization of expand in basic_modif.py

Making Expr immutable, and then optimize __hash__ with this
This commit is contained in:
Clément Barthélemy 2024-02-24 02:57:07 +01:00
parent c27d706a26
commit 948e4da10c
6 changed files with 92 additions and 37 deletions

View file

@ -14,6 +14,7 @@ from python_symb.Parsing.parse import infix_str_to_postfix
class Expr(Tree):
"""
A class to represent an expression tree
is immutable
value: the value of the node
children: the subtrees of the root (Default : None)
@ -23,19 +24,39 @@ class Expr(Tree):
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
"""
__slots__ = ['value', 'children', '_is_frozen', '_hash']
__match_args__ = ('value', 'children')
def __init__(self, value, children=None):
def __init__(self, value, children=None, debug=False):
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
self._is_frozen = False
super().__init__(value, children if children else [])
assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr'
self._is_frozen = True
self._hash = None
match value:
case BinOperator() as op:
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
if debug:
case UnaryOperator() as op:
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
assert all([isinstance(child, Expr) for child in
self.children]), f'Invalid children: {self.children} all child should be Expr'
match value:
case BinOperator() as op:
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
case UnaryOperator() as op:
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
def __setattr__(self, key, value):
if key == "_hash" or key == "_is_frozen":
super().__setattr__(key, value)
elif getattr(self, "_is_frozen", True):
raise TypeError(f"{self.__class__.__name__} is immutable")
super().__setattr__(key, value)
def __delattr__(self, item):
if getattr(self, "_is_frozen", True):
raise TypeError(f"{self.__class__.__name__} is immutable")
super().__delattr__(item)
@staticmethod
@ -91,7 +112,6 @@ class Expr(Tree):
@staticmethod
def bin_op_constructor(self, other, op):
"""
@ -139,19 +159,26 @@ class Expr(Tree):
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
see test_eq
"""
if self._hash:
return self._hash
match self:
case Expr(value) if self.is_leaf:
return hash(value)
self._hash = hash(value)
return self._hash
case Expr(UnaryOperator() as op, [child]):
return hash(op.name + str(hash(child)))
self._hash = hash(op.name + str(hash(child)))
return self._hash
case Expr(BinOperator() as op, [left, right]):
if op.properties.commutative and op.properties.associative:
return hash(op.name) + hash(left) + hash(right)
self._hash = hash(op.name) + hash(left) + hash(right)
return self._hash
else:
return hash(op.name) + hash(str(hash(left)) + str(hash(right)))
self._hash = hash(op.name) + hash(str(hash(left)) + str(hash(right)))
return self._hash
case _:
print(f'Invalid type: {type(self)}')

View file

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Iterable, Generator
from python_symb.IndependantTools.tools import gcd
from python_symb.IndependantTools.math_tools import gcd
class Fraction:

View file

@ -8,6 +8,7 @@ Number = Union[int, float]
def expand(expr: Expr) -> Expr:
"""
Expand an expression
@ -23,14 +24,22 @@ def expand(expr: Expr) -> Expr:
return expr
match expr:
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children]))
expanded_right = expand(right)
return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children]))
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children]))
expanded_left = expand(left)
return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children]))
case Expr(BinOperator() as Op, [left, right]):
return Expr(Op, [expand(left), expand(right)])
left_leaf, right_leaf = left.is_leaf, right.is_leaf
if not left_leaf:
left = expand(left)
if not right_leaf:
right = expand(right)
return Expr(Op, [left, right])
return expr

View file

@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var
from python_symb.Expressions.expr import Expr
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
import time
from functools import wraps
import cProfile
def timeit(func):
@wraps(func)
def measure_time(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time))
return result
return measure_time
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
def create_var():
global x, y
x = var('x')
y = var('y')
global x, y, z, a, b, c
x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c')
def test_sum(print_result=False):
@ -69,6 +57,17 @@ def test_expand(print_result=False):
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
def test_nested_expand(print_result=False):
expr = (x+(x+y)*(x+y))*(x+y)
expanded = expand(expr)
expanded = regroup(expanded, Add)
expanded = regroup(expanded, Mul)
expanded_str = expanded.to_infix_str()
if print_result:
print("(x+(x+y)*(x+y))*(x+y)")
print(f"expanded: {expanded_str}")
def test_regroup(print_result=False):
expr = x+x+2*x+y+x+3*y+x
regrouped = regroup(expr, Add)
@ -81,7 +80,7 @@ def test_regroup(print_result=False):
def test_newton_bin(print_result=False):
expr = (x+y)**4
expr = (x+y)**2
ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add)
@ -96,11 +95,11 @@ def test_newton_bin(print_result=False):
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
#assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
@timeit
def test_big_bin_for_performance(print_result=False):
expr = (x+y)**11
expr = (x+y)**12
print("a")
ungrouped = ungroup(expr, Exp)
print("b")
@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False):
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
def test_trinomial(print_result=False):
expr = (x+y+z)**3
ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add)
expand_grouped = regroup(expand_grouped, Mul)
if print_result:
print("---")
print("trinomial test")
print("---")
print(f"expr: {expr.to_infix_str()}")
print(f"ungrouped: {ungrouped.to_infix_str()}")
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
if __name__ == '__main__':
create_var()
print_r = True
@ -127,9 +145,10 @@ if __name__ == '__main__':
test_parse(print_result=print_r)
test_expr_to_str(print_result=print_r)
test_expand(print_result=print_r)
test_nested_expand(print_result=print_r)
test_regroup(print_result=print_r)
test_newton_bin(print_result=print_r)
test_big_bin_for_performance(print_result=print_r)
#test_trinomial(print_result=print_r)
cProfile.run('test_big_bin_for_performance(print_result=print_r)')
print("All tests passed")