Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added arithmetic expression support, closes #1093 #1444

Open
wants to merge 1 commit into
base: staging
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions evadb/expression/abstract_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ExpressionType(IntEnum):
ARITHMETIC_SUBTRACT = auto()
ARITHMETIC_MULTIPLY = auto()
ARITHMETIC_DIVIDE = auto()
ARITHMETIC_MODULUS = auto()

FUNCTION_EXPRESSION = auto()

Expand Down
55 changes: 52 additions & 3 deletions evadb/expression/arithmetic_expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,62 @@ def __init__(
super().__init__(exp_type, rtype=ExpressionReturnType.FLOAT, children=children)

def evaluate(self, *args, **kwargs):
vl = self.get_child(0).evaluate(*args, **kwargs)
vr = self.get_child(1).evaluate(*args, **kwargs)
lbatch = self.get_child(0).evaluate(*args, **kwargs)
rbatch = self.get_child(1).evaluate(*args, **kwargs)

return Batch.combine_batches(vl, vr, self.etype)
assert len(lbatch) == len(
rbatch
), f"Left and Right batch does not have equal elements: left: {len(lbatch)} right: {len(rbatch)}"

assert self.etype in [
ExpressionType.ARITHMETIC_ADD,
ExpressionType.ARITHMETIC_SUBTRACT,
ExpressionType.ARITHMETIC_DIVIDE,
ExpressionType.ARITHMETIC_MULTIPLY,
ExpressionType.ARITHMETIC_MODULUS,
], f"Expression type not supported {self.etype}"

if self.etype == ExpressionType.ARITHMETIC_ADD:
return Batch.from_add(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return Batch.from_subtract(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return Batch.from_multiply(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return Batch.from_divide(lbatch, rbatch)
elif self.etype == ExpressionType.ARITHMETIC_MODULUS:
return Batch.from_modulus(lbatch, rbatch)

return Batch.combine_batches(lbatch, rbatch, self.etype)

def get_symbol(self) -> str:
if self.etype == ExpressionType.ARITHMETIC_ADD:
return "+"
elif self.etype == ExpressionType.ARITHMETIC_SUBTRACT:
return "-"
elif self.etype == ExpressionType.ARITHMETIC_MULTIPLY:
return "*"
elif self.etype == ExpressionType.ARITHMETIC_DIVIDE:
return "/"
elif self.etype == ExpressionType.ARITHMETIC_MODULUS:
return "%"

def __str__(self) -> str:
expr_str = "("
if self.get_child(0):
expr_str += f"{self.get_child(0)}"
if self.etype:
expr_str += f" {self.get_symbol()} "
if self.get_child(1):
expr_str += f"{self.get_child(1)}"
expr_str += ")"
return expr_str

def __eq__(self, other):
is_subtree_equal = super().__eq__(other)
if not isinstance(other, ArithmeticExpression):
return False
return is_subtree_equal and self.etype == other.etype

def __hash__(self) -> int:
return super().__hash__()
20 changes: 20 additions & 0 deletions evadb/models/storage/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ def deserialize(cls, data):
obj = PickleSerializer.deserialize(data)
return cls(frames=obj["frames"])

@classmethod
def from_add(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() + batch2.to_numpy()))

@classmethod
def from_subtract(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() - batch2.to_numpy()))

@classmethod
def from_multiply(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() * batch2.to_numpy()))

@classmethod
def from_divide(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() / batch2.to_numpy()))

@classmethod
def from_modulus(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() % batch2.to_numpy()))

@classmethod
def from_eq(cls, batch1: Batch, batch2: Batch) -> Batch:
return Batch(pd.DataFrame(batch1.to_numpy() == batch2.to_numpy()))
Expand Down
16 changes: 11 additions & 5 deletions evadb/parser/evadb.lark
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,13 @@ predicate: predicate NOT? IN "(" (select_statement | expressions) ")" ->in_pred
| predicate comparison_operator predicate -> binary_comparison_predicate
| predicate comparison_operator (ALL | ANY | SOME) "(" select_statement ")" ->subquery_comparison_predicate
| assign_var ->expression_atom_predicate
| expression_atom
| arithmetic_expression

arithmetic_expression: product
| arithmetic_expression add_sub_operator product -> arithmetic_expression_atom

product: expression_atom
| product div_mul_mod_operator expression_atom -> arithmetic_expression_atom

assign_var.1: LOCAL_ID VAR_ASSIGN expression_atom

Expand All @@ -332,8 +338,7 @@ expression_atom.2: constant ->constant_expression_atom
| unary_operator expression_atom ->unary_expression_atom
| "(" expression ("," expression)* ")" ->nested_expression_atom
| "(" select_statement ")" ->subquery_expession_atom
| expression_atom bit_operator expression_atom ->bit_expression_atom
| expression_atom math_operator expression_atom
| expression_atom bit_operator expression_atom ->bit_expression_atom

unary_operator: EXCLAMATION_SYMBOL | BIT_NOT_OP | PLUS | MINUS | NOT

Expand All @@ -343,7 +348,8 @@ logical_operator: AND | XOR | OR

bit_operator: "<<" | ">>" | "&" | "^" | "|"

math_operator: STAR | DIVIDE | MODULUS | DIV | MOD | PLUS | MINUS | MINUSMINUS
div_mul_mod_operator: DIVIDE | STAR | MODULUS
add_sub_operator: PLUS | MINUS

// KEYWORDS

Expand Down Expand Up @@ -526,7 +532,7 @@ OR_ASSIGN: "|="

STAR: "*"
DIVIDE: "/"
MODULUS: "%"
MODULUS: "%"
PLUS: "+"
MINUSMINUS: "--"
MINUS: "-"
Expand Down
23 changes: 23 additions & 0 deletions evadb/parser/lark_visitor/_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from evadb.catalog.catalog_type import ColumnType
from evadb.expression.abstract_expression import ExpressionType
from evadb.expression.arithmetic_expression import ArithmeticExpression
from evadb.expression.comparison_expression import ComparisonExpression
from evadb.expression.constant_value_expression import ConstantValueExpression
from evadb.expression.logical_expression import LogicalExpression
Expand Down Expand Up @@ -60,6 +61,28 @@ def constant(self, tree):

return self.visit_children(tree)

def arithmetic_expression_atom(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
right = self.visit(tree.children[2])
return ArithmeticExpression(op, left, right)

def div_mul_mod_operator(self, tree):
op = str(tree.children[0])
if op == "*":
return ExpressionType.ARITHMETIC_MULTIPLY
elif op == "/":
return ExpressionType.ARITHMETIC_DIVIDE
elif op == "%":
return ExpressionType.ARITHMETIC_MODULUS

def add_sub_operator(self, tree):
op = str(tree.children[0])
if op == "+":
return ExpressionType.ARITHMETIC_ADD
elif op == "-":
return ExpressionType.ARITHMETIC_SUBTRACT

def logical_expression(self, tree):
left = self.visit(tree.children[0])
op = self.visit(tree.children[1])
Expand Down