Skip to content

Commit

Permalink
refactor torch frontend lp pooling functions and their tests (#21967)
Browse files Browse the repository at this point in the history
Co-authored-by: @AnnaTz
  • Loading branch information
mohame54 authored Aug 16, 2023
1 parent 075eecd commit bc23efe
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 49 deletions.
71 changes: 32 additions & 39 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global

from functools import reduce
# local
import ivy
from ivy import with_unsupported_dtypes
Expand Down Expand Up @@ -207,56 +207,49 @@ def adaptive_avg_pool2d(input, output_size):
"torch",
)
@to_ivy_arrays_and_back
@to_ivy_arrays_and_back
def lp_pool1d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
data_format = "NCW"
padding = "VALID"
if stride is not None:
out = ivy.avg_pool1d(
ivy.pow(input, norm_type),
kernel_size,
stride,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
if stride is None:
stride = kernel_size
if not isinstance(kernel_size, int):
kernel_mul = reduce(lambda x, y: x * y, kernel_size)
else:
out = ivy.avg_pool1d(
ivy.pow(input, norm_type),
kernel_size,
kernel_size,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
kernel_mul = kernel_size

return ivy.pow(ivy.multiply(out, kernel_size), ivy.divide(1.0, norm_type))
out = ivy.avg_pool1d(
ivy.pow(input, norm_type),
kernel_size,
stride,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
p = 1.0 / norm_type if norm_type != 0 else 1.0
return ivy.pow(ivy.multiply(out, kernel_mul), p)


@to_ivy_arrays_and_back
def lp_pool2d(input, norm_type, kernel_size, stride=None, ceil_mode=False):
data_format = "NCHW"
padding = "VALID"
if stride is not None:
out = ivy.avg_pool2d(
ivy.pow(input, norm_type),
kernel_size,
stride,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
else:
out = ivy.avg_pool2d(
ivy.pow(input, norm_type),
kernel_size,
kernel_size,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
if stride is None:
stride = kernel_size
out = ivy.avg_pool2d(
ivy.pow(input, norm_type),
kernel_size,
stride,
padding,
data_format=data_format,
ceil_mode=ceil_mode,
)
if not isinstance(kernel_size, int):
kernel_size = kernel_size[0] * kernel_size[1]
return ivy.pow(ivy.multiply(out, kernel_size), ivy.divide(1.0, norm_type))
kernel_mul = reduce(lambda x, y: x * y, kernel_size)
else:
kernel_mul = kernel_size
p = ivy.divide(1.0, norm_type) if norm_type != 0 else 1.0
return ivy.pow(ivy.multiply(out, kernel_mul), p).astype(input.dtype)


@to_ivy_arrays_and_back
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,9 @@ def test_torch_adaptive_max_pool2d(
max_dims=3,
min_side=1,
max_side=3,
data_format="channel_first"
),
norm_type=helpers.number(min_value=0.1, max_value=6),
norm_type=helpers.ints(min_value=1, max_value=6),
test_with_out=st.just(False),
)
def test_torch_lp_pool1d(
Expand All @@ -444,10 +445,6 @@ def test_torch_lp_pool1d(
):
input_dtype, x, kernel_size, stride, _ = dtype_x_k_s

# Torch ground truth func expects input to be consistent
# with a channels first format i.e. NCW
x[0] = x[0].reshape((x[0].shape[0], x[0].shape[-1], x[0].shape[1]))

helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
Expand All @@ -471,8 +468,9 @@ def test_torch_lp_pool1d(
max_dims=4,
min_side=1,
max_side=4,
data_format="channel_first",
),
norm_type=helpers.number(min_value=0.1, max_value=6),
norm_type=helpers.ints(min_value=1, max_value=6),
test_with_out=st.just(False),
)
def test_torch_lp_pool2d(
Expand All @@ -486,10 +484,6 @@ def test_torch_lp_pool2d(
on_device,
):
input_dtype, x, kernel_size, stride, _ = dtype_x_k_s
# Torch ground truth func expects input to be consistent
# with a channels first format i.e. NCW
x[0] = x[0].transpose((0, 3, 1, 2))

helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
Expand Down

0 comments on commit bc23efe

Please sign in to comment.