Optimization of expand in basic_modif.py
Making Expr immutable, and then optimize __hash__ with this
This commit is contained in:
parent
c27d706a26
commit
948e4da10c
6 changed files with 92 additions and 37 deletions
|
@ -14,6 +14,7 @@ from python_symb.Parsing.parse import infix_str_to_postfix
|
|||
class Expr(Tree):
|
||||
"""
|
||||
A class to represent an expression tree
|
||||
is immutable
|
||||
|
||||
value: the value of the node
|
||||
children: the subtrees of the root (Default : None)
|
||||
|
@ -23,19 +24,39 @@ class Expr(Tree):
|
|||
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
|
||||
|
||||
"""
|
||||
__slots__ = ['value', 'children', '_is_frozen', '_hash']
|
||||
__match_args__ = ('value', 'children')
|
||||
|
||||
def __init__(self, value, children=None):
|
||||
def __init__(self, value, children=None, debug=False):
|
||||
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
|
||||
self._is_frozen = False
|
||||
super().__init__(value, children if children else [])
|
||||
assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr'
|
||||
self._is_frozen = True
|
||||
self._hash = None
|
||||
|
||||
match value:
|
||||
case BinOperator() as op:
|
||||
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
|
||||
if debug:
|
||||
|
||||
case UnaryOperator() as op:
|
||||
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
|
||||
assert all([isinstance(child, Expr) for child in
|
||||
self.children]), f'Invalid children: {self.children} all child should be Expr'
|
||||
|
||||
match value:
|
||||
case BinOperator() as op:
|
||||
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
|
||||
|
||||
case UnaryOperator() as op:
|
||||
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
if key == "_hash" or key == "_is_frozen":
|
||||
super().__setattr__(key, value)
|
||||
elif getattr(self, "_is_frozen", True):
|
||||
raise TypeError(f"{self.__class__.__name__} is immutable")
|
||||
super().__setattr__(key, value)
|
||||
|
||||
def __delattr__(self, item):
|
||||
if getattr(self, "_is_frozen", True):
|
||||
raise TypeError(f"{self.__class__.__name__} is immutable")
|
||||
super().__delattr__(item)
|
||||
|
||||
|
||||
@staticmethod
|
||||
|
@ -91,7 +112,6 @@ class Expr(Tree):
|
|||
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def bin_op_constructor(self, other, op):
|
||||
"""
|
||||
|
@ -139,19 +159,26 @@ class Expr(Tree):
|
|||
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
|
||||
see test_eq
|
||||
"""
|
||||
if self._hash:
|
||||
return self._hash
|
||||
|
||||
match self:
|
||||
|
||||
case Expr(value) if self.is_leaf:
|
||||
return hash(value)
|
||||
self._hash = hash(value)
|
||||
return self._hash
|
||||
|
||||
case Expr(UnaryOperator() as op, [child]):
|
||||
return hash(op.name + str(hash(child)))
|
||||
self._hash = hash(op.name + str(hash(child)))
|
||||
return self._hash
|
||||
|
||||
case Expr(BinOperator() as op, [left, right]):
|
||||
if op.properties.commutative and op.properties.associative:
|
||||
return hash(op.name) + hash(left) + hash(right)
|
||||
self._hash = hash(op.name) + hash(left) + hash(right)
|
||||
return self._hash
|
||||
else:
|
||||
return hash(op.name) + hash(str(hash(left)) + str(hash(right)))
|
||||
self._hash = hash(op.name) + hash(str(hash(left)) + str(hash(right)))
|
||||
return self._hash
|
||||
|
||||
case _:
|
||||
print(f'Invalid type: {type(self)}')
|
||||
|
|
0
python_symb/IndependantTools/programming_tools.py
Normal file
0
python_symb/IndependantTools/programming_tools.py
Normal file
|
@ -1,6 +1,6 @@
|
|||
from __future__ import annotations
|
||||
from typing import Iterable, Generator
|
||||
from python_symb.IndependantTools.tools import gcd
|
||||
from python_symb.IndependantTools.math_tools import gcd
|
||||
|
||||
|
||||
class Fraction:
|
||||
|
|
|
@ -8,6 +8,7 @@ Number = Union[int, float]
|
|||
|
||||
|
||||
def expand(expr: Expr) -> Expr:
|
||||
|
||||
"""
|
||||
Expand an expression
|
||||
|
||||
|
@ -23,14 +24,22 @@ def expand(expr: Expr) -> Expr:
|
|||
return expr
|
||||
|
||||
match expr:
|
||||
|
||||
case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive:
|
||||
return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children]))
|
||||
expanded_right = expand(right)
|
||||
return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children]))
|
||||
|
||||
case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive:
|
||||
return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children]))
|
||||
expanded_left = expand(left)
|
||||
return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children]))
|
||||
|
||||
case Expr(BinOperator() as Op, [left, right]):
|
||||
return Expr(Op, [expand(left), expand(right)])
|
||||
left_leaf, right_leaf = left.is_leaf, right.is_leaf
|
||||
if not left_leaf:
|
||||
left = expand(left)
|
||||
if not right_leaf:
|
||||
right = expand(right)
|
||||
return Expr(Op, [left, right])
|
||||
|
||||
return expr
|
||||
|
||||
|
|
|
@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var
|
|||
from python_symb.Expressions.expr import Expr
|
||||
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
|
||||
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
|
||||
import time
|
||||
from functools import wraps
|
||||
import cProfile
|
||||
|
||||
|
||||
def timeit(func):
|
||||
@wraps(func)
|
||||
def measure_time(*args, **kwargs):
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
end_time = time.time()
|
||||
print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time))
|
||||
return result
|
||||
return measure_time
|
||||
|
||||
|
||||
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
|
||||
|
||||
def create_var():
|
||||
global x, y
|
||||
x = var('x')
|
||||
y = var('y')
|
||||
global x, y, z, a, b, c
|
||||
x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c')
|
||||
|
||||
|
||||
def test_sum(print_result=False):
|
||||
|
@ -69,6 +57,17 @@ def test_expand(print_result=False):
|
|||
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
|
||||
|
||||
|
||||
def test_nested_expand(print_result=False):
|
||||
expr = (x+(x+y)*(x+y))*(x+y)
|
||||
expanded = expand(expr)
|
||||
expanded = regroup(expanded, Add)
|
||||
expanded = regroup(expanded, Mul)
|
||||
expanded_str = expanded.to_infix_str()
|
||||
if print_result:
|
||||
print("(x+(x+y)*(x+y))*(x+y)")
|
||||
print(f"expanded: {expanded_str}")
|
||||
|
||||
|
||||
def test_regroup(print_result=False):
|
||||
expr = x+x+2*x+y+x+3*y+x
|
||||
regrouped = regroup(expr, Add)
|
||||
|
@ -81,7 +80,7 @@ def test_regroup(print_result=False):
|
|||
|
||||
|
||||
def test_newton_bin(print_result=False):
|
||||
expr = (x+y)**4
|
||||
expr = (x+y)**2
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
expanded = expand(ungrouped)
|
||||
expand_grouped = regroup(expanded, Add)
|
||||
|
@ -96,11 +95,11 @@ def test_newton_bin(print_result=False):
|
|||
print(f"expanded: {expanded.to_infix_str()}")
|
||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||
|
||||
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
||||
#assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
|
||||
|
||||
|
||||
@timeit
|
||||
def test_big_bin_for_performance(print_result=False):
|
||||
expr = (x+y)**11
|
||||
expr = (x+y)**12
|
||||
print("a")
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
print("b")
|
||||
|
@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False):
|
|||
print(f"expanded: {expanded.to_infix_str()}")
|
||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||
|
||||
|
||||
def test_trinomial(print_result=False):
|
||||
expr = (x+y+z)**3
|
||||
ungrouped = ungroup(expr, Exp)
|
||||
expanded = expand(ungrouped)
|
||||
expand_grouped = regroup(expanded, Add)
|
||||
expand_grouped = regroup(expand_grouped, Mul)
|
||||
|
||||
if print_result:
|
||||
print("---")
|
||||
print("trinomial test")
|
||||
print("---")
|
||||
print(f"expr: {expr.to_infix_str()}")
|
||||
print(f"ungrouped: {ungrouped.to_infix_str()}")
|
||||
print(f"expanded: {expanded.to_infix_str()}")
|
||||
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_var()
|
||||
print_r = True
|
||||
|
@ -127,9 +145,10 @@ if __name__ == '__main__':
|
|||
test_parse(print_result=print_r)
|
||||
test_expr_to_str(print_result=print_r)
|
||||
test_expand(print_result=print_r)
|
||||
test_nested_expand(print_result=print_r)
|
||||
test_regroup(print_result=print_r)
|
||||
test_newton_bin(print_result=print_r)
|
||||
test_big_bin_for_performance(print_result=print_r)
|
||||
#test_trinomial(print_result=print_r)
|
||||
cProfile.run('test_big_bin_for_performance(print_result=print_r)')
|
||||
print("All tests passed")
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue