forked from ad50810344/compression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
93 lines (82 loc) · 4.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import numpy as np
import os
import torch
import torchvision.models as models
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import sys
import math
import torch.nn.init as init
import logging
from torch.nn.parameter import Parameter
from models import *
def save_model(model, iter, name):
torch.save(model.state_dict(), os.path.join(name, "iter_{}.pth.tar".format(iter)))
def load_model(model, f):
with open(f, 'rb') as f:
pretrained_dict = torch.load(f)
model_dict = model.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
f = str(f)
if f.find('iter_') != -1 and f.find('.pth') != -1:
st = f.find('iter_') + 5
ed = f.find('.pth', st)
return int(f[st:ed])
else:
return 0
class ImageCompressor(nn.Module):
def __init__(self, out_channel_N=192, out_channel_M=320):
super(ImageCompressor, self).__init__()
self.Encoder = Analysis_net(out_channel_N=out_channel_N, out_channel_M=out_channel_M)
self.Decoder = Synthesis_net(out_channel_N=out_channel_N, out_channel_M=out_channel_M)
self.priorEncoder = Analysis_prior_net(out_channel_N=out_channel_N, out_channel_M=out_channel_M)
self.priorDecoder = Synthesis_prior_net(out_channel_N=out_channel_N, out_channel_M=out_channel_M)
self.bitEstimator_z = BitEstimator(out_channel_N)
self.out_channel_N = out_channel_N
self.out_channel_M = out_channel_M
def forward(self, input_image):
quant_noise_feature = torch.zeros(input_image.size(0), self.out_channel_M, input_image.size(2) // 16, input_image.size(3) // 16).cuda()
quant_noise_z = torch.zeros(input_image.size(0), self.out_channel_N, input_image.size(2) // 64, input_image.size(3) // 64).cuda()
quant_noise_feature = torch.nn.init.uniform_(torch.zeros_like(quant_noise_feature), -0.5, 0.5)
quant_noise_z = torch.nn.init.uniform_(torch.zeros_like(quant_noise_z), -0.5, 0.5)
feature = self.Encoder(input_image)
batch_size = feature.size()[0]
z = self.priorEncoder(feature)
if self.training:
compressed_z = z + quant_noise_z
else:
compressed_z = torch.round(z)
recon_sigma = self.priorDecoder(compressed_z)
feature_renorm = feature
if self.training:
compressed_feature_renorm = feature_renorm + quant_noise_feature
else:
compressed_feature_renorm = torch.round(feature_renorm)
recon_image = self.Decoder(compressed_feature_renorm)
# recon_image = prediction + recon_res
clipped_recon_image = recon_image.clamp(0., 1.)
# distortion
mse_loss = torch.mean((recon_image - input_image).pow(2))
def feature_probs_based_sigma(feature, sigma):
mu = torch.zeros_like(sigma)
sigma = sigma.clamp(1e-10, 1e10)
gaussian = torch.distributions.laplace.Laplace(mu, sigma)
probs = gaussian.cdf(feature + 0.5) - gaussian.cdf(feature - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(probs + 1e-10) / math.log(2.0), 0, 50))
return total_bits, probs
def iclr18_estimate_bits_z(z):
prob = self.bitEstimator_z(z + 0.5) - self.bitEstimator_z(z - 0.5)
total_bits = torch.sum(torch.clamp(-1.0 * torch.log(prob + 1e-10) / math.log(2.0), 0, 50))
return total_bits, prob
total_bits_feature, _ = feature_probs_based_sigma(compressed_feature_renorm, recon_sigma)
total_bits_z, _ = iclr18_estimate_bits_z(compressed_z)
im_shape = input_image.size()
bpp_feature = total_bits_feature / (batch_size * im_shape[2] * im_shape[3])
bpp_z = total_bits_z / (batch_size * im_shape[2] * im_shape[3])
bpp = bpp_feature + bpp_z
return clipped_recon_image, mse_loss, bpp_feature, bpp_z, bpp