From 981a92d5bee38f3c1825f1e4513b0ac26dcf0ea6 Mon Sep 17 00:00:00 2001 From: Nightcrab Date: Tue, 16 Jul 2024 11:16:24 +0100 Subject: [PATCH] use jax.nnx for native module --- ivy/functional/backends/jax/module.py | 42 +++++++-------------------- 1 file changed, 11 insertions(+), 31 deletions(-) diff --git a/ivy/functional/backends/jax/module.py b/ivy/functional/backends/jax/module.py index 5599c60bd19e..4829a70a6c98 100644 --- a/ivy/functional/backends/jax/module.py +++ b/ivy/functional/backends/jax/module.py @@ -2,7 +2,7 @@ from __future__ import annotations import re import jax -from flax import linen as nn +from flax import nnx as nn import jax.tree_util as tree import jax.numpy as jnp import functools @@ -1012,7 +1012,7 @@ def __call__( buffers=None, **kwargs, ): - ret = self.apply(v=v, *args, method=self.__class__._forward, **kwargs) + ret = self._call(v=v, *args, **kwargs) return ret def __getattr__(self, name): @@ -1020,8 +1020,6 @@ def __getattr__(self, name): if not super().__getattribute__("_v") and not getattr( # noqa: E501 self, "_built", False ): - print("self._kwargs", self._kwargs) - print("_build_and_return_v") return self._build_and_return_v( *self._args, dynamic_backend=self._dynamic_backend, **self._kwargs ) @@ -1153,9 +1151,6 @@ def _find_variables( return lambda: getattr(obj, fn)( *obj._args, dynamic_backend=self._dynamic_backend, **obj._kwargs ) - print(type(obj)) - print(obj) - print(obj._kwargs) return getattr(obj, fn)( *obj._args, dynamic_backend=obj._dynamic_backend, **obj._kwargs ) @@ -1285,29 +1280,14 @@ def _fn_with_var_arg(self, fn, v_fn, /, keychain_mappings, orig_key_chain): return _fn_with_var_arg_wrapper def _call(self, *args, v=None, buffers=None, **kwargs): - if not self._built or not self.built: - if not self._built: - first_arr = self._get_first_array(*args, **kwargs) - self.build( - *args, - **kwargs, - from_call=True, - dtype=first_arr.dtype if first_arr is not None else tf.float32, - ) - - if not self.built: - # Don't use `keras` build method - if os.environ.get("USE_KERAS_BUILD", "False").lower() == "false": - self.inputs = tf.nest.flatten(args) - else: - input_shapes = self._get_input_shapes(*args) - if len(input_shapes) == 0: - input_shapes = tf.TensorShape(None) - elif len(input_shapes) == 1: - input_shapes = input_shapes[0] - - super(Model, self).build(tf.TensorShape(None)) # noqa: UP008 - + if not self._built: + first_arr = self._get_first_array(*args, **kwargs) + self.build( + *args, + **kwargs, + from_call=True, + dtype=first_arr.dtype if first_arr is not None else tf.float32, + ) # If `v` was provided, replace with the module's v replace_v = False if v is not None: @@ -1331,7 +1311,7 @@ def _call(self, *args, v=None, buffers=None, **kwargs): return ret elif hasattr(self.__call__, "wrapped"): return self.__call__(*args, **kwargs) - return super(Model, self).__call__(*args, **kwargs) # noqa: UP008 + return self.forward(*args, **kwargs) # noqa: UP008 def __delattr__(self, name): if hasattr(self, name):