Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stateful changes 2 #21246

Merged
merged 47 commits into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
0b397dc
introduced delayed module variable initialisation, that is, initialis…
RickSanchezStoic Jul 13, 2023
05cf1d4
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Jul 13, 2023
3a30214
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 13, 2023
d4ff2a3
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 14, 2023
e042114
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 14, 2023
9bea610
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 17, 2023
06debd8
added weight freezing logic for nested Module class initialisation
RickSanchezStoic Jul 17, 2023
2329b55
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Jul 17, 2023
2909205
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 17, 2023
52dcbeb
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 19, 2023
ae539d2
added a custom class that overrides the __new__ method of module clas…
Jul 19, 2023
7ab4f24
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 21, 2023
75ea5b3
added some fixes for assert_and_assign method in cont_identical, fixe…
RickSanchezStoic Jul 21, 2023
5beb350
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Jul 27, 2023
87f7e3f
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 1, 2023
fd26a4e
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 1, 2023
7432751
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 2, 2023
f4c27c6
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 3, 2023
ee38e02
added docstrings for build_callable and assert_and_assign
RickSanchezStoic Aug 3, 2023
ee04368
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 3, 2023
d4d5d35
Added buffer support to Module, internal methods namely, _set_buffers…
RickSanchezStoic Aug 3, 2023
3bee2ca
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 4, 2023
a0b8273
fixed a bug with buffers and inplace update in container function
RickSanchezStoic Aug 4, 2023
1cd09ce
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 4, 2023
afc6017
fixed a small bug with getting kwargs['buffers']
RickSanchezStoic Aug 4, 2023
489a277
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 4, 2023
0f935f7
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 8, 2023
f96eb44
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 8, 2023
309a361
added support for nested module objects and their buffers
RickSanchezStoic Aug 9, 2023
b881684
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 9, 2023
98225c5
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 9, 2023
165f760
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 10, 2023
229178c
Merge remote-tracking branch 'upstream/master' into stateful-changes-2
RickSanchezStoic Aug 10, 2023
5bdc232
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
5931aa4
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
2d2f9be
added tests, made functions public and removed tracker
RickSanchezStoic Aug 11, 2023
3602c38
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 11, 2023
13922c4
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
042c8b9
added buffers as a dictionary, with pseudo variable like behavior, al…
RickSanchezStoic Aug 11, 2023
db004b3
Merge branch 'unifyai:master' into stateful-changes-2
RickSanchezStoic Aug 11, 2023
196d7ba
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 15, 2023
fea0c82
removed unused function
RickSanchezStoic Aug 15, 2023
52c402a
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 16, 2023
9efd662
minor fixes and reordering if else statements
RickSanchezStoic Aug 16, 2023
2001d77
Merge branch 'unifyai:main' into stateful-changes-2
RickSanchezStoic Aug 17, 2023
9e18e86
syntax error fix
RickSanchezStoic Aug 17, 2023
ab9cd20
Merge branch 'stateful-changes-2' of https://github.com/RickSanchezSt…
RickSanchezStoic Aug 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ivy/data_classes/container/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2845,7 +2845,7 @@ def cont_overwrite_at_key_chains(
)
else:
return_dict[k] = v
return ivy.Container(return_dict, **self._config)
return return_dict
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved

def cont_prune_keys(self, query_keys, ignore_none=True):
"""
Expand Down
128 changes: 121 additions & 7 deletions ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
*args,
device=None,
v=None,
buffers=None,
build_mode="on_init",
compile_on_next_step=False,
store_vars=True,
Expand Down Expand Up @@ -144,6 +145,9 @@ def __init__(
self._target = None
self._lazy_compiled = False
self._dynamic_backend = dynamic_backend
if hasattr(self, "_create_buffers"):
vedpatwardhan marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like we're still using _create_buffers and _buffer_tracker, could you please remove it as we'd only need to allow users to register buffers "on-spot"?

self._create_buffers = self._buffer_tracker(self._create_buffers)

if build_mode != "on_init":
return
if hasattr(Module, "_init_var"):
Expand All @@ -156,7 +160,12 @@ def __init__(

if v or with_partial_v:
# build only if `v` or `with_partial_v`
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(
*args,
dynamic_backend=dynamic_backend,
buffers=buffers,
**kwargs,
)
# we don't want to delete the class variable now
# since there could be other child modules
return
Expand All @@ -171,16 +180,19 @@ def __init__(
# move on
Module._init_var.pop()
return
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(
*args, dynamic_backend=dynamic_backend, buffers=buffers, **kwargs
)
if Module._init_var[-1] == self.__class__.__name__:
# you delete it, only if this is the class that caused it's creation
Module._init_var.pop()

# do a final check if _init_var becomes empty, then delete it all together
del Module._init_var
if not Module._init_var:
del Module._init_var

return
self.build(*args, dynamic_backend=dynamic_backend, **kwargs)
self.build(*args, dynamic_backend=dynamic_backend, buffers=buffers**kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

syntax error, , to be added

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this throws an error right?


# Private #
# --------#
Expand Down Expand Up @@ -289,6 +301,13 @@ def _build_and_return_v(self, *args, **kwargs):
def _find_child_objects(self, /, *, obj=None, _visited=None):
pass

def _find_buffers(self):
for obj in self.__dict__.keys():
if isinstance(getattr(self, obj), ivy.Module):
# simply fetch it's buffer
if hasattr(getattr(self, obj), "buffers"):
self.buffers.update({obj: getattr(self, obj).buffers})

@staticmethod
def _extract_v(v, keychain_mappings: dict, orig_key_chain, /):
"""
Expand Down Expand Up @@ -428,6 +447,76 @@ def found_dup_callback(x, kc):
vs = vs.cont_prune_key_chain(dup_kc)
return vs, keychain_mappings

