diff --git a/python_symb/Expressions/expr.py b/python_symb/Expressions/expr.py index d26d828..eac9577 100644 --- a/python_symb/Expressions/expr.py +++ b/python_symb/Expressions/expr.py @@ -85,11 +85,9 @@ class Expr(Tree): case Expr(BinOperator() as op, [left, right]): op_name = op.name if not(implicit_mul and op == Mul) else '' if op.precedence < parent_precedence: - print("hehehe") - print(self, op.precedence, parent_precedence) - return f"({left.to_infix_str(op.precedence)}{op_name}{right.to_infix_str(op.precedence)})" + return f"({left.to_infix_str(op.precedence, implicit_mul)}{op_name}{right.to_infix_str(op.precedence, implicit_mul)})" else: - return f"{left.to_infix_str(op.precedence)}{op_name}{right.to_infix_str(op.precedence)}" + return f"{left.to_infix_str(op.precedence, implicit_mul)}{op_name}{right.to_infix_str(op.precedence, implicit_mul)}" @@ -123,8 +121,18 @@ class Expr(Tree): other_expr = other if isinstance(other, Expr) else Expr(other) return Expr.bin_op_constructor(other_expr, self, Mul) + def __pow__(self, other): + other_expr = other if isinstance(other, Expr) else Expr(other) + return Expr.bin_op_constructor(self, other_expr, Exp) + + def __neg__(self): + return Expr(Mul, [Expr(-1), self]) + def __sub__(self, other): - return Expr.bin_op_constructor(self, other, Min) + return self + (-other) + + def __rsub__(self, other): + return other + (-self) def __hash__(self): """ @@ -155,52 +163,41 @@ class Expr(Tree): """temporary""" return self.bad_eq(other) + def delete_node(self, node: Expr) -> Expr: + """ + return a new expression without all occurences of the node (with the equality defined by the __eq__ method) + + :param node: node to delete everywhere + """ + match self: + + case Expr(value) if self.is_leaf: + if self == node: + raise "Cannot delete a leaf node, don't" + else: + return self + + case Expr(UnaryOperator() as op, [child]): + if child == node: + raise ValueError("Cannot delete a node that is the child of a unary operator, don't.") + else: + return Expr(op, [child.delete_node(node)]) + + case Expr(BinOperator() as op, [left, right]): + if left == node: + return right + elif right == node: + return left + else: + return Expr(op, [left.delete_node(node), right.delete_node(node)]) + + case _: + raise ValueError(f'Invalid type: {type(self)}') -def test(): - x, y = var('x'), var('y') - a, b = var('a'), var('b') - def test1(): - expr1 = x + y - expr2 = 5+x - print(expr1 + expr2) - def test2(): - from python_symb.MathTypes.operator_file import Sin - expr = Sin(x+y) - print(expr) - def test_eq(): - expr = x + y + 3 - expr2 = 3 + x + y - print("----") - print(expr) - print(expr2) - print(expr == expr2) - def test_return_to_string(): - expr = x+y - new_expr = 5*expr - print(new_expr) - print(f"new_expr: {new_expr.to_infix_str()}") - - expr2 = x*x*x + y*y*y - print(expr2) - print(f"expr2: {expr2.to_infix_str()}") - - - print("test1") - test1() - print("test2") - test2() - print("test_eq") - test_eq() - print("test_return_to_string") - test_return_to_string() - - -if __name__ == '__main__': - test() diff --git a/python_symb/MathTypes/operator_file.py b/python_symb/MathTypes/operator_file.py index 7a7f862..8a7a46b 100644 --- a/python_symb/MathTypes/operator_file.py +++ b/python_symb/MathTypes/operator_file.py @@ -7,9 +7,13 @@ class Operator(Symbols): """ Represent an operator, like +, *, sin, anything that can be applied to an expression """ + # Store all the instances of Operator, used in the parser instances = {} + # The deconstruct operator of a repeated operator is used to deconstruct an expression (x+y)^2 -> (x+y)*(x+y) + # Mul is the deconstruct operator of Add + deconstruct_op_dict = {} - def __init__(self, name: str, precedence: int, call: Callable, repeated_op: Operator = None): + def __init__(self, name: str, precedence: int, call: Callable, repeated_op: BinOperator = None): """ :param name of the operator :param precedence: precedence of the operator, higher is better @@ -24,9 +28,16 @@ class Operator(Symbols): self.repeated_op = repeated_op Operator.instances[name] = self + if repeated_op: + Operator.deconstruct_op_dict[repeated_op] = self + def __repr__(self): return f'{self.name}' + @property + def deconstruct_op(self): + return Operator.deconstruct_op_dict.get(self, None) + class UnaryOperator(Operator): """ @@ -53,7 +64,7 @@ class BinProperties: """ def __init__(self, associative: bool, commutative: True, - left_distributivity: Set[str], right_distributivity: Set[str]): + left_distributivity: Set[str], right_distributivity: Set[str], neutral_element=None, absorbing_element=None): """ :param associative: True if the operator is associative :param commutative: True if the operator is commutative @@ -72,18 +83,29 @@ class BinProperties: self.commutative = commutative self.left_distributive = left_distributivity self.right_distributive = right_distributivity + self.neutral_element = neutral_element + self.absorbing_element = absorbing_element class BinOperator(Operator): """ Represent a binary operator, like +, *, etc... all operators that take two arguments + """ # Used to store all the instances of BinOperator, used in the parser instances = {} - def __init__(self, name: str, precedence: int, properties: BinProperties, call: Callable, repeated_op: Operator = None ): + def __init__(self, name: str, precedence: int, properties: BinProperties, call: Callable, repeated_op: BinOperator = None): + """ + :param name: name of the operator + :param precedence: precedence of the operator, higher is better + :param properties: properties of the operator + :param call: function to apply the operator + :param repeated_op: if you repeat the operator what do you get ? (for exemple a+a+a+a -> 4*a, the repeated_op of Add is Mul) + :param deconstruct_op: if you deconstruct the operator what do you get ? (for exemple 4*a -> a+a+a+a, the deconstruct_op of Mul is Add) + """ BinOperator.instances[name] = self super().__init__(name, precedence, call, repeated_op) self.properties = properties @@ -99,13 +121,13 @@ class BinOperator(Operator): """ Generic operators """ -ExpProperties = BinProperties(False, False, set(), set()) +ExpProperties = BinProperties(False, False, set(), set(), 1) Exp = BinOperator('^', 4, ExpProperties, lambda x, y: x ** y) -MulProperties = BinProperties(True, True, {'+'}, {'+'}) +MulProperties = BinProperties(True, True, {'+'}, {'+'}, 1, 0) Mul = BinOperator('*', 3, MulProperties, lambda x, y: x * y, Exp) -AddProperties = BinProperties(True, True, set(), set()) +AddProperties = BinProperties(True, True, set(), set(), 0) Add = BinOperator('+', 2, AddProperties, lambda x, y: x + y, Mul) diff --git a/python_symb/TreeModification/basic_modif.py b/python_symb/TreeModification/basic_modif.py index ce71a39..ce86013 100644 --- a/python_symb/TreeModification/basic_modif.py +++ b/python_symb/TreeModification/basic_modif.py @@ -1,6 +1,6 @@ from python_symb.Expressions.expr import Expr from python_symb.MathTypes.symbols import Var -from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul +from python_symb.MathTypes.operator_file import Operator, BinOperator, Add, Mul, Exp from typing import Union @@ -23,10 +23,10 @@ def expand(expr: Expr) -> Expr: return expr match expr: - case Expr(BinOperator() as Op1, [Expr(Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive: + case Expr(BinOperator() as Op1, [Expr(BinOperator() as Op2, op2_children), right]) if Op2.name in Op1.properties.left_distributive: return expand(Expr(Op2, [Expr(Op1, [expand(op2_child), expand(right)]) for op2_child in op2_children])) - case Expr(BinOperator() as Op1, [left, Expr(Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive: + case Expr(BinOperator() as Op1, [left, Expr(BinOperator() as Op2, op2_children)]) if Op2.name in Op1.properties.right_distributive: return expand(Expr(Op2, [Expr(Op1, [expand(left), expand(op2_child)]) for op2_child in op2_children])) case Expr(BinOperator() as Op, [left, right]): @@ -53,7 +53,7 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr: # Motifs : Key : (Expr) -> Value : int # represent number of times the expression appears in the expression, - # custom hash make for instance x+y and y+x the same when counting + # custom expr hash make for instance x+y and y+x the same when counting in motifs motifs = {} def collect_motifs(expr: Expr): @@ -76,11 +76,16 @@ def _regroup(expr: Expr, focus_op: BinOperator) -> Expr: def reconstruct(tuple_motifs): match tuple_motifs: case [(expr, int(a))]: - if focus_op.repeated_op.properties.commutative: + if a == focus_op.repeated_op.properties.neutral_element: + return expr + elif focus_op.repeated_op.properties.commutative: return Expr(focus_op.repeated_op, [Expr(a), expr]) return Expr(focus_op.repeated_op, [expr, Expr(a)]) + case [(expr, int(a)), *rest]: - if focus_op.repeated_op.properties.commutative: + if a == focus_op.repeated_op.properties.neutral_element: + return Expr(focus_op, [expr, reconstruct(rest)]) + elif focus_op.repeated_op.properties.commutative: return Expr(focus_op, [Expr(focus_op.repeated_op, [Expr(a), expr]), reconstruct(rest)]) return Expr(focus_op, [Expr(focus_op.repeated_op, [expr, Expr(a)]), reconstruct(rest)]) @@ -112,42 +117,46 @@ def regroup(expr: Expr, focus_op: BinOperator) -> Expr: return expr -def test(): +def ungroup(expr: Expr, focus_op: BinOperator) -> Expr: + """ + Ungroup an expression + + :param expr: expression to ungroup + :param focus_op: operator to ungroup + + exemple : + with focus_op = Exp + (x+y) ^ 2 -> (x+y)*(x+y) + """ + + def recreate(expr: Expr, nb_repeat: int) -> Expr: + if nb_repeat == 1: + return expr + return Expr(focus_op.deconstruct_op, [expr, recreate(expr, nb_repeat-1)]) + + if expr.is_leaf: + return expr + + match expr: + case Expr(BinOperator() as op, [left, leaf]) if op == focus_op and isinstance(leaf.value, int): + return recreate(ungroup(left, focus_op), leaf.value) + case Expr(BinOperator() as op, [leaf, right]) if op == focus_op and isinstance(leaf.value, int): + return recreate(ungroup(right, focus_op), leaf.value) + case Expr(BinOperator() as op, [left, right]): + return Expr(op, [ungroup(left, focus_op), ungroup(right, focus_op)]) + + return expr + + + - x, y = Var('x'), Var('y') - a, b = Var('a'), Var('b') - def test_expand(): - expr = (x+y)*(x+y)*(x+y) - expr = expand(expr) - expr = regroup(expr, Add) - print(f"(x+y)*(x+y)*(x+y) -> {expr.to_infix_str()}") - expr = (x+y+a)*b - print(f"(x+y+a)*b -> {expand(expr).to_infix_str()}") - def test_regroup(): - expr = x+2*x+y+y+2*y - print(f"x+2*x+y+y+2*y -> {regroup(expr, Add).to_infix_str()}") - def test_power(): - expr = x*x*x + y*y*y - print(f"x*x*x -> {regroup(expr, Mul).to_infix_str()}") - def test_all(): - expr = (x+y)*(x+y)*(x+y) - expanded_expr = expand(expr) - regrouped_expr = regroup(expanded_expr, Add) - print(f"(x+y)*(x+y)*(x+y) -> {regrouped_expr.to_infix_str()}") - test_expand() - test_regroup() - test_power() - #test_all() - -if __name__ == "__main__": - test() diff --git a/python_symb/test.py b/python_symb/test.py new file mode 100644 index 0000000..ca94884 --- /dev/null +++ b/python_symb/test.py @@ -0,0 +1,137 @@ +from python_symb.MathTypes.symbols import var +from python_symb.Expressions.expr import Expr +from python_symb.MathTypes.operator_file import Add, Mul, Exp, Sin +from python_symb.TreeModification.basic_modif import expand, regroup, ungroup +import time +from functools import wraps + + +def timeit(func): + @wraps(func) + def measure_time(*args, **kwargs): + start_time = time.time() + result = func(*args, **kwargs) + end_time = time.time() + print("@timefn: {} took {} seconds.".format(func.__name__, end_time - start_time)) + return result + return measure_time + + +# assert equal use none perfect implementation of __eq__ in expr (using weird handmade hash), be careful + +def create_var(): + global x, y + x = var('x') + y = var('y') + + +def test_sum(print_result=False): + expr = x + y + expr2 = 5 + x + if print_result: + print(f"expr1: {expr}") + print(f"expr2: {expr2}") + + assert expr == Expr(Add, [Expr(x), Expr(y)]) + assert expr2 == Expr(Add, [Expr(5), Expr(x)]) + + +def test_parse(print_result=False): + str_expr = "5*(x+y) + x*sin(y)" + expr = Expr.from_infix_str(str_expr) + if print_result: + print(f"expr: {expr}") + + assert expr == Expr(Add, [Expr(Mul, [Expr(5), Expr(Add, [Expr(x), Expr(y)])]), Expr(Mul, [Expr(x), Expr(Sin, [Expr(y)])])]) + + +def test_expr_to_str(print_result=False): + expr = 5*(x+y) + x*Sin(y) + str_expr = expr.to_infix_str(implicit_mul=True) + str_expr_no_implicit = expr.to_infix_str(implicit_mul=False) + + if print_result: + print(f"str_expr: {str_expr}") + print(f"str_expr_no_implicit: {str_expr_no_implicit}") + + assert str_expr == "5(x+y)+xsin(y)" + assert str_expr_no_implicit == "5*(x+y)+x*sin(y)" + + +def test_expand(print_result=False): + expr = (x+2)*(x+y) + expanded = expand(expr) + expanded_str = expanded.to_infix_str() + if print_result: + print(f"expanded: {expanded_str}") + + assert expanded_str == "xx+xy+2x+2y" + assert expanded == Expr(Add, [Expr(Add, [Expr(Mul, [Expr(x), Expr(x)]), Expr(Mul, [Expr(x), Expr(y)])]), Expr(Add, [Expr(Mul, [Expr(2), Expr(x)]), Expr(Mul, [Expr(2), Expr(y)])])]) + + +def test_regroup(print_result=False): + expr = x+x+2*x+y+x+3*y+x + regrouped = regroup(expr, Add) + regrouped_str = regrouped.to_infix_str() + if print_result: + print(f"regrouped: {regrouped_str}") + + assert regrouped_str == "6x+4y" + assert regrouped == Expr(Add, [Expr(Mul, [Expr(6), Expr(x)]), Expr(Mul, [Expr(4), Expr(y)])]) + + +def test_newton_bin(print_result=False): + expr = (x+y)**4 + ungrouped = ungroup(expr, Exp) + expanded = expand(ungrouped) + expand_grouped = regroup(expanded, Add) + expand_grouped = regroup(expand_grouped, Mul) + + if print_result: + print("---") + print("newton binomial test") + print("---") + print(f"expr: {expr.to_infix_str()}") + print(f"ungrouped: {ungrouped.to_infix_str()}") + print(f"expanded: {expanded.to_infix_str()}") + print(f"expand_grouped: {expand_grouped.to_infix_str()}") + + assert expand_grouped == Expr(Add, [Expr(Exp, [Expr(x), Expr(4)]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(Exp, [Expr(x), Expr(3)]),Expr(y)])]),Expr(Add, [Expr(Mul, [Expr(6),Expr(Mul, [Expr(Exp, [Expr(x), Expr(2)]),Expr(Exp, [Expr(y), Expr(2)])])]),Expr(Add, [Expr(Mul, [Expr(4),Expr(Mul, [Expr(x),Expr(Exp, [Expr(y), Expr(3)])])]),Expr(Exp, [Expr(y), Expr(4)])])])])]) + +@timeit +def test_big_bin_for_performance(print_result=False): + expr = (x+y)**11 + print("a") + ungrouped = ungroup(expr, Exp) + print("b") + expanded = expand(ungrouped) + print("c") + expand_grouped = regroup(expanded, Add) + print("d") + expand_grouped = regroup(expand_grouped, Mul) + print("e") + + if print_result: + print("---") + print("ultra big newton binomial test") + print("---") + print(f"expr: {expr.to_infix_str()}") + print(f"ungrouped: {ungrouped.to_infix_str()}") + print(f"expanded: {expanded.to_infix_str()}") + print(f"expand_grouped: {expand_grouped.to_infix_str()}") + +if __name__ == '__main__': + create_var() + print_r = True + test_sum(print_result=print_r) + test_parse(print_result=print_r) + test_expr_to_str(print_result=print_r) + test_expand(print_result=print_r) + test_regroup(print_result=print_r) + test_newton_bin(print_result=print_r) + test_big_bin_for_performance(print_result=print_r) + print("All tests passed") + + + +