Optimization of expand in basic_modif.py

Making Expr immutable, and then optimize __hash__ with this
This commit is contained in:
Clément Barthélemy 2024-02-24 02:57:07 +01:00
parent c27d706a26
commit 948e4da10c
6 changed files with 92 additions and 37 deletions

View file

@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var
from python_symb.Expressions.expr import Expr
from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin
from python_symb.TreeModification.basic_modif import expand, regroup, ungroup
import time
from functools import wraps
import cProfile
def timeit(func):
@wraps(func)
def measure_time(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time))
return result
return measure_time
# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful
def create_var():
global x, y
x = var('x')
y = var('y')
global x, y, z, a, b, c
x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c')
def test_sum(print_result=False):
@ -69,6 +57,17 @@ def test_expand(print_result=False):
assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])])
def test_nested_expand(print_result=False):
expr = (x+(x+y)*(x+y))*(x+y)
expanded = expand(expr)
expanded = regroup(expanded, Add)
expanded = regroup(expanded, Mul)
expanded_str = expanded.to_infix_str()
if print_result:
print("(x+(x+y)*(x+y))*(x+y)")
print(f"expanded: {expanded_str}")
def test_regroup(print_result=False):
expr = x+x+2*x+y+x+3*y+x
regrouped = regroup(expr, Add)
@ -81,7 +80,7 @@ def test_regroup(print_result=False):
def test_newton_bin(print_result=False):
expr = (x+y)**4
expr = (x+y)**2
ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add)
@ -96,11 +95,11 @@ def test_newton_bin(print_result=False):
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
#assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])])
@timeit
def test_big_bin_for_performance(print_result=False):
expr = (x+y)**11
expr = (x+y)**12
print("a")
ungrouped = ungroup(expr, Exp)
print("b")
@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False):
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
def test_trinomial(print_result=False):
expr = (x+y+z)**3
ungrouped = ungroup(expr, Exp)
expanded = expand(ungrouped)
expand_grouped = regroup(expanded, Add)
expand_grouped = regroup(expand_grouped, Mul)
if print_result:
print("---")
print("trinomial test")
print("---")
print(f"expr: {expr.to_infix_str()}")
print(f"ungrouped: {ungrouped.to_infix_str()}")
print(f"expanded: {expanded.to_infix_str()}")
print(f"expand_grouped: {expand_grouped.to_infix_str()}")
if __name__ == '__main__':
create_var()
print_r = True
@ -127,9 +145,10 @@ if __name__ == '__main__':
test_parse(print_result=print_r)
test_expr_to_str(print_result=print_r)
test_expand(print_result=print_r)
test_nested_expand(print_result=print_r)
test_regroup(print_result=print_r)
test_newton_bin(print_result=print_r)
test_big_bin_for_performance(print_result=print_r)
#test_trinomial(print_result=print_r)
cProfile.run('test_big_bin_for_performance(print_result=print_r)')
print("All tests passed")