Optimization of expand in basic_modif.py
Making Expr immutable, and then optimize __hash__ with this
This commit is contained in:
parent
c27d706a26
commit
948e4da10c
6 changed files with 92 additions and 37 deletions
|
@ -14,6 +14,7 @@ from python_symb.Parsing.parse import infix_str_to_postfix
|
||||||
class Expr(Tree):
|
class Expr(Tree):
|
||||||
"""
|
"""
|
||||||
A class to represent an expression tree
|
A class to represent an expression tree
|
||||||
|
is immutable
|
||||||
|
|
||||||
value: the value of the node
|
value: the value of the node
|
||||||
children: the subtrees of the root (Default : None)
|
children: the subtrees of the root (Default : None)
|
||||||
|
@ -23,19 +24,39 @@ class Expr(Tree):
|
||||||
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
|
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
__slots__ = ['value', 'children', '_is_frozen', '_hash']
|
||||||
__match_args__ = ('value', 'children')
|
__match_args__ = ('value', 'children')
|
||||||
|
|
||||||
def __init__(self, value, children=None):
|
def __init__(self, value, children=None, debug=False):
|
||||||
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
|
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
|
||||||
|
self._is_frozen = False
|
||||||
super().__init__(value, children if children else [])
|
super().__init__(value, children if children else [])
|
||||||
assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr'
|
self._is_frozen = True
|
||||||
|
self._hash = None
|
||||||
|
|
||||||
match value:
|
if debug:
|
||||||
case BinOperator() as op:
|
|
||||||
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
|
|
||||||
|
|
||||||
case UnaryOperator() as op:
|
assert all([isinstance(child, Expr) for child in
|
||||||
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
|
self.children]), f'Invalid children: {self.children} all child should be Expr'
|
||||||
|
|
||||||
|
match value:
|
||||||
|
case BinOperator() as op:
|
||||||
|
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
|
||||||
|
|
||||||
|
case UnaryOperator() as op:
|
||||||
|
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
|
||||||
|
|
||||||
|
def __setattr__(self, key, value):
|
||||||
|
if key == "_hash" or key == "_is_frozen":
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
elif getattr(self, "_is_frozen", True):
|
||||||
|
raise TypeError(f"{self.__class__.__name__} is immutable")
|
||||||
|
super().__setattr__(key, value)
|
||||||
|
|
||||||
|
def __delattr__(self, item):
|
||||||
|
if getattr(self, "_is_frozen", True):
|
||||||
|
raise TypeError(f"{self.__class__.__name__} is immutable")
|
||||||
|
super().__delattr__(item)
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -91,7 +112,6 @@ class Expr(Tree):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bin_op_constructor(self, other, op):
|
def bin_op_constructor(self, other, op):
|
||||||
"""
|
"""
|
||||||
|
@ -139,19 +159,26 @@ class Expr(Tree):
|
||||||
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
|
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
|
||||||
see test_eq
|
see test_eq
|
||||||
"""
|
"""
|
||||||
|
if self._hash:
|
||||||
|
return self._hash
|
||||||
|
|
||||||
match self:
|
match self:
|
||||||
|
|
||||||
case Expr(value) if self.is_leaf:
|
case Expr(value) if self.is_leaf:
|
||||||
return hash(value)
|
self._hash = hash(value)
|
||||||
|
return self._hash
|
||||||
|
|
||||||
case Expr(UnaryOperator() as op, [child]):
|
case Expr(UnaryOperator() as op, [child]):
|
||||||
return hash(op.name + str(hash(child)))
|
self._hash = hash(op.name + str(hash(child)))
|
||||||
|
return self._hash
|
||||||
|
|
||||||
case Expr(BinOperator() as op, [left, right]):
|
case Expr(BinOperator() as op, [left, right]):
|
||||||
if op.properties.commutative and op.properties.associative:
|
if op.properties.commutative and op.properties.associative:
|
||||||
return hash(op.name) + hash(left) + hash(right)
|
self._hash = hash(op.name) + hash(left) + hash(right)
|
||||||
|
return self._hash
|
||||||
else:
|
else:
|
||||||
return hash(op.name) + hash(str(hash(left)) + str(hash(right)))
|
self._hash = hash(op.name) + hash(str(hash(left)) + str(hash(right)))
|
||||||
|
return self._hash
|
||||||
|
|
||||||
case _:
|
case _:
|
||||||
print(f'Invalid type: {type(self)}')
|
print(f'Invalid type: {type(self)}')
|
||||||
|
|
0
python_symb/IndependantTools/programming_tools.py
Normal file
0
python_symb/IndependantTools/programming_tools.py
Normal file
|
@ -1,6 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import Iterable, Generator
|
from typing import Iterable, Generator
|
||||||
from python_symb.IndependantTools.tools import gcd
|
from python_symb.IndependantTools.math_tools import gcd
|
||||||
|
|
||||||
|
|
||||||
class Fraction:
|
class Fraction:
|
||||||
|
|
|
@ -8,6 +8,7 @@ Number = Union[int, float]
|
||||||
|
|
||||||
|
|
||||||
def expand(expr: Expr) -> Expr:
|
def expand(expr: Expr) -> Expr:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Expand an expression
|
Expand an expression
|
||||||
|
|
||||||
|
@ -23,14 +24,22 @@ def expand(expr: Expr) -> Expr:
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
match expr:
|
match expr:
|
||||||
|
|
||||||
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
|
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
|
||||||
return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children]))
|
expanded_right = expand(right)
|
||||||
|
return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children]))
|
||||||
|
|
||||||
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
|
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
|
||||||
return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children]))
|
expanded_left = expand(left)
|
||||||
|
return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children]))
|
||||||
|
|
||||||
case Expr(BinOperator() as Op, [left, right]):
|
case Expr(BinOperator() as Op, [left, right]):
|
||||||
return Expr(Op, [expand(left), expand(right)])
|
left_leaf, right_leaf = left.is_leaf, right.is_leaf
|
||||||
|
if not left_leaf:
|
||||||
|
left = expand(left)
|
||||||
|
if not right_leaf:
|
||||||
|
right = expand(right)
|
||||||
|
return Expr(Op, [left, right])
|
||||||
|
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
|
@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var
|
||||||
from python_symb.Expressions.expr import Expr
|
from python_symb.Expressions.expr import Expr
|
||||||
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
|
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
|
||||||
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
|
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
|
||||||
import time
|
import cProfile
|
||||||
from functools import wraps
|
|
||||||
|
|
||||||
|
|
||||||
def timeit(func):
|
|
||||||
@wraps(func)
|
|
||||||
def measure_time(*args, **kwargs):
|
|
||||||
start_time = time.time()
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
end_time = time.time()
|
|
||||||
print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time))
|
|
||||||
return result
|
|
||||||
return measure_time
|
|
||||||
|
|
||||||
|
|
||||||
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
|
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
|
||||||
|
|
||||||
def create_var():
|
def create_var():
|
||||||
global x, y
|
global x, y, z, a, b, c
|
||||||
x = var('x')
|
x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c')
|
||||||
y = var('y')
|
|
||||||
|
|
||||||
|
|
||||||
def test_sum(print_result=False):
|
def test_sum(print_result=False):
|
||||||
|
@ -69,6 +57,17 @@ def test_expand(print_result=False):
|
||||||
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
|
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
|
||||||
|
|
||||||
|
|
||||||
|
def test_nested_expand(print_result=False):
|
||||||
|
expr = (x+(x+y)*(x+y))*(x+y)
|
||||||
|
expanded = expand(expr)
|
||||||
|
expanded = regroup(expanded, Add)
|
||||||
|
expanded = regroup(expanded, Mul)
|
||||||
|
expanded_str = expanded.to_infix_str()
|
||||||
|
if print_result:
|
||||||
|
print("(x+(x+y)*(x+y))*(x+y)")
|
||||||
|
print(f"expanded: {expanded_str}")
|
||||||
|
|
||||||
|
|
||||||
def test_regroup(print_result=False):
|
def test_regroup(print_result=False):
|
||||||
expr = x+x+2*x+y+x+3*y+x
|
expr = x+x+2*x+y+x+3*y+x
|
||||||
regrouped = regroup(expr, Add)
|
regrouped = regroup(expr, Add)
|
||||||
|
@ -81,7 +80,7 @@ def test_regroup(print_result=False):
|
||||||
|
|
||||||
|
|
||||||
def test_newton_bin(print_result=False):
|
def test_newton_bin(print_result=False):
|
||||||
expr = (x+y)**4
|
expr = (x+y)**2
|
||||||
ungrouped = ungroup(expr, Exp)
|
ungrouped = ungroup(expr, Exp)
|
||||||
expanded = expand(ungrouped)
|
expanded = expand(ungrouped)
|
||||||
expand_grouped = regroup(expanded, Add)
|
expand_grouped = regroup(expanded, Add)
|
||||||
|
@ -96,11 +95,11 @@ def test_newton_bin(print_result=False):
|
||||||
print(f"expanded: {expanded.to_infix_str()}")
|
print(f"expanded: {expanded.to_infix_str()}")
|
||||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||||
|
|
||||||
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
#assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
||||||
|
|
||||||
|
|
||||||
@timeit
|
|
||||||
def test_big_bin_for_performance(print_result=False):
|
def test_big_bin_for_performance(print_result=False):
|
||||||
expr = (x+y)**11
|
expr = (x+y)**12
|
||||||
print("a")
|
print("a")
|
||||||
ungrouped = ungroup(expr, Exp)
|
ungrouped = ungroup(expr, Exp)
|
||||||
print("b")
|
print("b")
|
||||||
|
@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False):
|
||||||
print(f"expanded: {expanded.to_infix_str()}")
|
print(f"expanded: {expanded.to_infix_str()}")
|
||||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||||
|
|
||||||
|
|
||||||
|
def test_trinomial(print_result=False):
|
||||||
|
expr = (x+y+z)**3
|
||||||
|
ungrouped = ungroup(expr, Exp)
|
||||||
|
expanded = expand(ungrouped)
|
||||||
|
expand_grouped = regroup(expanded, Add)
|
||||||
|
expand_grouped = regroup(expand_grouped, Mul)
|
||||||
|
|
||||||
|
if print_result:
|
||||||
|
print("---")
|
||||||
|
print("trinomial test")
|
||||||
|
print("---")
|
||||||
|
print(f"expr: {expr.to_infix_str()}")
|
||||||
|
print(f"ungrouped: {ungrouped.to_infix_str()}")
|
||||||
|
print(f"expanded: {expanded.to_infix_str()}")
|
||||||
|
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
create_var()
|
create_var()
|
||||||
print_r = True
|
print_r = True
|
||||||
|
@ -127,9 +145,10 @@ if __name__ == '__main__':
|
||||||
test_parse(print_result=print_r)
|
test_parse(print_result=print_r)
|
||||||
test_expr_to_str(print_result=print_r)
|
test_expr_to_str(print_result=print_r)
|
||||||
test_expand(print_result=print_r)
|
test_expand(print_result=print_r)
|
||||||
|
test_nested_expand(print_result=print_r)
|
||||||
test_regroup(print_result=print_r)
|
test_regroup(print_result=print_r)
|
||||||
test_newton_bin(print_result=print_r)
|
#test_trinomial(print_result=print_r)
|
||||||
test_big_bin_for_performance(print_result=print_r)
|
cProfile.run('test_big_bin_for_performance(print_result=print_r)')
|
||||||
print("All tests passed")
|
print("All tests passed")
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue