Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful changes 2 #21246

Merged
merged 47 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0b397dc
introduced delayed module variable initialisation, that is, initialis…
RickSanchezStoic Jul 13, 2023
05cf1d4
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Jul 13, 2023
3a30214
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 13, 2023
d4ff2a3
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 14, 2023
e042114
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 14, 2023
9bea610
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 17, 2023
06debd8
added weight freezing logic for nested Module class initialisation
RickSanchezStoic Jul 17, 2023
2329b55
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Jul 17, 2023
2909205
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 17, 2023
52dcbeb
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 19, 2023
ae539d2
added a custom class that overrides the __new__ method of module clas…
Jul 19, 2023
7ab4f24
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 21, 2023
75ea5b3
added some fixes for assert_and_assign method in cont_identical, fixe…
RickSanchezStoic Jul 21, 2023
5beb350
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 27, 2023
87f7e3f
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 1, 2023
fd26a4e
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 1, 2023
7432751
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 2, 2023
f4c27c6
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 3, 2023
ee38e02
added docstrings for build_callable and assert_and_assign
RickSanchezStoic Aug 3, 2023
ee04368
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 3, 2023
d4d5d35
Added buffer support to Module, internal methods namely, _set_buffers…
RickSanchezStoic Aug 3, 2023
3bee2ca
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 4, 2023
a0b8273
fixed a bug with buffers and inplace update in container function
RickSanchezStoic Aug 4, 2023
1cd09ce
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 4, 2023
afc6017
fixed a small bug with getting kwargs['buffers']
RickSanchezStoic Aug 4, 2023
489a277
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 4, 2023
0f935f7
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 8, 2023
f96eb44
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 8, 2023
309a361
added support for nested module objects and their buffers
RickSanchezStoic Aug 9, 2023
b881684
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 9, 2023
98225c5
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 9, 2023
165f760
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 10, 2023
229178c
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 10, 2023
5bdc232
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
5931aa4
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
2d2f9be
added tests, made functions public and removed tracker
RickSanchezStoic Aug 11, 2023
3602c38
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 11, 2023
13922c4
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
042c8b9
added buffers as a dictionary, with pseudo variable like behavior, al…
RickSanchezStoic Aug 11, 2023
db004b3
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
196d7ba
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 15, 2023
fea0c82
removed unused function
RickSanchezStoic Aug 15, 2023
52c402a
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 16, 2023
9efd662
minor fixes and reordering if else statements
RickSanchezStoic Aug 16, 2023
2001d77
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 17, 2023
9e18e86
syntax error fix
RickSanchezStoic Aug 17, 2023
ab9cd20
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ivy/data_classes/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,7 @@ def cont_overwrite_at_key_chains(
)
else:
return_dict[k] = v
return ivy.Container(return_dict, **self._config)
return return_dict
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved

def cont_prune_keys(self, query_keys, ignore_none=True):
"""
Expand Down
98 changes: 90 additions & 8 deletions ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
*args,
device=None,
v=None,
buffers=None,
build_mode="on_init",
compile_on_next_step=False,
store_vars=True,
Expand Down Expand Up @@ -156,7 +157,12 @@ def __init__(

if v or with_partial_v:
# build only if `v` or `with_partial_v`
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(
*args,
dynamic_backend=dynamic_backend,
buffers=buffers,
**kwargs,
)
# we don't want to delete the class variable now
# since there could be other child modules
return
Expand All @@ -171,16 +177,19 @@ def __init__(
# move on
Module._init_var.pop()
return
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(
*args, dynamic_backend=dynamic_backend, buffers=buffers, **kwargs
)
if Module._init_var[-1] == self.__class__.__name__:
# you delete it, only if this is the class that caused it's creation
Module._init_var.pop()

# do a final check if _init_var becomes empty, then delete it all together
del Module._init_var
if not Module._init_var:
del Module._init_var

return
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(*args, dynamic_backend=dynamic_backend, buffers=buffers**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax error, , to be added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this throws an error right?


# Private #
# --------#
Expand Down Expand Up @@ -289,6 +298,13 @@ def _build_and_return_v(self, *args, **kwargs):
def _find_child_objects(self, /, *, obj=None, _visited=None):
pass

def _find_buffers(self):
for obj in self.__dict__.keys():
if isinstance(getattr(self, obj), ivy.Module):
# simply fetch it's buffer
if hasattr(getattr(self, obj), "buffers"):
self.buffers.update({obj: getattr(self, obj).buffers})

@staticmethod
def _extract_v(v, keychain_mappings: dict, orig_key_chain, /):
"""
Expand Down Expand Up @@ -428,6 +444,33 @@ def found_dup_callback(x, kc):
vs = vs.cont_prune_key_chain(dup_kc)
return vs, keychain_mappings

