Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

I added structural similarity index (SSIM) loss. #27134

Merged
merged 12 commits into from
Jul 13, 2024
57 changes: 57 additions & 0 deletions ivy/functional/ivy/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,60 @@ def sparse_cross_entropy(
return ivy.cross_entropy(
true, pred, axis=axis, epsilon=epsilon, reduction=reduction, out=out
)


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def ssim_loss(
pred: Union[ivy.Array, ivy.NativeArray], ytrue: Union[ivy.Array, ivy.NativeArray]
):
"""
Calculates the Structural Similarity Index (SSIM) loss between two images.

Args:
pred: A 3D image tensor of shape (batch_size, channels, height, width).
ytrue: A 3D image tensor of shape (batch_size, channels, height, width).

Returns:
-------
inv.array: The SSIM loss mesure similarity between the two images.

Examples
--------
With :class:`ivy.Array` input:
>>> import ivy
>>> x = ivy.ones((batch_size, channels, height, width))
>>> y = ivy.zeros((batch_size, channels, height, width))
>>> loss = ivy.ssim_loss(x, y)
>>> print(loss)
ivy.array(0.99989992)
"""

# Calculate the mean and variance of the two images.
mu_x = iv.avg_pool2d(pred, (3, 3), (3, 3), "SAME")
mu_y = iv.avg_pool2d(ytrue, (3, 3), (3, 3), "SAME")

sigma_x2 = iv.avg_pool2d(pred * pred, (3, 3), (3, 3), "SAME") - mu_x * mu_x
sigma_y2 = iv.avg_pool2d(ytrue * ytrue, (3, 3), (3, 3), "SAME") - mu_y * mu_y

sigma_xy = iv.avg_pool2d(pred * ytrue, (3, 3), (3, 3), "SAME") - mu_x * mu_y

# Add small constants to avoid division by zero.
C1 = 0.01**2
C2 = 0.03**2

# Calculate the SSIM index.
ssim = ((2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)) / (
(mu_x**2 + mu_y**2 + C1) * (sigma_x2 + sigma_y2 + C2)
)

# Subtract 1 from the SSIM index to make the loss function compatible with other loss functions that are minimized when the difference between the predicted image and the ground truth image is zero.
ssim = 1 - ssim

# Take the mean of the loss values for each image in the batch to get a single loss value that can be used to train the model.
loss = iv.mean(ssim)

return loss
Loading