Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added r_ and c_ to jax.numpy #21768

Merged
merged 8 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions ivy/functional/frontends/jax/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ def at(self):
def T(self):
return self.ivy_array.T

@property
def ndim(self):
return self.ivy_array.ndim

# Instance Methods #
# ---------------- #

Expand Down Expand Up @@ -146,6 +150,16 @@ def any(self, *, axis=None, out=None, keepdims=False, where=None):
self._ivy_array, axis=axis, keepdims=keepdims, out=out, where=where
)

def reshape(self, *args, order="C"):
if not isinstance(args[0], int):
if len(args) > 1:
raise TypeError(
"Shapes must be 1D sequences of concrete values of integer type,"
f" got {args}."
)
args = args[0]
return jax_frontend.numpy.reshape(self, tuple(args), order)

def __add__(self, other):
return jax_frontend.numpy.add(self, other)

Expand Down Expand Up @@ -264,8 +278,7 @@ def __setitem__(self, idx, val):
)

def __iter__(self):
ndim = len(self.shape)
if ndim == 0:
if self.ndim == 0:
raise TypeError("iteration over a 0-d Array not supported")
for i in range(self.shape[0]):
yield self[i]
Expand Down
101 changes: 101 additions & 0 deletions ivy/functional/frontends/jax/numpy/indexing.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# global
import inspect
import abc

# local
import ivy
from ivy.functional.frontends.jax.func_wrapper import (
to_ivy_arrays_and_back,
)
from .creation import linspace, arange, array
from .manipulations import transpose, concatenate, expand_dims


@to_ivy_arrays_and_back
Expand Down Expand Up @@ -96,3 +99,101 @@ def indices(dimensions, dtype=int, sparse=False):
else:
grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij")
return ivy.stack(grid, axis=0).astype(dtype)


def _make_1d_grid_from_slice(s):
step = 1 if s.step is None else s.step
start = 0 if s.start is None else s.start
if s.step is not None and ivy.is_complex_dtype(s.step):
newobj = linspace(start, s.stop, int(abs(step)))
else:
newobj = arange(start, s.stop, step)
return newobj


class _AxisConcat(abc.ABC):
axis: int
ndmin: int
trans1d: int

def __getitem__(self, key):
key_tup = key if isinstance(key, tuple) else (key,)

params = [self.axis, self.ndmin, self.trans1d, -1]

directive = key_tup[0]
if isinstance(directive, str):
key_tup = key_tup[1:]
# check two special cases: matrix directives
if directive == "r":
params[-1] = 0
elif directive == "c":
params[-1] = 1
else:
vec = directive.split(",")
k = len(vec)
if k < 4:
vec += params[k:]
else:
# ignore everything after the first three comma-separated ints
vec = vec[:3] + [params[-1]]
try:
params = list(map(int, vec))
except ValueError as err:
raise ValueError(
f"could not understand directive {directive!r}"
) from err

axis, ndmin, trans1d, matrix = params

output = []
for item in key_tup:
if isinstance(item, slice):
newobj = _make_1d_grid_from_slice(item)
item_ndim = 0
elif isinstance(item, str):
raise ValueError("string directive must be placed at the beginning")
else:
newobj = array(item, copy=False)
item_ndim = newobj.ndim

newobj = array(newobj, copy=False, ndmin=ndmin)

if trans1d != -1 and ndmin - item_ndim > 0:
shape_obj = tuple(range(ndmin))
# Calculate number of left shifts, with overflow protection by mod
num_lshifts = ndmin - abs(ndmin + trans1d + 1) % ndmin
shape_obj = tuple(shape_obj[num_lshifts:] + shape_obj[:num_lshifts])

newobj = transpose(newobj, shape_obj)

output.append(newobj)

res = concatenate(tuple(output), axis=axis)

if matrix != -1 and res.ndim == 1:
# insert 2nd dim at axis 0 or 1
res = expand_dims(res, matrix)

return res

def __len__(self) -> int:
return 0


class RClass(_AxisConcat):
axis = 0
ndmin = 1
trans1d = -1


r_ = RClass()


class CClass(_AxisConcat):
axis = -1
ndmin = 2
trans1d = 0


c_ = CClass()
66 changes: 66 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from ivy_tests.test_ivy.test_functional.test_core.test_statistical import (
_get_castable_dtype,
)
from ivy_tests.test_ivy.test_frontends.test_jax.test_numpy.test_manipulations import (
_get_input_and_reshape,
)

CLASS_TREE = "ivy.functional.frontends.jax.numpy.ndarray"

