Python-symbolic/python_symb/Expressions/expr.py

111 lines
2.6 KiB
Python

from __future__ import annotations
import python_symb.MathTypes.symbols
from python_symb.Expressions.tree import Tree
from python_symb.MathTypes.operator_file import Add, Mul, Min
from python_symb.MathTypes.operator_file import UnaryOperator, BinOperator, Add, Mul, Min
from typing import List
from python_symb.MathTypes.symbols import Var
from python_symb.Parsing.parse import infix_str_to_postfix
class Expr(Tree):
"""
A class to represent an expression tree
value: the value of the node
children: the subtrees of the root (Default : None)
exemple :
5*(2+3) represented as
Expr(Mul, [Expr(5), Expr(Expr(Add, [Expr(2), Expr(3)]))])
"""
def __init__(self, value, children=None):
super().__init__(value, children if children else [])
@staticmethod
def from_postfix_list(postfix: List):
"""
Create an expression tree from a postfix list of tokens
tokens are : int, float, Var, UnaryOperator, BinOperator
exemple :
x = Var('x')
[5, 2, Add] -> Expr(Add, [Expr(5), Expr(2)])
"""
def aux():
first = postfix.pop()
match first:
case int() | Var():
return Expr(first)
case UnaryOperator():
return Expr(first, [aux()])
case BinOperator():
return Expr(first, [aux(), aux()])
return aux()
@staticmethod
def from_infix_str(expr_str) -> Expr:
"""
Create an expression tree from an infix string
"""
expr_rev_polish = infix_str_to_postfix(expr_str)
return Expr.from_postfix_list(expr_rev_polish)
def bin_op_constructor(self, other, op):
"""
Construct a binary operation
"""
match other:
case Expr():
return Expr(op, [self, other])
case Var() | int() | float():
return Expr(op, [self, Expr(other)])
case _:
return ValueError(f'Invalid type for operation: {other} : {type(other)}')
def __add__(self, other):
return self.bin_op_constructor(other, Add)
def __mul__(self, other):
return self.bin_op_constructor(other, Mul)
def __sub__(self, other):
return self.bin_op_constructor(other, Min)
def test1():
x, y = Var('x'), Var('y')
expr1 = x + y
expr2 = 5+x
print(expr1 + expr2)
def test2():
from python_symb.MathTypes.operator_file import Sin
a, b = Var('a'), Var('b')
expr = Sin(a+b)
print(expr)
if __name__ == '__main__':
test1()
test2()