From 948e4da10c262c75bff16ba4ece68440a68b9dc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20Barth=C3=A9lemy?= Date: Sat, 24 Feb 2024 02:57:07 +0100 Subject: [PATCH] Optimization of expand in basic_modif.py Making Expr immutable, and then optimize __hash__ with this --- python_symb/Expressions/expr.py | 51 ++++++++++++---- .../{tools.py => math_tools.py} | 0 .../IndependantTools/programming_tools.py | 0 python_symb/MathTypes/fraction.py | 2 +- python_symb/TreeModification/basic_modif.py | 15 ++++- python_symb/test.py | 61 ++++++++++++------- 6 files changed, 92 insertions(+), 37 deletions(-) rename python_symb/IndependantTools/{tools.py => math_tools.py} (100%) create mode 100644 python_symb/IndependantTools/programming_tools.py diff --git a/python_symb/Expressions/expr.py b/python_symb/Expressions/expr.py index eac9577..5025331 100644 --- a/python_symb/Expressions/expr.py +++ b/python_symb/Expressions/expr.py @@ -14,6 +14,7 @@ from python_symb.Parsing.parse import infix_str_to_postfix class Expr(Tree): """ A class to represent an expression tree + is immutable value: the value of the node children: the subtrees of the root (Default : None) @@ -23,19 +24,39 @@ class Expr(Tree): Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))]) """ + __slots__ = ['value', 'children', '_is_frozen', '_hash'] __match_args__ = ('value', 'children') - def __init__(self, value, children=None): + def __init__(self, value, children=None, debug=False): from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator + self._is_frozen = False super().__init__(value, children if children else []) - assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr' + self._is_frozen = True + self._hash = None - match value: - case BinOperator() as op: - assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}' + if debug: - case UnaryOperator() as op: - assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}' + assert all([isinstance(child, Expr) for child in + self.children]), f'Invalid children: {self.children} all child should be Expr' + + match value: + case BinOperator() as op: + assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}' + + case UnaryOperator() as op: + assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}' + + def __setattr__(self, key, value): + if key == "_hash" or key == "_is_frozen": + super().__setattr__(key, value) + elif getattr(self, "_is_frozen", True): + raise TypeError(f"{self.__class__.__name__} is immutable") + super().__setattr__(key, value) + + def __delattr__(self, item): + if getattr(self, "_is_frozen", True): + raise TypeError(f"{self.__class__.__name__} is immutable") + super().__delattr__(item) @staticmethod @@ -91,7 +112,6 @@ class Expr(Tree): - @staticmethod def bin_op_constructor(self, other, op): """ @@ -139,19 +159,26 @@ class Expr(Tree): Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash see test_eq """ + if self._hash: + return self._hash + match self: case Expr(value) if self.is_leaf: - return hash(value) + self._hash = hash(value) + return self._hash case Expr(UnaryOperator() as op, [child]): - return hash(op.name + str(hash(child))) + self._hash = hash(op.name + str(hash(child))) + return self._hash case Expr(BinOperator() as op, [left, right]): if op.properties.commutative and op.properties.associative: - return hash(op.name) + hash(left) + hash(right) + self._hash = hash(op.name) + hash(left) + hash(right) + return self._hash else: - return hash(op.name) + hash(str(hash(left)) + str(hash(right))) + self._hash = hash(op.name) + hash(str(hash(left)) + str(hash(right))) + return self._hash case _: print(f'Invalid type: {type(self)}') diff --git a/python_symb/IndependantTools/tools.py b/python_symb/IndependantTools/math_tools.py similarity index 100% rename from python_symb/IndependantTools/tools.py rename to python_symb/IndependantTools/math_tools.py diff --git a/python_symb/IndependantTools/programming_tools.py b/python_symb/IndependantTools/programming_tools.py new file mode 100644 index 0000000..e69de29 diff --git a/python_symb/MathTypes/fraction.py b/python_symb/MathTypes/fraction.py index 3fa6939..43f4aae 100644 --- a/python_symb/MathTypes/fraction.py +++ b/python_symb/MathTypes/fraction.py @@ -1,6 +1,6 @@ from __future__ import annotations from typing import Iterable, Generator -from python_symb.IndependantTools.tools import gcd +from python_symb.IndependantTools.math_tools import gcd class Fraction: diff --git a/python_symb/TreeModification/basic_modif.py b/python_symb/TreeModification/basic_modif.py index ce86013..827cc8f 100644 --- a/python_symb/TreeModification/basic_modif.py +++ b/python_symb/TreeModification/basic_modif.py @@ -8,6 +8,7 @@ Number = Union[int, float] def expand(expr: Expr) -> Expr: + """ Expand an expression @@ -23,14 +24,22 @@ def expand(expr: Expr) -> Expr: return expr match expr: + case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive: - return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children])) + expanded_right = expand(right) + return expand(Expr(Op2, [Expr(Op1, [op2_child, expanded_right]) for op2_child in op2_children])) case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive: - return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children])) + expanded_left = expand(left) + return expand(Expr(Op2, [Expr(Op1, [expanded_left, op2_child]) for op2_child in op2_children])) case Expr(BinOperator() as Op, [left, right]): - return Expr(Op, [expand(left), expand(right)]) + left_leaf, right_leaf = left.is_leaf, right.is_leaf + if not left_leaf: + left = expand(left) + if not right_leaf: + right = expand(right) + return Expr(Op, [left, right]) return expr diff --git a/python_symb/test.py b/python_symb/test.py index ca94884..c5d3331 100644 --- a/python_symb/test.py +++ b/python_symb/test.py @@ -2,27 +2,15 @@ from python_symb.MathTypes.symbols import var from python_symb.Expressions.expr import Expr from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin from python_symb.TreeModification.basic_modif import expand, regroup, ungroup -import time -from functools import wraps +import cProfile -def timeit(func): - @wraps(func) - def measure_time(*args, **kwargs): - start_time = time.time() - result = func(*args, **kwargs) - end_time = time.time() - print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time)) - return result - return measure_time - # assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful def create_var(): - global x, y - x = var('x') - y = var('y') + global x, y, z, a, b, c + x, y, z, a, b, c = var('x'), var('y'), var('z'), var('a'), var('b'), var('c') def test_sum(print_result=False): @@ -69,6 +57,17 @@ def test_expand(print_result=False): assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])]) +def test_nested_expand(print_result=False): + expr = (x+(x+y)*(x+y))*(x+y) + expanded = expand(expr) + expanded = regroup(expanded, Add) + expanded = regroup(expanded, Mul) + expanded_str = expanded.to_infix_str() + if print_result: + print("(x+(x+y)*(x+y))*(x+y)") + print(f"expanded: {expanded_str}") + + def test_regroup(print_result=False): expr = x+x+2*x+y+x+3*y+x regrouped = regroup(expr, Add) @@ -81,7 +80,7 @@ def test_regroup(print_result=False): def test_newton_bin(print_result=False): - expr = (x+y)**4 + expr = (x+y)**2 ungrouped = ungroup(expr, Exp) expanded = expand(ungrouped) expand_grouped = regroup(expanded, Add) @@ -96,11 +95,11 @@ def test_newton_bin(print_result=False): print(f"expanded: {expanded.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}") - assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) + #assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) + -@timeit def test_big_bin_for_performance(print_result=False): - expr = (x+y)**11 + expr = (x+y)**12 print("a") ungrouped = ungroup(expr, Exp) print("b") @@ -120,6 +119,25 @@ def test_big_bin_for_performance(print_result=False): print(f"expanded: {expanded.to_infix_str()}") print(f"expand_grouped: {expand_grouped.to_infix_str()}") + +def test_trinomial(print_result=False): + expr = (x+y+z)**3 + ungrouped = ungroup(expr, Exp) + expanded = expand(ungrouped) + expand_grouped = regroup(expanded, Add) + expand_grouped = regroup(expand_grouped, Mul) + + if print_result: + print("---") + print("trinomial test") + print("---") + print(f"expr: {expr.to_infix_str()}") + print(f"ungrouped: {ungrouped.to_infix_str()}") + print(f"expanded: {expanded.to_infix_str()}") + print(f"expand_grouped: {expand_grouped.to_infix_str()}") + + + if __name__ == '__main__': create_var() print_r = True @@ -127,9 +145,10 @@ if __name__ == '__main__': test_parse(print_result=print_r) test_expr_to_str(print_result=print_r) test_expand(print_result=print_r) + test_nested_expand(print_result=print_r) test_regroup(print_result=print_r) - test_newton_bin(print_result=print_r) - test_big_bin_for_performance(print_result=print_r) + #test_trinomial(print_result=print_r) + cProfile.run('test_big_bin_for_performance(print_result=print_r)') print("All tests passed")