From b3f5ba78912c5dff983b4c5593cbb68832845105 Mon Sep 17 00:00:00 2001 From: Sam-Armstrong Date: Tue, 16 Jul 2024 23:15:15 +0100 Subject: [PATCH] fix: minor fixes for torch max_pool2d frontend --- .../torch/indexing_slicing_joining_mutating_ops.py | 2 +- .../frontends/torch/nn/functional/pooling_functions.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py index 5e46854039dcf..19e9b8a2db840 100644 --- a/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py +++ b/ivy/functional/frontends/torch/indexing_slicing_joining_mutating_ops.py @@ -125,7 +125,7 @@ def gather(input, dim, index, *, sparse_grad=False, out=None): dim = dim % len(input.shape) all_indices = ivy.argwhere(ivy.full(index.shape, True)) - gather_locations = ivy.reshape(index, [ivy.prod(ivy.array(index.shape))]) + gather_locations = ivy.reshape(index, [ivy.prod(ivy.array(index.shape), dtype=torch_frontend.int64)]) gather_indices = [] for axis in range(len(index.shape)): diff --git a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py index fd38133cadcdf..5474baa7d1b51 100644 --- a/ivy/functional/frontends/torch/nn/functional/pooling_functions.py +++ b/ivy/functional/frontends/torch/nn/functional/pooling_functions.py @@ -314,7 +314,7 @@ def max_pool2d( ) # torch pad takes width padding first, then height padding padding = (padding[1], padding[0]) - pad_array = ivy.flatten(padding) + pad_list = list(ivy.flatten(padding)) in_shape = input.shape H = in_shape[-2] @@ -329,13 +329,13 @@ def max_pool2d( # find the indices of the max value for each position of the sliding window input = torch_frontend.nn.functional.pad( input, - pad_array, + pad_list, value=float("-inf"), ) input_indices = torch_frontend.nn.functional.pad( input_indices, - pad_array, + pad_list, value=0, )