Source code for pgcuts.losses.hycut

"""HyCut loss module."""
from typing import Tuple

import torch
from torch import nn

from ..hyp2f1.funct import hyp2f1


[docs] class HyCutLoss(nn.Module): """Hypergeometric envelope loss for graph cuts. Uses 2F1(-m, b; c; z) as an upper bound on the expected graph cut. """
[docs] def __init__( self, m: int, b: float = 1.0, c: float = 2.0, ema_decay: float = 0.9, ) -> None: """Initialize HyCutLoss. Args: m: Polynomial degree for 2F1. b: Second parameter of 2F1. c: Third parameter of 2F1. ema_decay: EMA decay for cluster proportions. """ super().__init__() self.m = m self.b = b self.c = c self.ema_decay = ema_decay
[docs] def forward( self, p_left: torch.Tensor, log_p_right: torch.Tensor, weights: torch.Tensor, alphas: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Compute the HyCut loss. Args: p_left: Softmax probs, shape (E, K). log_p_right: Log-softmax probs, shape (E, K). weights: Edge weights, shape (E,). alphas: EMA cluster proportions, shape (K,). Returns: Tuple of (loss, updated_alphas). """ m = torch.tensor( self.m, device=p_left.device, dtype=torch.int32, ) b = torch.tensor( self.b, device=p_left.device, dtype=p_left.dtype, ) c = torch.tensor( self.c, device=p_left.device, dtype=p_left.dtype, ) hycut = ( weights.unsqueeze(-1) * (-p_left * log_p_right) ).mean(0) p_mean = p_left.detach().mean(0) alpha_input = ( p_mean * (1 - self.ema_decay) + self.ema_decay * alphas ) hycut = ( hycut * hyp2f1(-m, b, c, alpha_input) ).sum() hycut = hycut / weights.sum() updated_alphas = ( alphas * self.ema_decay + (1 - self.ema_decay) * p_mean ) return hycut, updated_alphas