def _set_buffers(self, buffers):
"""
Set the buffers of the given class instance, according to the buffers passed.

Parameters
----------
buffers
a dictionary with variable names and corresponding values

override
if true, sets the variable as an attribute even if it doesn't exist
"""
for buffer in buffers:
if hasattr(self, buffer):
# check if this value is another nested dictionary, if yes
# we recurse
if isinstance(buffers[buffer], dict):
getattr(self, buffer)._set_buffers(buffers=buffers[buffer])
else:
setattr(self, buffer, buffers[buffer])
else:
if hasattr(self, "buffers"):
self.buffers.update({buffer: buffers[buffer]})
else:
setattr(self, "buffers", {buffer: buffers[buffer]})
setattr(self, buffer, buffers[buffer])

# Overridable #

# noinspection PyMethodMayBeStatic,PyUnusedLocal
Expand Down Expand Up @@ -494,7 +537,7 @@ def _forward_with_tracking(self, *args, **kwargs):
self._check_submod_ret()
return ret

def _call(self, *args, v=None, **kwargs):
def _call(self, *args, v=None, buffers=None, **kwargs):
"""
Compute forward pass of the layer, treating layer instance as callable function.

Expand All @@ -516,6 +559,10 @@ def _call(self, *args, v=None, **kwargs):
from_call=True,
dtype=_get_first_array(*args, **kwargs).dtype,
)
if buffers:
buffers_orig = self.buffers.copy()
self.buffers = {}
self._set_buffers(buffers)
if v is not None:
v_orig = self.v
self.v = (
Expand All @@ -525,7 +572,11 @@ def _call(self, *args, v=None, **kwargs):
)
ret = self._forward_with_tracking(*args, **kwargs)
self.v = v_orig
if buffers:
self.buffers = {}
self._set_buffers(buffers_orig)
return ret

elif hasattr(self.__call__, "wrapped"):
return self.__call__(*args, **kwargs)
return self._forward_with_tracking(*args, **kwargs)
Expand All @@ -536,6 +587,7 @@ def __call__(
self,
*args,
v=None,
buffers=None,
stateful=None,
arg_stateful_idxs=None,
kwarg_stateful_idxs=None,
Expand Down Expand Up @@ -597,7 +649,7 @@ def __call__(

# convert variables to native arrays so that they can be tracked
v = ivy.to_native(v)
ret = self._call(*args, v=v, **kwargs)
ret = self._call(*args, v=v, buffers=buffers, **kwargs)
self._unset_submod_flags()
return ret

Expand All @@ -624,6 +676,7 @@ def build(
device=None,
dtype=None,
dynamic_backend=None,
buffers=None,
**kwargs,
):
"""
Expand All @@ -646,8 +699,9 @@ def build(
True for successfully built a module.
"""
self._dev = ivy.default(device, self._dev)
# build buffers if any
self._create_buffers()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that ivy.Module._create_buffers has been removed, we should also remove this

# return False if not from_call but build_mode is on_call

if not from_call and self._build_mode == "on_call":
return self.v
if dtype:
Expand Down Expand Up @@ -749,8 +803,18 @@ def build(
if not self._store_vars:
# ToDo: verify variables in self.v are released once this method exits
self.v = ivy.Container()

# once all variables built, find and assign buffers
if buffers:
self._set_buffers(buffers=buffers)
self._find_buffers()

return v_ret if bool(v_ret) or isinstance(built, bool) else built

def register_buffer(self, var_name, value):
"""Set the buffer at any place within the class."""
self._set_buffers({var_name: value})

def __repr__(self):
return object.__repr__(self)

Expand Down Expand Up @@ -798,9 +862,27 @@ def __getattribute__(self, name):
self._build_and_return_v(
self._args, dynamic_backend=self._dynamic_backend, **self._kwargs
)

if name == "buffers":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This if block is redundant right?

return super().__getattribute__(name)
elif hasattr(self, "buffers"):
if name in self.buffers:
return self.buffers[name]
return super().__getattribute__(name)

def __setattr__(self, name, value):
if hasattr(self, "buffers"):
if name in self.buffers:
self.buffers[name] = value
return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra blank line

return super().__setattr__(name, value)

def __delattr__(self, name):
if name in self.buffers:
del self.buffers
else:
super().__delattr__(name)

def compile(
self,
args: Optional[Tuple] = None,
Expand Down
30 changes: 30 additions & 0 deletions ivy_tests/test_ivy/test_stateful/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,3 +1059,33 @@ def loss_fn(v_):
assert ivy.Container.all(loaded_module.v == module.v).cont_all_true()

os.remove(save_filepath)


class ModuleWithBuffer(ivy.Module):
def __init__(self, *args, **kwargs):
pass

def _forward(*args, **kwargs):
pass


@given(
buffer=st.just(
[
{
"var1": [
ivy.ones((1, 2)),
]
}
]
)
)
def test_get_buffers(buffer):
module = ModuleWithBuffer()
buffers = {}
for item in buffer:
buffers.update(item)
for key in item:
module.register_buffer(key, item[key])

assert module.buffers == buffers
Loading