diff --git a/ivy/functional/frontends/jax/lax/control_flow_operators.py b/ivy/functional/frontends/jax/lax/control_flow_operators.py index 3fb33fb590749..7a5761bf77801 100644 --- a/ivy/functional/frontends/jax/lax/control_flow_operators.py +++ b/ivy/functional/frontends/jax/lax/control_flow_operators.py @@ -58,3 +58,14 @@ def while_loop(cond_fun, body_fun, init_val): while cond_fun(val): val = body_fun(val) return val + +@to_ivy_arrays_and_back +def scan(f, init, xs, length=None): + if xs is None: + xs = [None] * length + carry = init + ys = [] + for x in xs: + carry, y = f(carry, x) + ys.append(y) + return carry, ivy.stack(ys) \ No newline at end of file