Skip to content

Commit

Permalink
fix test_DISK
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam-Armstrong committed Oct 23, 2024
1 parent a32a3a2 commit e780229
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions kornia/test_feature5.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,24 +199,35 @@ def test_DISK(target_framework, mode, backend_compile):

transpiled_kornia = ivy.transpile(kornia, source="torch", target=target_framework)

x = torch.rand(1, 3, 256, 256)
x = torch.rand(1, 3, 32, 32)
transpiled_x = _nest_torch_tensor_to_new_framework(x, target_framework)

model = kornia.feature.DISK()
class TorchModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.fc = torch.nn.Linear(32, 32)

def forward(self, x):
return self.fc(x)

torch_unet = TorchModel()
transpiled_unet = ivy.transpile(TorchModel, target=target_framework)()

model = kornia.feature.DISK(desc_dim=2, unet=torch_unet)
torch_out = model(x)
transpiled_model = transpiled_kornia.feature.DISK()

transpiled_model = transpiled_kornia.feature.DISK(desc_dim=2, unet=transpiled_unet)
if target_framework == "tensorflow":
# build the layers
transpiled_model(transpiled_x)

ivy.sync_models(model, transpiled_model)

transpiled_out = transpiled_model(transpiled_x)

_to_numpy_and_shape_allclose(torch_out.keypoints, transpiled_out.keypoints)
_to_numpy_and_shape_allclose(torch_out.descriptors, transpiled_out.descriptors)
_to_numpy_and_shape_allclose(torch_out.detection_scores, transpiled_out.detection_scores)
_to_numpy_and_shape_allclose(torch_out[0].keypoints, transpiled_out[0].keypoints)
_to_numpy_and_shape_allclose(torch_out[0].descriptors, transpiled_out[0].descriptors)
_to_numpy_and_shape_allclose(torch_out[0].detection_scores, transpiled_out[0].detection_scores)


def test_DISKFeatures(target_framework, mode, backend_compile):
Expand Down

0 comments on commit e780229

Please sign in to comment.