Skip to content

Commit

Permalink
fix: minor fixes for torch max_pool2d frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Jul 16, 2024
1 parent 22757f4 commit b3f5ba7
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def gather(input, dim, index, *, sparse_grad=False, out=None):

dim = dim % len(input.shape)
all_indices = ivy.argwhere(ivy.full(index.shape, True))
gather_locations = ivy.reshape(index, [ivy.prod(ivy.array(index.shape))])
gather_locations = ivy.reshape(index, [ivy.prod(ivy.array(index.shape), dtype=torch_frontend.int64)])

gather_indices = []
for axis in range(len(index.shape)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def max_pool2d(
)
# torch pad takes width padding first, then height padding
padding = (padding[1], padding[0])
pad_array = ivy.flatten(padding)
pad_list = list(ivy.flatten(padding))

in_shape = input.shape
H = in_shape[-2]
Expand All @@ -329,13 +329,13 @@ def max_pool2d(
# find the indices of the max value for each position of the sliding window
input = torch_frontend.nn.functional.pad(
input,
pad_array,
pad_list,
value=float("-inf"),
)

input_indices = torch_frontend.nn.functional.pad(
input_indices,
pad_array,
pad_list,
value=0,
)

Expand Down

0 comments on commit b3f5ba7

Please sign in to comment.