def _buffer_tracker(self, func):
"""Tracks the variables defined as buffer variables and stores them."""

def wrapper(*args, **kwargs):
initial_object_snapshot = self.__dict__.copy()
# initialise the buffers
func(*args, **kwargs)
final_object_snapshot = self.__dict__
getattr(self, "buffers", {}).update(
set(final_object_snapshot.keys()).difference(
set(initial_object_snapshot.keys())
)
)

wrapper.buffers_tracked = True
return wrapper

def _get_buffers(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why we need this method actually, given that we are already using _find_buffers while building, won't self.buffers be a container or arrays anyway?

"""Return the buffer variables, if any, of the given Module class instance."""
if self.buffers:
buffer_dict = {}
for buffer in self.buffers:
buffer_dict[buffer] = (
getattr(self, buffer)
if not isinstance(getattr(self, buffer), ivy.Module)
else getattr(self, buffer)._get_buffers()
)
return buffer_dict
return {}

def _set_buffers(self, buffers, override=False, nesting=None):
"""
Set the buffers of the given class instance, according to the buffers passed.

Parameters
----------
buffers
a dictionary with variable names and corresponding values

override
if true, sets the variable as an attribute even if it doesn't exist
"""
for buffer in buffers:
if hasattr(self, buffer):
# check if this value is another nested dictionary, if yes
# we recurse
if isinstance(buffers[buffer], dict):
getattr(self, buffer)._set_buffers(
buffers=buffers[buffer], override=override
)
else:
setattr(self, buffer, buffers[buffer])
elif not override or (override and isinstance(buffers[buffer], dict)):
# do a quick check to see if this is nested case
# we can't set buffers for module objects not yet
# created, even if override=True
raise ivy.exceptions.IvyNotImplementedException(
f"{buffer} hasn't been defined for the given Module structure"
)
else:
setattr(self, buffer, buffers[buffer])
if hasattr(self, "buffers"):
self.buffers.update({buffer})
else:
setattr(self, "buffers", {buffer})

def _register_buffers(self, var_name, value):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this public, and remove variables.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you please rename register_buffers to register_buffer, set_buffers to set_buffer and others and make the necessary ones public?

"""Set the buffer variables at any place within the class."""
self._set_buffers({var_name: value}, override=True)

# Overridable #

# noinspection PyMethodMayBeStatic,PyUnusedLocal
Expand All @@ -448,6 +537,17 @@ def _create_variables(self, *, device=None, dtype=None):
"""
return {}

def _create_buffers(self):
RickSanchezStoic marked this conversation as resolved.
Show resolved Hide resolved
"""
Create buffers for this class.

Returns
-------
ret
An empty set.
"""
return {}

def _build(self, *args, **kwargs) -> bool:
"""
Build the internal layers and variables for this module. Overridable.
Expand Down Expand Up @@ -494,7 +594,7 @@ def _forward_with_tracking(self, *args, **kwargs):
self._check_submod_ret()
return ret

def _call(self, *args, v=None, **kwargs):
def _call(self, *args, v=None, buffers=None, **kwargs):
"""
Compute forward pass of the layer, treating layer instance as callable function.

Expand All @@ -516,6 +616,9 @@ def _call(self, *args, v=None, **kwargs):
from_call=True,
dtype=_get_first_array(*args, **kwargs).dtype,
)
if buffers:
buffers_orig = self._get_buffers()
self._set_buffers(buffers)
if v is not None:
v_orig = self.v
self.v = (
Expand All @@ -525,6 +628,8 @@ def _call(self, *args, v=None, **kwargs):
)
ret = self._forward_with_tracking(*args, **kwargs)
self.v = v_orig
if buffers:
self._set_buffers(buffers_orig)
return ret
elif hasattr(self.__call__, "wrapped"):
return self.__call__(*args, **kwargs)
Expand All @@ -536,6 +641,7 @@ def __call__(
self,
*args,
v=None,
buffers=None,
stateful=None,
arg_stateful_idxs=None,
kwarg_stateful_idxs=None,
Expand Down Expand Up @@ -597,7 +703,7 @@ def __call__(

# convert variables to native arrays so that they can be tracked
v = ivy.to_native(v)
ret = self._call(*args, v=v, **kwargs)
ret = self._call(*args, v=v, buffers=buffers, **kwargs)
self._unset_submod_flags()
return ret

Expand All @@ -624,6 +730,7 @@ def build(
device=None,
dtype=None,
dynamic_backend=None,
buffers=None,
**kwargs,
):
"""
Expand All @@ -646,8 +753,9 @@ def build(
True for successfully built a module.
"""
self._dev = ivy.default(device, self._dev)
# build buffers if any
self._create_buffers()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that ivy.Module._create_buffers has been removed, we should also remove this

# return False if not from_call but build_mode is on_call

if not from_call and self._build_mode == "on_call":
return self.v
if dtype:
Expand Down Expand Up @@ -749,6 +857,12 @@ def build(
if not self._store_vars:
# ToDo: verify variables in self.v are released once this method exits
self.v = ivy.Container()

# once all variables built, find and assign buffers
if buffers:
self._set_buffers(buffers=buffers)
self._find_buffers()

return v_ret if bool(v_ret) or isinstance(built, bool) else built

def __repr__(self):
Expand Down
Loading