Expand Down Expand Up @@ -55,6 +58,24 @@ def test_jax_array_dtype(
assert x.dtype == dtype[0]


@given(
dtype_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
),
)
def test_jax_array_ndim(
dtype_x,
backend_fw,
):
dtype, data = dtype_x
with update_backend(backend_fw) as ivy_backend:
jax_frontend = ivy_backend.utils.dynamic_import.import_module(
"ivy.functional.frontends.jax"
)
x = jax_frontend.Array(data[0])
assert x.ndim == data[0].ndim


@given(
dtype_x_shape=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False),
Expand Down Expand Up @@ -2341,6 +2362,50 @@ def test_jax_array_searchsorted(
)



@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="jax.numpy.array",
method_name="reshape",
dtype_and_x_shape=_get_input_and_reshape(),
order=st.sampled_from(["C", "F"]),
input=st.booleans(),
)
def test_jax_array_reshape(
dtype_and_x_shape,
order,
input,
frontend,
frontend_method_data,
init_flags,
method_flags,
on_device,
backend_fw,
):
input_dtype, x, shape = dtype_and_x_shape
if input:
method_flags.num_positional_args = len(shape)
kwargs = {f"{i}": shape[i] for i in range(len(shape))}
else:
kwargs = {"shape": shape}
method_flags.num_positional_args = 1
kwargs["order"] = order
helpers.test_frontend_method(
backend_to_test=backend_fw,
init_input_dtypes=input_dtype,
init_all_as_kwargs_np={
"object": x[0],
},
method_input_dtypes=input_dtype,
method_all_as_kwargs_np=kwargs,
frontend=frontend,
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
on_device=on_device,
)


# repeat
@st.composite
def _repeat_helper(draw):
Expand Down Expand Up @@ -2435,3 +2500,4 @@ def test_jax_repeat(
method_flags=method_flags,
on_device=on_device,
)

Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# global
from hypothesis import strategies as st, assume
import numpy as np
from jax.numpy import tril, triu
from jax.numpy import tril, triu, r_, c_


# local
import ivy_tests.test_ivy.helpers as helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.helpers import handle_frontend_test, update_backend
from ...test_numpy.test_indexing_routines.test_inserting_data_into_arrays import (
_helper_r_,
_helper_c_,
)
import ivy.functional.frontends.jax.numpy as jnp_frontend


# diagonal
Expand Down Expand Up @@ -459,3 +464,20 @@ def test_jax_numpy_indices(
dtype=dtype[0],
sparse=sparse,
)


@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_r_()) # dummy fn_tree
def test_jax_numpy_r_(inputs, backend_fw):
inputs, *_ = inputs
ret_gt = r_.__getitem__(tuple(inputs))
with update_backend(backend_fw):
ret = jnp_frontend.r_.__getitem__(tuple(inputs))
assert np.allclose(ret.ivy_array, ret_gt)


@handle_frontend_test(fn_tree="jax.numpy.add", inputs=_helper_c_()) # dummy fn_tree
def test_jax_numpy_c_(inputs, backend_fw):
ret_gt = c_.__getitem__(tuple(inputs))
with update_backend(backend_fw):
ret = jnp_frontend.c_.__getitem__(tuple(inputs))
assert np.allclose(ret.ivy_array, ret_gt)
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# local
import ivy_tests.test_ivy.helpers as helpers
import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers
from ivy_tests.test_ivy.helpers import handle_frontend_test
from ivy_tests.test_ivy.helpers import handle_frontend_test, update_backend
import ivy.functional.frontends.numpy as np_frontend


Expand Down Expand Up @@ -148,10 +148,11 @@ def test_numpy_fill_diagonal(


@handle_frontend_test(fn_tree="numpy.add", inputs=_helper_r_()) # dummy fn_tree
def test_numpy_r_(inputs):
def test_numpy_r_(inputs, backend_fw):
inputs, elems_in_last_dim, dim = inputs
ret_gt = np.r_.__getitem__(tuple(inputs))
ret = np_frontend.r_.__getitem__(tuple(inputs))
with update_backend(backend_fw):
ret = np_frontend.r_.__getitem__(tuple(inputs))
if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]:
ret = ret._data
else:
Expand All @@ -160,9 +161,10 @@ def test_numpy_r_(inputs):


@handle_frontend_test(fn_tree="numpy.add", inputs=_helper_c_()) # dummy fn_tree
def test_numpy_c_(inputs):
def test_numpy_c_(inputs, backend_fw):
ret_gt = np.c_.__getitem__(tuple(inputs))
ret = np_frontend.c_.__getitem__(tuple(inputs))
with update_backend(backend_fw):
ret = np_frontend.c_.__getitem__(tuple(inputs))
if isinstance(inputs[0], str) and inputs[0] in ["r", "c"]:
ret = ret._data
else:
Expand Down
Loading