Optimization of expand in basic_modif.py

Making Expr immutable, and then optimize __hash__ with this
This commit is contained in:
Clément Barthélemy 2024-02-24 02:57:07 +01:00
parent c27d706a26
commit 948e4da10c
6 changed files with 92 additions and 37 deletions

View file

@ -14,6 +14,7 @@ from python_symb.Parsing.parse import infix_str_to_postfix
class Expr(Tree): class Expr(Tree):
""" """
A class to represent an expression tree A class to represent an expression tree
is immutable
value: the value of the node value: the value of the node
children: the subtrees of the root (Default : None) children: the subtrees of the root (Default : None)
@ -23,19 +24,39 @@ class Expr(Tree):
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))]) Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
""" """
__slots__ = ['value', 'children', '_is_frozen', '_hash']
__match_args__ = ('value', 'children') __match_args__ = ('value', 'children')
def __init__(self, value, children=None): def __init__(self, value, children=None, debug=False):
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
self._is_frozen = False
super().__init__(value, children if children else []) super().__init__(value, children if children else [])
assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr' self._is_frozen = True
self._hash = None
match value: if debug:
case BinOperator() as op:
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
case UnaryOperator() as op: assert all([isinstance(child, Expr) for child in
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}' self.children]), f'Invalid children: {self.children} all child should be Expr'
match value:
case BinOperator() as op:
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
case UnaryOperator() as op:
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
def __setattr__(self, key, value):
if key == "_hash" or key == "_is_frozen":
super().__setattr__(key, value)
elif getattr(self, "_is_frozen", True):
raise TypeError(f"{self.__class__.__name__} is immutable")
super().__setattr__(key, value)
def __delattr__(self, item):
if getattr(self, "_is_frozen", True):
raise TypeError(f"{self.__class__.__name__} is immutable")
super().__delattr__(item)
@staticmethod @staticmethod
@ -91,7 +112,6 @@ class Expr(Tree):
@staticmethod @staticmethod
def bin_op_constructor(self, other, op): def bin_op_constructor(self, other, op):
""" """
@ -139,19 +159,26 @@ class Expr(Tree):
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
see test_eq see test_eq
""" """
if self._hash:
return self._hash
match self: match self:
case Expr(value) if self.is_leaf: case Expr(value) if self.is_leaf:
return hash(value) self._hash = hash(value)
return self._hash
case Expr(UnaryOperator() as op, [child]): case Expr(UnaryOperator() as op, [child]):
return hash(op.name + str(hash(child))) self._hash = hash(op.name + str(hash(child)))
return self._hash
case Expr(BinOperator() as op, [left, right]): case Expr(BinOperator() as op, [left, right]):
if op.properties.commutative and op.properties.associative: if op.properties.commutative and op.properties.associative:
return hash(op.name) + hash(left) + hash(right) self._hash = hash(op.name) + hash(left) + hash(right)
return self._hash
else: else:
return hash(op.name) + hash(str(hash(left)) + str(hash(right))) self._hash = hash(op.name) + hash(str(hash(left)) + str(hash(right)))
return self._hash
case _: case _:
print(f'Invalid type: {type(self)}') print(f'Invalid type: {type(self)}')

View file

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Iterable, Generator from typing import Iterable, Generator
from python_symb.IndependantTools.tools import gcd from python_symb.IndependantTools.math_tools import gcd
class Fraction: class Fraction:

View file

@ -8,6 +8,7 @@ Number = Union[int, float]
def expand(expr: Expr) -> Expr: def expand(expr: Expr) -> Expr:
""" """
Expand an expression Expand an expression
@ -23,14 +24,22 @@ def expand(expr: Expr) -> Expr:
return expr return expr
match expr: match expr:
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive: case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children])) expanded_right = expand(right)
return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children]))
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive: case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children])) expanded_left = expand(left)
return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children]))
case Expr(BinOperator() as Op, [left, right]): case Expr(BinOperator() as Op, [left, right]):
return Expr(Op, [expand(left), expand(right)]) left_leaf, right_leaf = left.is_leaf, right.is_leaf
if not left_leaf:
left = expand(left)
if not right_leaf:
right = expand(right)
return Expr(Op, [left, right])
return expr return expr

View file

@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var
from python_symb.Expressions.expr import Expr from python_symb.Expressions.expr import Expr
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
import time import cProfile
from functools import wraps
def timeit(func):
@wraps(func)
def measure_time(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time))
return result
return measure_time
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful # assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
def create_var(): def create_var():
global x, y global x, y, z, a, b, c
x = var('x') x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c')
y = var('y')
def test_sum(print_result=False): def test_sum(print_result=False):
@ -69,6 +57,17 @@ def test_expand(print_result=False):
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])]) assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
def test_nested_expand(print_result=False):
expr = (x+(x+y)*(x+y))*(x+y)
expanded = expand(expr)
expanded = regroup(expanded, Add)
expanded = regroup(expanded, Mul)
expanded_str = expanded.to_infix_str()
if print_result:
print("(x+(x+y)*(x+y))*(x+y)")
print(f"expanded: {expanded_str}")
def test_regroup(print_result=False): def test_regroup(print_result=False):
expr = x+x+2*x+y+x+3*y+x expr = x+x+2*x+y+x+3*y+x
regrouped = regroup(expr, Add) regrouped = regroup(expr, Add)
@ -81,7 +80,7 @@ def test_regroup(print_result=False):
def test_newton_bin(print_result=False): def test_newton_bin(print_result=False):
expr = (x+y)**4 expr = (x+y)**2
ungrouped = ungroup(expr, Exp) ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped) expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add) expand_grouped = regroup(expanded, Add)
@ -96,11 +95,11 @@ def test_newton_bin(print_result=False):
print(f"expanded: {expanded.to_infix_str()}") print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}")
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) #assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
@timeit
def test_big_bin_for_performance(print_result=False): def test_big_bin_for_performance(print_result=False):
expr = (x+y)**11 expr = (x+y)**12
print("a") print("a")
ungrouped = ungroup(expr, Exp) ungrouped = ungroup(expr, Exp)
print("b") print("b")
@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False):
print(f"expanded: {expanded.to_infix_str()}") print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}")
def test_trinomial(print_result=False):
expr = (x+y+z)**3
ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add)
expand_grouped = regroup(expand_grouped, Mul)
if print_result:
print("---")
print("trinomial test")
print("---")
print(f"expr: {expr.to_infix_str()}")
print(f"ungrouped: {ungrouped.to_infix_str()}")
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
if __name__ == '__main__': if __name__ == '__main__':
create_var() create_var()
print_r = True print_r = True
@ -127,9 +145,10 @@ if __name__ == '__main__':
test_parse(print_result=print_r) test_parse(print_result=print_r)
test_expr_to_str(print_result=print_r) test_expr_to_str(print_result=print_r)
test_expand(print_result=print_r) test_expand(print_result=print_r)
test_nested_expand(print_result=print_r)
test_regroup(print_result=print_r) test_regroup(print_result=print_r)
test_newton_bin(print_result=print_r) #test_trinomial(print_result=print_r)
test_big_bin_for_performance(print_result=print_r) cProfile.run('test_big_bin_for_performance(print_result=print_r)')
print("All tests passed") print("All tests passed")