Python-symbolic/python_symb/Expressions/expr.py
2024-02-20 21:21:55 +01:00

213 lines
6.1 KiB
Python

from __future__ import annotations
import python_symb.MathTypes.symbols
from python_symb.Expressions.tree import Tree
from python_symb.MathTypes.operator_file import UnaryOperator, BinOperator, Add, Mul, Exp
from typing import List
from python_symb.MathTypes.symbols import Var, var
from python_symb.Parsing.parse import infix_str_to_postfix
class Expr(Tree):
"""
A class to represent an expression tree
value: the value of the node
children: the subtrees of the root (Default : None)
exemple :
5*(2+3) represented as
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
"""
__match_args__ = ('value', 'children')
def __init__(self, value, children=None):
from python_symb.MathTypes.operator_file import BinOperator, UnaryOperator
super().__init__(value, children if children else [])
assert all([isinstance(child, Expr) for child in self.children]), f'Invalid children: {self.children} all child should be Expr'
match value:
case BinOperator() as op:
assert len(self.children) == 2, f'Invalid number of children for BinOperator{op}: {len(self.children)}'
case UnaryOperator() as op:
assert len(self.children) == 1, f'Invalid number of children for UnaryOperator{op}: {len(self.children)}'
@staticmethod
def from_postfix_list(postfix: List):
"""
Create an expression tree from a postfix list of tokens
tokens are : int, float, Var, UnaryOperator, BinOperator
exemple :
x = Var('x')
[5, 2, Add] -> Expr(Add, [Expr(5), Expr(2)])
"""
def aux():
first = postfix.pop()
match first:
case int() | Var():
return Expr(first)
case UnaryOperator():
return Expr(first, [aux()])
case BinOperator():
return Expr(first, [aux(), aux()])
return aux()
@staticmethod
def from_infix_str(expr_str) -> Expr:
"""
Create an expression tree from an infix string
"""
expr_rev_polish = infix_str_to_postfix(expr_str)
return Expr.from_postfix_list(expr_rev_polish)
def to_infix_str(self, parent_precedence=-100, implicit_mul=True) -> str:
"""
Return the infix string of the expression
"""
match self:
case Expr(value) if self.is_leaf:
return str(value)
case Expr(UnaryOperator() as op, [child]):
return f"{op.name}({child.to_infix_str(parent_precedence=op.precedence)})"
case Expr(BinOperator() as op, [left, right]):
op_name = op.name if not(implicit_mul and op == Mul) else ''
if op.precedence < parent_precedence:
print("hehehe")
print(self, op.precedence, parent_precedence)
return f"({left.to_infix_str(op.precedence)}{op_name}{right.to_infix_str(op.precedence)})"
else:
return f"{left.to_infix_str(op.precedence)}{op_name}{right.to_infix_str(op.precedence)}"
@staticmethod
def bin_op_constructor(self, other, op):
"""
Construct a binary operation
"""
match other:
case Expr():
return Expr(op, [self, other])
case Var() | int() | float():
return Expr(op, [self, Expr(other)])
case _:
return ValueError(f'Invalid type for operation: {other} : {type(other)}')
def __add__(self, other):
other_expr = other if isinstance(other, Expr) else Expr(other)
return Expr.bin_op_constructor(self, other_expr, Add)
def __radd__(self, other):
other_expr = other if isinstance(other, Expr) else Expr(other)
return Expr.bin_op_constructor(other_expr, self, Add)
def __mul__(self, other):
other_expr = other if isinstance(other, Expr) else Expr(other)
return Expr.bin_op_constructor(self, other_expr, Mul)
def __rmul__(self, other):
other_expr = other if isinstance(other, Expr) else Expr(other)
return Expr.bin_op_constructor(other_expr, self, Mul)
def __sub__(self, other):
return Expr.bin_op_constructor(self, other, Min)
def __hash__(self):
"""
Two equivalent expressions (without more modification like factorisation, or expanding) should have the same hash
see test_eq
"""
match self:
case Expr(value) if self.is_leaf:
return hash(value)
case Expr(UnaryOperator() as op, [child]):
return hash(op.name + str(hash(child)))
case Expr(BinOperator() as op, [left, right]):
if op.properties.commutative and op.properties.associative:
return hash(op.name) + hash(left) + hash(right)
else:
return hash(op.name) + hash(str(hash(left)) + str(hash(right)))
case _:
print(f'Invalid type: {type(self)}')
def bad_eq(self, other):
return self.__hash__() == other.__hash__()
def __eq__(self, other):
"""temporary"""
return self.bad_eq(other)
def test():
x, y = var('x'), var('y')
a, b = var('a'), var('b')
def test1():
expr1 = x + y
expr2 = 5+x
print(expr1 + expr2)
def test2():
from python_symb.MathTypes.operator_file import Sin
expr = Sin(x+y)
print(expr)
def test_eq():
expr = x + y + 3
expr2 = 3 + x + y
print("----")
print(expr)
print(expr2)
print(expr == expr2)
def test_return_to_string():
expr = x+y
new_expr = 5*expr
print(new_expr)
print(f"new_expr: {new_expr.to_infix_str()}")
expr2 = x*x*x + y*y*y
print(expr2)
print(f"expr2: {expr2.to_infix_str()}")
print("test1")
test1()
print("test2")
test2()
print("test_eq")
test_eq()
print("test_return_to_string")
test_return_to_string()
if __name__ == '__main__':
test()