From 5056de6928f0771464e1cdd2d4bd3d0755c83de2 Mon Sep 17 00:00:00 2001 From: Ved Patwardhan <54766411+VedPatwardhan@users.noreply.github.com> Date: Wed, 16 Aug 2023 09:42:51 +0000 Subject: [PATCH] Removed x_dilations assertion from ivy.conv as it's unused, added missing tests for ivy.conv, updated the x_and_filters helper to not generate x_dilations for transposed convolutions --- ivy/functional/ivy/layers.py | 21 +- .../test_functional/test_nn/test_layers.py | 228 ++++++++++++++++-- 2 files changed, 214 insertions(+), 35 deletions(-) diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index dc768aa854055..52a9f93e81aa0 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -109,7 +109,7 @@ def linear( ret Result array of the linear transformation. *[outer_batch_shape,inner_batch_shape,out_features]* - + Both the description and the type hints above assumes an array input for simplicity, but this function is *nestable*, and therefore also accepts :class:`ivy.Container` instances in place of any of the arguments. @@ -123,7 +123,7 @@ def linear( >>> y = ivy.linear(x, w) >>> print(y) ivy.array([1.]) - + >>> x = ivy.array([[0.666, -0.4269, 1.911]]) >>> w = ivy.array([[1., 0., 0.], [0., 0., 1.]]) >>> y = ivy.zeros((1, 2)) @@ -143,7 +143,7 @@ def linear( ivy.array([[ 34.98495483, 101.0293808 ], [ 28.0159359 , 83.74752808], [ 37.20942307, 108.3205719 ]]) - + With :class:`ivy.Container` input: >>> x = ivy.Container(a=ivy.array([[1., 2., 3.], @@ -181,7 +181,7 @@ def linear( b: ivy.array([[15.1, 32., 47.9], [85., 196., 306.]]) } - + """ outer_batch_shape = list(weight.shape[:-2]) num_outer_batch_dims = len(outer_batch_shape) @@ -590,8 +590,10 @@ def scaled_dot_product_attention( ... b=ivy.array([[[3.2, 1.], [2.2, 3.6], [4.0, 5.6]]])) >>> v = ivy.Container(a=ivy.array([[[5.2, 1.], [2.1, 3.], [4.4, 5.6]]]), ... b=ivy.array([[[0.2, 1.], [2.2, 3.], [4.4, 5.6]]])) - >>> mask = ivy.Container(a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]), - ... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]])) + >>> mask = ivy.Container( + ... a=ivy.array([[[1.0, 1.0, 1.0],[1.0, 1.0, 1.0],[1.0, 1.0, 1.0]]]), + ... b=ivy.array([[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0,1.0]]]) + ... ) >>> result = ivy.scaled_dot_product_attention(q,k,v,scale=1,mask=mask) >>> print(result) { @@ -1602,10 +1604,10 @@ def conv3d( while "NCDHW" corresponds to input with shape (batch_size, channels, depth, height, width). filter_format - Either "channel_first" or "channel_last". "channel_first" corresponds + Either "channel_first" or "channel_last". "channel_first" corresponds to "OIDHW",input data formats, while "channel_last" corresponds to "DHWIO". x_dilations - The dilation factor for each dimension of input. (Default value = 1) + The dilation factor for each dimension of input. (Default value = 1) dilations The dilation factor for each dimension of input. (Default value = 1) bias @@ -1983,8 +1985,8 @@ def conv_general_transpose( @handle_exceptions @handle_array_like_without_promotion @handle_out_argument -@handle_array_function @inputs_to_native_shapes +@handle_array_function def conv( x: Union[ivy.Array, ivy.NativeArray], filters: Union[ivy.Array, ivy.NativeArray], @@ -2053,7 +2055,6 @@ def conv( The result of the transpose or dilated convolution operation. """ if transpose: - assert x_dilations == 1, "x_dilations must be 1 for transpose convolutions." return conv_general_transpose( x, filters, diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index 8bebc64f12f14..c5b9f4391b066 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -644,14 +644,14 @@ def x_and_filters( ) if general: data_format = "channel_first" if channel_first else "channel_last" - - x_dilation = draw( - st.one_of( - st.integers(1, 3), - st.lists(st.integers(1, 3), min_size=dim, max_size=dim), + if not transpose: + x_dilation = draw( + st.one_of( + st.integers(1, 3), + st.lists(st.integers(1, 3), min_size=dim, max_size=dim), + ) ) - ) - dilations = (dilations, x_dilation) + dilations = (dilations, x_dilation) if filter_format is not None: filter_format = draw(filter_format) if filter_format == "channel_first": @@ -694,9 +694,18 @@ def _assume_tf_dilation_gt_1(backend_fw, on_device, dilations): ground_truth_backend="jax", ) def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + ff_format, + bias, + ) = x_f_d_df # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( @@ -730,9 +739,18 @@ def test_conv1d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ground_truth_backend="jax", ) def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = x_f_d_df _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, @@ -765,9 +783,18 @@ def test_conv1d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic ground_truth_backend="jax", ) def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + ff_format, + bias, + ) = x_f_d_df # ToDo: Enable gradient tests for dilations > 1 when tensorflow supports it. _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( @@ -802,9 +829,18 @@ def test_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ground_truth_backend="jax", ) def test_conv2d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = x_f_d_df _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( @@ -870,9 +906,18 @@ def test_depthwise_conv2d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_devic ground_truth_backend="jax", ) def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, fc, ff_format, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + ff_format, + bias, + ) = x_f_d_df _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, @@ -905,9 +950,18 @@ def test_conv3d(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): ground_truth_backend="jax", ) def test_conv3d_transpose(*, x_f_d_df, test_flags, backend_fw, fn_name, on_device): - dtype, x, filters, dilations, data_format, stride, pad, output_shape, fc, bias = ( - x_f_d_df - ) + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = x_f_d_df _assume_tf_dilation_gt_1(backend_fw, on_device, dilations[0]) helpers.test_function( input_dtypes=dtype, @@ -1026,6 +1080,130 @@ def test_conv_general_transpose( ) +# filter_format not in conv_general_transpose +# output_shape not in conv_general_dilated +@st.composite +def x_and_filters_and_transpose( + draw, + dim: int = 2, + general=False, + bias=False, + filter_format=None, +): + transpose = draw(st.booleans()) + if not transpose: + filter_format = st.sampled_from(["channel_last", "channel_first"]) + all_args = draw( + x_and_filters( + dim=dim, + general=general, + bias=bias, + filter_format=filter_format, + transpose=transpose, + ) + ) + output_shape = None + filter_format = "channel_last" + if transpose: + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + output_shape, + fc, + bias, + ) = all_args + else: + ( + dtype, + x, + filters, + dilations, + data_format, + stride, + pad, + fc, + filter_format, + bias, + ) = all_args + return ( + dtype, + x, + filters, + stride, + pad, + transpose, + output_shape, + data_format, + filter_format, + fc, + dilations, + bias, + ) + + +# conv +@handle_test( + fn_tree="functional.ivy.conv", + dims=st.shared(st.integers(1, 3), key="dims"), + x_f_d_df_tr=x_and_filters_and_transpose( + dim=st.shared(st.integers(1, 3), key="dims"), + general=True, + bias=True, + ), + ground_truth_backend="jax", +) +def test_conv(*, dims, x_f_d_df_tr, test_flags, backend_fw, fn_name, on_device): + # pass + ( + dtype, + x, + filters, + stride, + pad, + transpose, + output_shape, + data_format, + filter_format, + fc, + dilations, + bias, + ) = x_f_d_df_tr + tf_dilations = dilations + if not transpose: + tf_dilations = tf_dilations[0] + dilations, x_dilations = dilations + else: + x_dilations = None + _assume_tf_dilation_gt_1(backend_fw, on_device, tf_dilations) + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-2, + atol_=1e-2, + x=x, + filters=filters, + strides=stride, + padding=pad, + transpose=transpose, + dims=dims, + output_shape=output_shape, + data_format=data_format, + filter_format=filter_format, + feature_group_count=fc, + x_dilations=x_dilations, + dilations=dilations, + bias=bias, + ) + + # LSTM # # -----#