Skip to content

Commit

Permalink
use jax.nnx for native module
Browse files Browse the repository at this point in the history
  • Loading branch information
Nightcrab committed Jul 16, 2024
1 parent e8d64c4 commit 981a92d
Showing 1 changed file with 11 additions and 31 deletions.
42 changes: 11 additions & 31 deletions ivy/functional/backends/jax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from __future__ import annotations
import re
import jax
from flax import linen as nn
from flax import nnx as nn
import jax.tree_util as tree
import jax.numpy as jnp
import functools
Expand Down Expand Up @@ -1012,16 +1012,14 @@ def __call__(
buffers=None,
**kwargs,
):
ret = self.apply(v=v, *args, method=self.__class__._forward, **kwargs)
ret = self._call(v=v, *args, **kwargs)
return ret

def __getattr__(self, name):
if name == "v":
if not super().__getattribute__("_v") and not getattr( # noqa: E501
self, "_built", False
):
print("self._kwargs", self._kwargs)
print("_build_and_return_v")
return self._build_and_return_v(
*self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)
Expand Down Expand Up @@ -1153,9 +1151,6 @@ def _find_variables(
return lambda: getattr(obj, fn)(
*obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs
)
print(type(obj))
print(obj)
print(obj._kwargs)
return getattr(obj, fn)(
*obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs
)
Expand Down Expand Up @@ -1285,29 +1280,14 @@ def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain):
return _fn_with_var_arg_wrapper

def _call(self, *args, v=None, buffers=None, **kwargs):
if not self._built or not self.built:
if not self._built:
first_arr = self._get_first_array(*args, **kwargs)
self.build(
*args,
**kwargs,
from_call=True,
dtype=first_arr.dtype if first_arr is not None else tf.float32,
)

if not self.built:
# Don't use `keras` build method
if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false":
self.inputs = tf.nest.flatten(args)
else:
input_shapes = self._get_input_shapes(*args)
if len(input_shapes) == 0:
input_shapes = tf.TensorShape(None)
elif len(input_shapes) == 1:
input_shapes = input_shapes[0]

super(Model, self).build(tf.TensorShape(None)) # noqa: UP008

if not self._built:
first_arr = self._get_first_array(*args, **kwargs)
self.build(
*args,
**kwargs,
from_call=True,
dtype=first_arr.dtype if first_arr is not None else tf.float32,
)
# If `v` was provided, replace with the module's v
replace_v = False
if v is not None:
Expand All @@ -1331,7 +1311,7 @@ def _call(self, *args, v=None, buffers=None, **kwargs):
return ret
elif hasattr(self.__call__, "wrapped"):
return self.__call__(*args, **kwargs)
return super(Model, self).__call__(*args, **kwargs) # noqa: UP008
return self.forward(*args, **kwargs) # noqa: UP008

def __delattr__(self, name):
if hasattr(self, name):
Expand Down

0 comments on commit 981a92d

Please sign in to comment.