diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index b1bacd44fa103..aba9ac22fd619 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -1,5 +1,5 @@ # global - +from functools import reduce # local import ivy from ivy import with_unsupported_dtypes @@ -207,56 +207,49 @@ def adaptive_avg_pool2d(input, output_size): "torch", ) @to_ivy_arrays_and_back +@to_ivy_arrays_and_back def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False): data_format = "NCW" padding = "VALID" - if stride is not None: - out = ivy.avg_pool1d( - ivy.pow(input, norm_type), - kernel_size, - stride, - padding, - data_format=data_format, - ceil_mode=ceil_mode, - ) + if stride is None: + stride = kernel_size + if not isinstance(kernel_size, int): + kernel_mul = reduce(lambda x, y: x * y, kernel_size) else: - out = ivy.avg_pool1d( - ivy.pow(input, norm_type), - kernel_size, - kernel_size, - padding, - data_format=data_format, - ceil_mode=ceil_mode, - ) + kernel_mul = kernel_size - return ivy.pow(ivy.multiply(out, kernel_size), ivy.divide(1.0, norm_type)) + out = ivy.avg_pool1d( + ivy.pow(input, norm_type), + kernel_size, + stride, + padding, + data_format=data_format, + ceil_mode=ceil_mode, + ) + p = 1.0 / norm_type if norm_type != 0 else 1.0 + return ivy.pow(ivy.multiply(out, kernel_mul), p) @to_ivy_arrays_and_back def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False): data_format = "NCHW" padding = "VALID" - if stride is not None: - out = ivy.avg_pool2d( - ivy.pow(input, norm_type), - kernel_size, - stride, - padding, - data_format=data_format, - ceil_mode=ceil_mode, - ) - else: - out = ivy.avg_pool2d( - ivy.pow(input, norm_type), - kernel_size, - kernel_size, - padding, - data_format=data_format, - ceil_mode=ceil_mode, - ) + if stride is None: + stride = kernel_size + out = ivy.avg_pool2d( + ivy.pow(input, norm_type), + kernel_size, + stride, + padding, + data_format=data_format, + ceil_mode=ceil_mode, + ) if not isinstance(kernel_size, int): - kernel_size = kernel_size[0] * kernel_size[1] - return ivy.pow(ivy.multiply(out, kernel_size), ivy.divide(1.0, norm_type)) + kernel_mul = reduce(lambda x, y: x * y, kernel_size) + else: + kernel_mul = kernel_size + p = ivy.divide(1.0, norm_type) if norm_type != 0 else 1.0 + return ivy.pow(ivy.multiply(out, kernel_mul), p).astype(input.dtype) @to_ivy_arrays_and_back diff --git a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py index e0e5530b5f5da..c59055f4f4a54 100644 --- a/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py +++ b/ivy_tests/test_ivy/test_frontends/test_torch/test_nn/test_functional/test_pooling_functions.py @@ -428,8 +428,9 @@ def test_torch_adaptive_max_pool2d( max_dims=3, min_side=1, max_side=3, + data_format="channel_first" ), - norm_type=helpers.number(min_value=0.1, max_value=6), + norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), ) def test_torch_lp_pool1d( @@ -444,10 +445,6 @@ def test_torch_lp_pool1d( ): input_dtype, x, kernel_size, stride, _ = dtype_x_k_s - # Torch ground truth func expects input to be consistent - # with a channels first format i.e. NCW - x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], x[0].shape[1])) - helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw, @@ -471,8 +468,9 @@ def test_torch_lp_pool1d( max_dims=4, min_side=1, max_side=4, + data_format="channel_first", ), - norm_type=helpers.number(min_value=0.1, max_value=6), + norm_type=helpers.ints(min_value=1, max_value=6), test_with_out=st.just(False), ) def test_torch_lp_pool2d( @@ -486,10 +484,6 @@ def test_torch_lp_pool2d( on_device, ): input_dtype, x, kernel_size, stride, _ = dtype_x_k_s - # Torch ground truth func expects input to be consistent - # with a channels first format i.e. NCW - x[0] = x[0].transpose((0, 3, 1, 2)) - helpers.test_frontend_function( input_dtypes=input_dtype, backend_to_test=backend_fw,