Skip to content

Commit

Permalink
PR reformated
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohammed committed Aug 15, 2023
1 parent 2affa3d commit d424afc
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 39 deletions.
31 changes: 4 additions & 27 deletions ivy/functional/frontends/torch/nn/functional/pooling_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,15 @@ def max_pool1d(
):
if stride is None:
stride = kernel_size
kernel_size = _broadcast_pooling_helper(kernel_size, "1d", name="kernel_size")
stride = _broadcast_pooling_helper(stride, "1d", name="stride")
padding = _broadcast_pooling_helper(padding, "1d", name="padding")
kernel_pads = zip(kernel_size, padding)

data_format = "NCW"
if not all([pad <= kernel / 2 for kernel, pad in kernel_pads]):
raise ValueError(
"pad should be smaller than or equal to half of kernel size, "
f"but got padding={padding}, kernel_size={kernel_size}. "
)
# figure out whether to apply padding
if sum(padding) == 0:
padding_str = "VALID"
elif all([pad == ivy.ceil((kernel - 1) / 2) for kernel, pad in kernel_pads]):
padding_str = "SAME"
else:
padding_str = "VALID"
return ivy.max_pool1d(
input,
kernel_size,
stride,
padding_str,
padding,
data_format=data_format,
dilation=dilation,
ceil_mode=ceil_mode,
)


Expand All @@ -152,14 +137,9 @@ def max_pool2d(
ceil_mode=False,
return_indices=False,
):
# ToDo: Add return_indices once superset in implemented
dim_check = False
if input.ndim == 3:
input = input.expand_dims()
dim_check = True
if stride is None:
stride = kernel_size
ret = ivy.max_pool2d(
return ivy.max_pool2d(
input,
kernel_size,
stride,
Expand All @@ -168,9 +148,6 @@ def max_pool2d(
dilation=dilation,
ceil_mode=ceil_mode,
)
if dim_check:
return ret.squeeze(0)
return ret


@to_ivy_arrays_and_back
Expand Down
2 changes: 1 addition & 1 deletion ivy/stateful/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import abc
import copy
import dill
#import dill
from typing import Optional, Tuple, Dict

# local
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,7 @@ def test_torch_avg_pool3d(
max_dims=3,
min_side=1,
max_side=3,
explicit_or_str_padding=False,
only_explicit_padding=False,
only_explicit_padding=True,
data_format="channel_first",
),
test_with_out=st.just(False),
Expand All @@ -216,13 +215,7 @@ def test_torch_max_pool1d(
on_device,
):
input_dtype, x, kernel_size, stride, padding = dtype_x_k_s
x_shape = [x[0].shape[2]]
# Torch ground truth func also takes padding input as an integer
# or a tuple of integers, not a string
if padding == "SAME":
padding = calculate_same_padding(kernel_size, stride, x_shape)
elif padding == "VALID":
padding = (0,)
padding = (padding[0][0], )
helpers.test_frontend_function(
input_dtypes=input_dtype,
backend_to_test=backend_fw,
Expand All @@ -245,10 +238,10 @@ def test_torch_max_pool1d(
max_dims=4,
min_side=1,
max_side=4,
explicit_or_str_padding=True,
only_explicit_padding=True,
return_dilation=True,
data_format="channel_first",
).filter(lambda x: x[4] != "VALID" and x[4] != "SAME"),
),
test_with_out=st.just(False),
ceil_mode=st.just(True),
)
Expand Down

0 comments on commit d424afc

Please sign in to comment.