-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
26 lines (22 loc) · 1.27 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import torch
import torch.nn.functional as F
class PSNRLoss(torch.nn.Module):
def __init__(self):
super(PSNRLoss, self).__init__()
self.MSE = torch.nn.MSELoss()
def forward(self, tensorA:torch.Tensor,tensorB:torch.Tensor) -> torch.Tensor:
return 20*(torch.log10(torch.max(tensorB)))-10*torch.log10(self.MSE(tensorA, tensorB))
class Rotation_Network(torch.nn.Module):
def __init__(self, dimension:int, theta:float, device:torch.device) -> None:
super(Rotation_Network, self).__init__()
if dimension ==2:
self.rotation_tensor = torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]]).unsqueeze(0)
elif dimension ==3:
self.rotation_tensor = torch.tensor([[torch.cos(theta), -torch.sin(theta),0,0],
[torch.sin(theta), torch.cos(theta), 0,0],
[0,0,1,0]]).unsqueeze(0)
self.device = device
def forward(self, x:torch.Tensor) -> torch.Tensor:
grid = F.affine_grid(self.rotation_tensor, x.size()).to(self.device)
return F.grid_sample(x, grid, padding_mode='zeros')