Skip to content

Commit

Permalink
fix(window_functions): specify if bounds are ROWS or RANGE (#1131)
Browse files Browse the repository at this point in the history
xref #1130
  • Loading branch information
gforsyth authored Sep 19, 2024
1 parent 7fe0f20 commit 311aad6
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 0 deletions.
7 changes: 7 additions & 0 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,9 @@ def value_op(
)


_bounds_enum = {"rows": 1, "range": 2}


@translate.register(ops.WindowOp) # type: ignore
def window_op(
op: ops.WindowOp, # type: ignore
Expand All @@ -537,6 +540,9 @@ def window_op(
end = op.end
func = op.func
func_args = op.func.args
how = op.how

bounds_type = _bounds_enum[how]

lower_bound, upper_bound = _translate_window_bounds(start, end)

Expand All @@ -558,6 +564,7 @@ def window_op(
],
lower_bound=lower_bound,
upper_bound=upper_bound,
bounds_type=bounds_type,
)
)

Expand Down
32 changes: 32 additions & 0 deletions ibis_substrait/tests/compiler/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,3 +584,35 @@ def test_join_chain_indexing_in_group_by(compiler):
.selection.direct_reference.struct_field.field
== 7
)


_window_hows = {
"unspecified": "BOUNDS_TYPE_UNSPECIFIED",
"range": "BOUNDS_TYPE_RANGE",
"rows": "BOUNDS_TYPE_ROWS",
}


@pytest.mark.parametrize(
"bounds",
[
(-4, 2),
(1, 5),
(None, None),
(2, 4),
],
)
@pytest.mark.parametrize("how", ["range", "rows"])
def test_aggregation_window_how(t, compiler, bounds, how):
how_arg = {how: bounds}
expr = t.projection(
[t.full_name.length().mean().over(ibis.window(group_by="age", **how_arg))]
)
result = translate(expr, compiler=compiler)

bounds_type_int = result.project.expressions[0].window_function.bounds_type

assert (
stalg.Expression.WindowFunction.BoundsType.Name(bounds_type_int)
== _window_hows[how]
)

0 comments on commit 311aad6

Please sign in to comment.