Source code for pgcuts.losses.prcut

"""PRCut loss modules."""
from typing import Tuple

import torch
from torch import nn


@torch.no_grad()
def offline_gradient(
    w_mat: torch.Tensor, probs: torch.Tensor
) -> torch.Tensor:
    """Compute full PRCut gradient offline.

    Args:
        w_mat: Weight matrix, shape (n, n).
        probs: Probability matrix, shape (n, k).

    Returns:
        Gradient w.r.t. probs, shape (n, k).
    """
    ov_p = probs.mean(0)
    left = (
        w_mat.sum(0).unsqueeze(1)
        - 2 * torch.mm(w_mat, probs.detach())
    ) / ov_p
    right = (
        -(w_mat.mm(probs) - w_mat.mm(probs) * probs)
        .sum(0)
        / ov_p ** 2
        / probs.size(0)
    )
    return left + right


@torch.no_grad()
def batch_cluster_prcut_loss(
    w_mat: torch.Tensor,
    p_l: torch.Tensor,
    p_r: torch.Tensor,
    ov_p: torch.Tensor,
) -> torch.Tensor:
    """Per-cluster PRCut loss for a batch.

    Args:
        w_mat: Weight matrix, shape (a, b).
        p_l: Left probabilities, shape (a, k).
        p_r: Right probabilities, shape (b, k).
        ov_p: Cluster likelihood, shape (k,).

    Returns:
        Per-cluster loss, shape (k,).
    """
    p_l = p_l.unsqueeze(1)
    p_r = p_r.unsqueeze(0)
    return (
        w_mat.unsqueeze(-1)
        * (p_l + p_r - 2 * p_l * p_r)
    ).sum(dim=(0, 1)) / ov_p


@torch.no_grad()
def batch_gradient(
    w_mat: torch.Tensor,
    p_l: torch.Tensor,
    p_r: torch.Tensor,
    ov_p: torch.Tensor,
    n: float,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Compute PRCut batch gradient.

    Args:
        w_mat: Weight matrix, shape (a, b).
        p_l: Left probabilities, shape (a, k).
        p_r: Right probabilities, shape (b, k).
        ov_p: Cluster likelihood, shape (k,).
        n: Number of samples.

    Returns:
        Tuple of gradients w.r.t. left and right.
    """
    left_l = w_mat.mm(1 - 2 * p_r) / ov_p
    left_r = w_mat.t().mm(1 - 2 * p_l) / ov_p
    right = (
        -batch_cluster_prcut_loss(
            w_mat, p_l, p_r, ov_p
        )
        / ov_p
        / n
    )
    return left_l + right, left_r + right


[docs] class PRCutGradLoss(nn.Module): """PRCut loss with analytical gradient. Uses the analytical gradient to construct a surrogate loss for backpropagation. """
[docs] def forward( self, w_mat, p_l, p_r, ov_p, n ) -> torch.Tensor: """Compute surrogate PRCut loss. Args: w_mat: Weight tensor, shape (a, b). p_l: Left probabilities, shape (a, k). p_r: Right probabilities, shape (b, k). ov_p: Cluster likelihood, shape (k,). n: Number of samples. Returns: Scalar loss. """ p_l_grad, p_r_grad = batch_gradient( w_mat, p_l, p_r, ov_p, n ) return (p_l_grad * p_l).sum() + ( p_r_grad * p_r ).sum()
[docs] class PRCutBatchLoss(nn.Module): """PRCut batch estimate with EMA tracking. Maintains running estimate of cluster probabilities. """
[docs] def __init__( self, num_clusters: int, gamma: float ) -> None: """Initialize PRCutBatchLoss. Args: num_clusters: Number of clusters. gamma: EMA decay rate. """ super().__init__() self.num_clusters = num_clusters self.gamma = gamma self.clusters_p = nn.Parameter( torch.ones(num_clusters) / num_clusters )
[docs] @torch.no_grad() def update_cluster_p( self, probs: torch.Tensor ) -> None: """Update cluster probability estimates. Args: probs: Assignment probs, shape (n, k). """ self.clusters_p.data.mul_(1 - self.gamma) self.clusters_p.data.add_( probs.detach().mean(0) * self.gamma )
@property def cluster_likelihood(self) -> torch.Tensor: """Return current cluster likelihood.""" return self.clusters_p
[docs] @torch.no_grad() def forward( self, w_mat, p_l, p_r ) -> torch.Tensor: """Compute PRCut batch loss. Args: w_mat: Weight tensor, shape (a, b). p_l: Left probabilities, shape (a, k). p_r: Right probabilities, shape (b, k). Returns: Scalar loss. """ p_i = p_l.unsqueeze(1) p_j = p_r.unsqueeze(0) return ( ( w_mat.unsqueeze(-1) * (p_i + p_j - 2 * p_i * p_j) ).sum(dim=(0, 1)) / self.clusters_p ).sum()
[docs] class SimplexL2Loss(nn.Module): """Simplex L2 regularization loss."""
[docs] def forward( self, probs: torch.Tensor, ov_p: torch.Tensor, ) -> torch.Tensor: """Compute simplex L2 loss. Args: probs: Assignment probs, shape (n, k). ov_p: Cluster likelihood, shape (k,). Returns: Scalar loss. """ return ( (ov_p - 1 / probs.size(1)) * probs.mean(0) ).sum()