From e5f50f4b4ee37d5717fdafe05e5d803831d05ddf Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 17 Jun 2026 12:10:43 +0200 Subject: [PATCH 1/2] add sinkhorn loss --- docs/source/_rst/loss/sinkhorn_loss.rst | 10 ++ pina/_src/loss/sinkhorn_loss.py | 127 ++++++++++++++++++++++++ pina/loss/__init__.py | 2 + tests/test_loss/test_sinkhorn_loss.py | 83 ++++++++++++++++ 4 files changed, 222 insertions(+) create mode 100644 docs/source/_rst/loss/sinkhorn_loss.rst create mode 100644 pina/_src/loss/sinkhorn_loss.py create mode 100644 tests/test_loss/test_sinkhorn_loss.py diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst new file mode 100644 index 000000000..d997c3ec3 --- /dev/null +++ b/docs/source/_rst/loss/sinkhorn_loss.rst @@ -0,0 +1,10 @@ +Lp Loss +=============== +.. currentmodule:: pina.loss.sinkhorn_loss + +.. automodule:: pina._src.loss.sinkhorn_loss + :no-members: + +.. autoclass:: pina._src.loss.sinkhorn_loss.SinkhornLoss + :members: + :show-inheritance: diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py new file mode 100644 index 000000000..2eb226451 --- /dev/null +++ b/pina/_src/loss/sinkhorn_loss.py @@ -0,0 +1,127 @@ +"""Module for the SinkhornLoss class.""" + +import torch +from pina._src.loss.base_dual_loss import BaseDualLoss +from pina._src.core.utils import check_consistency, check_positive_integer + + +class SinkhornLoss(BaseDualLoss): + r""" + Implementation of the Sinkhorn Loss based on regularized optimal transport. + It measures the regularized Wasserstein distance between the empirical + distributions represented by ``input`` (with :math:`N` samples) and + ``target`` (with :math:`M` samples), each in :math:`\mathbb{R}^D`. + + The loss solves the entropy-regularized optimal transport problem: + + .. math:: + W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} + \langle C, \pi \rangle - \varepsilon H(\pi), + + where :math:`C_{ij} = \|x_i - y_j\|_2^p` is the cost matrix, + :math:`H(\pi) = -\sum_{ij} \pi_{ij} \log \pi_{ij}` is the entropy of + the transport plan, and :math:`\varepsilon > 0` is the regularization + strength. The dual objective recovered by the Sinkhorn iterations is: + + .. math:: + W_\varepsilon = \langle a, f^* \rangle + \langle b, g^* \rangle, + + where :math:`a` and :math:`b` are uniform probability weights over the + :math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are + the optimal dual potentials computed via log-space Sinkhorn iterations. + + If ``reduction`` is set to ``"mean"`` or ``"sum"``, the scalar transport + cost is aggregated accordingly (the output is always a scalar, so both + reductions are equivalent): + + .. math:: + \ell(x, y) = + \begin{cases} + \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\ + \operatorname{sum}(L), & \text{if reduction} = \text{``sum''} + \end{cases} + + .. note:: + Unlike pointwise losses, the Sinkhorn loss operates on entire empirical + distributions, so the output is always a scalar regardless of the + number of samples. The ``reduction`` parameter is retained for API + consistency. + + .. note:: + Smaller values of ``eps`` approximate the true Wasserstein distance + more closely but may require more iterations to converge. + + .. note:: + The algorithm is taken from "Sinkhorn AutoEncoders", arXiv:1810.01118. + """ + + def __init__(self, p=2, eps=0.1, max_iter=100, reduction="mean"): + """ + Initialization of the :class:`SinkhornLoss` class. + + :param int p: Exponent of the cost function :math:`\|x_i - y_j\|_2^p`. + Default is ``2``. + :param float eps: Entropy regularization strength + :math:`\varepsilon > 0`. Larger values yield smoother transport + plans. Default is ``0.1``. + :param int max_iter: Number of Sinkhorn iterations. Default is ``100``. + :param str reduction: The reduction method to aggregate the scalar loss. + Available options include: ``"none"``, ``"mean"``, ``"sum"``. + Default is ``"mean"``. + :raises ValueError: If ``p`` is not a numeric value. + :raises ValueError: If ``eps`` is not a positive float. + :raises AssertionError: If ``max_iter`` is not a strictly positive int. + """ + super().__init__(reduction=reduction) + + check_consistency(p, (int, float)) + check_consistency(eps, float) + if eps <= 0: + raise ValueError( + f"eps must be a strictly positive float, got {eps}." + ) + check_positive_integer(max_iter, strict=True) + + self.p = p + self.eps = eps + self.max_iter = max_iter + + def forward(self, input, target): + """ + Forward method of the loss function. + + :param torch.Tensor input: Input tensor of shape :math:`(N, D)`. + :param torch.Tensor target: Target tensor of shape :math:`(M, D)`. + :return: Sinkhorn loss value. + :rtype: torch.Tensor + """ + n = input.shape[0] + m = target.shape[0] + + a = input.new_full((n,), 1.0 / n) + b = target.new_full((m,), 1.0 / m) + + # Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M) + diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D) + C = torch.linalg.norm(diff, ord=2, dim=-1).pow(self.p) # (N, M) + + # Log-space Sinkhorn iterations for numerical stability + log_a = a.log() + log_b = b.log() + f = torch.zeros(n, dtype=input.dtype, device=input.device) + g = torch.zeros(m, dtype=target.dtype, device=target.device) + + for _ in range(self.max_iter): + # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) + f = self.eps * ( + log_a + - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) + ) + # g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps)) + g = self.eps * ( + log_b + - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) + ) + + loss = (a * f).sum() + (b * g).sum() + return self._reduction(loss.unsqueeze(0)) diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 52ed278c7..7966d2019 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -5,12 +5,14 @@ "BaseDualLoss", "LpLoss", "PowerLoss", + "SinkhornLoss" ] from pina._src.loss.dual_loss_interface import DualLossInterface from pina._src.loss.base_dual_loss import BaseDualLoss from pina._src.loss.power_loss import PowerLoss from pina._src.loss.lp_loss import LpLoss +from pina._src.loss.sinkhorn_loss import SinkhornLoss # Back-compatibility with version 0.2, to be removed soon import warnings diff --git a/tests/test_loss/test_sinkhorn_loss.py b/tests/test_loss/test_sinkhorn_loss.py new file mode 100644 index 000000000..40e647596 --- /dev/null +++ b/tests/test_loss/test_sinkhorn_loss.py @@ -0,0 +1,83 @@ +import torch +import pytest + +from pina.loss import SinkhornLoss + +# Fixed random tensors for reproducibility +torch.manual_seed(0) +input_ = torch.rand(10, 2) +target_ = torch.rand(8, 2) + + +@pytest.mark.parametrize("p", [1, 2, 3]) +@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0]) +@pytest.mark.parametrize("max_iter", [10, 100]) +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_constructor(p, eps, max_iter, reduction): + + SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction=reduction) + + # Should fail if p is not numeric + with pytest.raises(ValueError): + SinkhornLoss(p="invalid", eps=eps, max_iter=max_iter, reduction=reduction) + + # Should fail if eps is not a float + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=1, max_iter=max_iter, reduction=reduction) + + # Should fail if eps is not positive + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter, reduction=reduction) + + # Should fail if max_iter is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=p, eps=eps, max_iter=0, reduction=reduction) + + # Should fail if reduction is invalid + with pytest.raises(ValueError): + SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction="invalid") + + +@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) +def test_forward_shape(reduction): + + loss_fn = SinkhornLoss(reduction=reduction) + value = loss_fn(input_, target_) + assert value.shape == torch.Size([1]) + + +def test_forward_finite(): + + # The (non-debiased) Sinkhorn dual can be negative due to the entropy + # regularization term, but it must always be finite. + loss_fn = SinkhornLoss() + value = loss_fn(input_, target_) + assert torch.isfinite(value).all() + + +def test_forward_same_distribution_smaller(): + + # Sinkhorn loss on identical data should be smaller than on different data + loss_same = SinkhornLoss(eps=1e-3, max_iter=500)(input_, input_) + loss_diff = SinkhornLoss(eps=1e-3, max_iter=500)(input_, target_) + assert loss_same.item() < loss_diff.item() + + +def test_forward_asymmetric_sizes(): + + # input and target may have different numbers of rows + x = torch.rand(5, 3) + y = torch.rand(8, 3) + value = SinkhornLoss()(x, y) + assert value.shape == torch.Size([1]) + assert torch.isfinite(value).all() + + +def test_forward_approaches_wasserstein(): + + # For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N + x = torch.tensor([[1.0], [2.0], [3.0]]) + y = torch.tensor([[4.0], [5.0], [6.0]]) + # W_2^2 = ((1-4)^2 + (2-5)^2 + (3-6)^2) / 3 = 9 + value = SinkhornLoss(p=2, eps=1e-3, max_iter=5000)(x, y) + assert abs(value.item() - 9.0) < 0.1 From f88736303fab9dccdcfe88d464303f6a46212283 Mon Sep 17 00:00:00 2001 From: cyberguli Date: Wed, 17 Jun 2026 12:14:45 +0200 Subject: [PATCH 2/2] docs, tests and minor fixes for sinkhorn loss Co-authored-by: GiovanniCanali --- docs/source/_rst/_code.rst | 4 +- docs/source/_rst/loss/sinkhorn_loss.rst | 3 +- pina/_src/loss/sinkhorn_loss.py | 173 +++++++++++++----------- pina/loss/__init__.py | 2 +- tests/test_loss/test_sinkhorn_loss.py | 107 ++++++--------- 5 files changed, 137 insertions(+), 152 deletions(-) diff --git a/docs/source/_rst/_code.rst b/docs/source/_rst/_code.rst index 0c289183e..ecd50ec7d 100644 --- a/docs/source/_rst/_code.rst +++ b/docs/source/_rst/_code.rst @@ -330,6 +330,8 @@ Losses BaseDualLoss LpLoss PowerLoss + SinkhornLoss + Weighting Schemas -------------------- @@ -343,4 +345,4 @@ Weighting Schemas Neural-Tangent-Kernel Weighting No Weighting Scalar Weighting - Self-Adaptive Weighting \ No newline at end of file + Self-Adaptive Weighting diff --git a/docs/source/_rst/loss/sinkhorn_loss.rst b/docs/source/_rst/loss/sinkhorn_loss.rst index d997c3ec3..17aa370ad 100644 --- a/docs/source/_rst/loss/sinkhorn_loss.rst +++ b/docs/source/_rst/loss/sinkhorn_loss.rst @@ -1,5 +1,6 @@ -Lp Loss +Sinkhorn Loss =============== + .. currentmodule:: pina.loss.sinkhorn_loss .. automodule:: pina._src.loss.sinkhorn_loss diff --git a/pina/_src/loss/sinkhorn_loss.py b/pina/_src/loss/sinkhorn_loss.py index 2eb226451..9feddc458 100644 --- a/pina/_src/loss/sinkhorn_loss.py +++ b/pina/_src/loss/sinkhorn_loss.py @@ -7,121 +7,132 @@ class SinkhornLoss(BaseDualLoss): r""" - Implementation of the Sinkhorn Loss based on regularized optimal transport. - It measures the regularized Wasserstein distance between the empirical - distributions represented by ``input`` (with :math:`N` samples) and - ``target`` (with :math:`M` samples), each in :math:`\mathbb{R}^D`. + Implementation of the Sinkhorn loss measuring the entropy-regularized + optimal transport distance between two empirical distributions. - The loss solves the entropy-regularized optimal transport problem: + Given an input tensor :math:`x` with :math:`N` samples and a target tensor + :math:`y` with :math:`M` samples, both in :math:`\mathbb{R}^D`, the loss is + defined through the entropy-regularized optimal transport problem: .. math:: + W_\varepsilon(\mu, \nu) = \min_{\pi \in \Pi(\mu, \nu)} - \langle C, \pi \rangle - \varepsilon H(\pi), + \langle C, \pi \rangle - \varepsilon H(\pi) + + where :math:`\mu` and :math:`\nu` are the empirical distributions associated + with :math:`x` and :math:`y`, :math:`\pi` is a transport plan, and + :math:`\Pi(\mu, \nu)` is the set of admissible transport plans with + marginals :math:`\mu` and :math:`\nu`. - where :math:`C_{ij} = \|x_i - y_j\|_2^p` is the cost matrix, - :math:`H(\pi) = -\sum_{ij} \pi_{ij} \log \pi_{ij}` is the entropy of - the transport plan, and :math:`\varepsilon > 0` is the regularization - strength. The dual objective recovered by the Sinkhorn iterations is: + The cost matrix is defined as: .. math:: - W_\varepsilon = \langle a, f^* \rangle + \langle b, g^* \rangle, - where :math:`a` and :math:`b` are uniform probability weights over the - :math:`N` and :math:`M` samples respectively, and :math:`f^*, g^*` are - the optimal dual potentials computed via log-space Sinkhorn iterations. + C_{ij} = \left\| x_i - y_j \right\|_2^p + + and the entropy term is: + + .. math:: - If ``reduction`` is set to ``"mean"`` or ``"sum"``, the scalar transport - cost is aggregated accordingly (the output is always a scalar, so both - reductions are equivalent): + H(\pi) = - \sum_{i,j} \pi_{ij} \log \pi_{ij} + + where :math:`\varepsilon > 0` controls the strength of the entropic + regularization. + + The Sinkhorn iterations compute the optimal dual potentials :math:`f^\ast` + and :math:`g^\ast` in log space. The regularized optimal transport cost is + then recovered from the dual formulation as: .. math:: - \ell(x, y) = - \begin{cases} - \operatorname{mean}(L), & \text{if reduction} = \text{``mean''} \\ - \operatorname{sum}(L), & \text{if reduction} = \text{``sum''} - \end{cases} - - .. note:: - Unlike pointwise losses, the Sinkhorn loss operates on entire empirical - distributions, so the output is always a scalar regardless of the - number of samples. The ``reduction`` parameter is retained for API - consistency. - - .. note:: - Smaller values of ``eps`` approximate the true Wasserstein distance - more closely but may require more iterations to converge. - - .. note:: - The algorithm is taken from "Sinkhorn AutoEncoders", arXiv:1810.01118. + + W_\varepsilon = \langle a, f^\ast \rangle + \langle b, g^\ast \rangle + + where :math:`a` and :math:`b` are uniform probability weights over the + :math:`N` input samples and :math:`M` target samples, respectively. + + Unlike pointwise losses, the Sinkhorn loss compares whole empirical + distributions. Therefore, the output is always a scalar value. + + Smaller values of ``eps`` provide a closer approximation to the true + Wasserstein distance, but may require more Sinkhorn iterations to converge. + + .. seealso:: + + **Original reference:** Patrini, G., Carioni, M., Forr'e, P., Bhargav, + S., Welling, M., Van den Berg, R., Genewein, T., and Nielsen, F. (2019). + *Sinkhorn AutoEncoders*. + In Proceedings of the 35th Conference on Uncertainty in Artificial + Intelligence. + URL: ``_. """ - def __init__(self, p=2, eps=0.1, max_iter=100, reduction="mean"): + def __init__(self, p=2, eps=0.1, iterations=100): """ Initialization of the :class:`SinkhornLoss` class. - :param int p: Exponent of the cost function :math:`\|x_i - y_j\|_2^p`. - Default is ``2``. - :param float eps: Entropy regularization strength - :math:`\varepsilon > 0`. Larger values yield smoother transport - plans. Default is ``0.1``. - :param int max_iter: Number of Sinkhorn iterations. Default is ``100``. - :param str reduction: The reduction method to aggregate the scalar loss. - Available options include: ``"none"``, ``"mean"``, ``"sum"``. - Default is ``"mean"``. - :raises ValueError: If ``p`` is not a numeric value. - :raises ValueError: If ``eps`` is not a positive float. - :raises AssertionError: If ``max_iter`` is not a strictly positive int. + :param int p: The exponent of the cost function. Default is ``2``. + :param eps: The entropy regularization strength. Smaller values provide + a closer approximation to the unregularized Wasserstein distance, + but may require more iterations for convergence. Default is ``0.1``. + :type eps: int | float + :param int iterations: The number of Sinkhorn iterations. + Default is ``100``. + :raises AssertionError: If ``iterations`` is not a positive integer. + :raises AssertionError: If ``p`` is not a positive integer. + :raises ValueError: If ``eps`` is not a positive numeric value. """ - super().__init__(reduction=reduction) + # Initialize the base class with mean reduction + super().__init__(reduction="mean") - check_consistency(p, (int, float)) - check_consistency(eps, float) + # Check consistency + check_positive_integer(iterations, strict=True) + check_positive_integer(p, strict=True) + check_consistency(eps, (int, float)) if eps <= 0: raise ValueError( - f"eps must be a strictly positive float, got {eps}." + f"Expected 'eps' to be strictly positive, but got {eps}." ) - check_positive_integer(max_iter, strict=True) - self.p = p + # Initialize parameters + self.iterations = iterations self.eps = eps - self.max_iter = max_iter + self.p = p def forward(self, input, target): """ Forward method of the loss function. - :param torch.Tensor input: Input tensor of shape :math:`(N, D)`. - :param torch.Tensor target: Target tensor of shape :math:`(M, D)`. - :return: Sinkhorn loss value. + :param torch.Tensor input: The input tensor. + :param torch.Tensor target: The target tensor. + :return: The computed Sinkhorn loss value. :rtype: torch.Tensor """ - n = input.shape[0] - m = target.shape[0] - - a = input.new_full((n,), 1.0 / n) - b = target.new_full((m,), 1.0 / m) + # Extract the number of samples in input and target + n, m = input.shape[0], target.shape[0] - # Cost matrix C[i,j] = ||x_i - y_j||_2^p, shape (N, M) - diff = input.unsqueeze(1) - target.unsqueeze(0) # (N, M, D) - C = torch.linalg.norm(diff, ord=2, dim=-1).pow(self.p) # (N, M) + # Initialize log-uniform weights for the empirical distributions + log_a = -input.new_tensor(n).log().expand(n) + log_b = -target.new_tensor(m).log().expand(m) - # Log-space Sinkhorn iterations for numerical stability - log_a = a.log() - log_b = b.log() + # Initialize dual potentials f and g f = torch.zeros(n, dtype=input.dtype, device=input.device) g = torch.zeros(m, dtype=target.dtype, device=target.device) - for _ in range(self.max_iter): - # f_i = eps * (log a_i - logsumexp_j ((g_j - C_ij) / eps)) - f = self.eps * ( - log_a - - torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) - ) - # g_j = eps * (log b_j - logsumexp_i ((f_i - C_ij) / eps)) - g = self.eps * ( - log_b - - torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) - ) + # Define the cost matrix, shape (n, m) + C = torch.cdist(input, target, p=self.p) ** self.p + + # Perform Sinkhorn iterations in log space for numerical stability + for _ in range(self.iterations): + + # Update dual potential f with the softmin operation in log space + softmin_f = torch.logsumexp((g.unsqueeze(0) - C) / self.eps, dim=1) + f = self.eps * (log_a - softmin_f) + + # Update dual potential g with the softmin operation in log space + softmin_g = torch.logsumexp((f.unsqueeze(1) - C) / self.eps, dim=0) + g = self.eps * (log_b - softmin_g) + + # Compute the Sinkhorn loss as the sum of the means of f and g + loss = f.mean() + g.mean() - loss = (a * f).sum() + (b * g).sum() return self._reduction(loss.unsqueeze(0)) diff --git a/pina/loss/__init__.py b/pina/loss/__init__.py index 7966d2019..280cbf76a 100644 --- a/pina/loss/__init__.py +++ b/pina/loss/__init__.py @@ -5,7 +5,7 @@ "BaseDualLoss", "LpLoss", "PowerLoss", - "SinkhornLoss" + "SinkhornLoss", ] from pina._src.loss.dual_loss_interface import DualLossInterface diff --git a/tests/test_loss/test_sinkhorn_loss.py b/tests/test_loss/test_sinkhorn_loss.py index 40e647596..86eb4de62 100644 --- a/tests/test_loss/test_sinkhorn_loss.py +++ b/tests/test_loss/test_sinkhorn_loss.py @@ -1,83 +1,54 @@ import torch import pytest - from pina.loss import SinkhornLoss -# Fixed random tensors for reproducibility -torch.manual_seed(0) -input_ = torch.rand(10, 2) -target_ = torch.rand(8, 2) +@pytest.mark.parametrize("p", [1, 2]) +@pytest.mark.parametrize("eps", [0.01, 1]) +@pytest.mark.parametrize("iterations", [2, 5]) +def test_constructor(p, eps, iterations): -@pytest.mark.parametrize("p", [1, 2, 3]) -@pytest.mark.parametrize("eps", [0.01, 0.1, 1.0]) -@pytest.mark.parametrize("max_iter", [10, 100]) -@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) -def test_constructor(p, eps, max_iter, reduction): + # Define the loss + SinkhornLoss(p=p, eps=eps, iterations=iterations) - SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction=reduction) + # Should fail if iterations is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=p, eps=eps, iterations=0) - # Should fail if p is not numeric - with pytest.raises(ValueError): - SinkhornLoss(p="invalid", eps=eps, max_iter=max_iter, reduction=reduction) + # Should fail if p is not a positive integer + with pytest.raises(AssertionError): + SinkhornLoss(p=0, eps=eps, iterations=iterations) - # Should fail if eps is not a float + # Should fail if eps is not numeric with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=1, max_iter=max_iter, reduction=reduction) + SinkhornLoss(p=p, eps="invalid", iterations=iterations) # Should fail if eps is not positive with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=-0.1, max_iter=max_iter, reduction=reduction) - - # Should fail if max_iter is not a positive integer - with pytest.raises(AssertionError): - SinkhornLoss(p=p, eps=eps, max_iter=0, reduction=reduction) - - # Should fail if reduction is invalid - with pytest.raises(ValueError): - SinkhornLoss(p=p, eps=eps, max_iter=max_iter, reduction="invalid") - - -@pytest.mark.parametrize("reduction", ["mean", "sum", "none"]) -def test_forward_shape(reduction): - - loss_fn = SinkhornLoss(reduction=reduction) - value = loss_fn(input_, target_) - assert value.shape == torch.Size([1]) - - -def test_forward_finite(): - - # The (non-debiased) Sinkhorn dual can be negative due to the entropy - # regularization term, but it must always be finite. - loss_fn = SinkhornLoss() - value = loss_fn(input_, target_) - assert torch.isfinite(value).all() - - -def test_forward_same_distribution_smaller(): - - # Sinkhorn loss on identical data should be smaller than on different data - loss_same = SinkhornLoss(eps=1e-3, max_iter=500)(input_, input_) - loss_diff = SinkhornLoss(eps=1e-3, max_iter=500)(input_, target_) - assert loss_same.item() < loss_diff.item() - - -def test_forward_asymmetric_sizes(): - - # input and target may have different numbers of rows - x = torch.rand(5, 3) - y = torch.rand(8, 3) - value = SinkhornLoss()(x, y) + SinkhornLoss(p=p, eps=-0.1, iterations=iterations) + + +@pytest.mark.parametrize("p", [2, 3]) +@pytest.mark.parametrize("eps", [0.1, 1]) +@pytest.mark.parametrize("iterations", [2, 5]) +@pytest.mark.parametrize( + "input, target", + [ + (torch.rand(10, 2), torch.rand(8, 2)), + (torch.rand(5, 3), torch.rand(5, 3)), + (torch.rand(1, 4), torch.rand(7, 4)), + (torch.rand(6, 4), torch.rand(1, 4)), + (torch.rand(3, 1), torch.rand(4, 1)), + ], +) +def test_forward(p, eps, iterations, input, target): + + # Define the loss + loss = SinkhornLoss(p=p, eps=eps, iterations=iterations) + + # Forward pass + value = loss(input, target) + + # Check shape assert value.shape == torch.Size([1]) assert torch.isfinite(value).all() - - -def test_forward_approaches_wasserstein(): - - # For 1-D sorted distributions, W_2^2 = sum |x_i - y_i|^2 / N - x = torch.tensor([[1.0], [2.0], [3.0]]) - y = torch.tensor([[4.0], [5.0], [6.0]]) - # W_2^2 = ((1-4)^2 + (2-5)^2 + (3-6)^2) / 3 = 9 - value = SinkhornLoss(p=2, eps=1e-3, max_iter=5000)(x, y) - assert abs(value.item() - 9.0) < 0.1