"""PGCut losses -- RatioCut and NCut with hyp envelope.
RatioCut (s(v)=1 for all v):
H(alpha_bar) = 2F1(-m, 1; 2; alpha_bar)
Same for every vertex.
NCut (s(v)=d_v, the degree):
Holder bound with binning (Theorem 2).
Phi varies per vertex, unlike RatioCut.
"""
from typing import List, Tuple, Optional
import numpy as np
import torch
from torch import nn, Tensor
from ..hyp2f1.autograd import Hyp2F1
_hyp2f1 = Hyp2F1.apply
# -----------------------------------------------------------
# Binning
# -----------------------------------------------------------
[docs]
def equal_size_bins(degrees: np.ndarray, num_bins: int) -> List[dict]:
"""Partition vertices into equal-size bins."""
sorted_idx = np.argsort(degrees)
splits = np.array_split(sorted_idx, num_bins)
bins = []
for indices in splits:
if len(indices) == 0:
continue
bins.append(
{
"beta_star": float(degrees[indices].min()),
"indices": indices,
"count": len(indices),
}
)
return bins
[docs]
def log_kmeans_bins(degrees: np.ndarray, num_bins: int) -> List[dict]:
"""Partition vertices via K-Means on log-degrees."""
# pylint: disable=import-outside-toplevel
from sklearn.cluster import KMeans
log_deg = np.log(degrees + 1e-6).reshape(-1, 1)
labels = KMeans(n_clusters=num_bins, n_init=10, random_state=42).fit_predict(
log_deg
)
bins = []
for label in range(num_bins):
indices = np.where(labels == label)[0]
if len(indices) == 0:
continue
bins.append(
{
"beta_star": float(degrees[indices].min()),
"indices": indices,
"count": len(indices),
}
)
return sorted(bins, key=lambda b: b["beta_star"])
# -----------------------------------------------------------
# Shared
# -----------------------------------------------------------
def edge_source_weights(w_mat: Tensor, probs: Tensor) -> Tensor:
"""M_il(P) = sum_j W_ij P_il (1 - P_jl)."""
return probs * torch.mm(w_mat, 1.0 - probs)
# -----------------------------------------------------------
# RatioCut -- Theorem 1 (homogeneous beta = 1)
# -----------------------------------------------------------
[docs]
class RatioCutLoss(nn.Module):
"""Probabilistic RatioCut with hyp envelope.
H_l = 2F1(-m, 1; 2; alpha_bar_l).
"""
[docs]
def __init__(self, n: int, ema_decay: float = 0.0) -> None:
"""Initialize RatioCutLoss.
Args:
n: Dataset size (used as polynomial degree).
ema_decay: EMA decay for alpha tracking.
"""
super().__init__()
self.n = n
self.ema_decay = ema_decay
[docs]
def forward(
self,
w_mat: Tensor,
probs: Tensor,
alpha_ema: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute RatioCut envelope loss.
Args:
w_mat: Adjacency matrix, shape (n, n).
probs: Assignment probs, shape (n, K).
alpha_ema: Running means, shape (K,).
Returns:
(loss, updated_alpha_bar).
"""
alpha_bar = probs.detach().mean(0)
if alpha_ema is not None and self.ema_decay > 0:
alpha_bar = self.ema_decay * alpha_ema + (1 - self.ema_decay) * alpha_bar
h_val = _hyp2f1(
-self.n,
1.0,
2.0,
alpha_bar.clamp(1e-7, 1 - 1e-7),
)
m_weights = edge_source_weights(w_mat, probs)
loss = (m_weights.sum(0) * h_val).sum() / w_mat.sum().clamp(min=1e-9)
return loss, alpha_bar
# -----------------------------------------------------------
# NCut -- Theorem 2 (Holder bound, beta = d_i)
# -----------------------------------------------------------
[docs]
class NCutLoss(nn.Module):
"""Probabilistic NCut with Holder-binned envelope.
Phi varies per vertex, unlike RatioCut.
"""
[docs]
def __init__(
self,
degrees: np.ndarray,
num_bins: int = 16,
binning: str = "equal",
ema_decay: float = 0.0,
) -> None:
"""Initialize NCutLoss.
Args:
degrees: Vertex degrees, shape (n,).
num_bins: Number of Holder bins.
binning: 'equal' or 'log_kmeans'.
ema_decay: EMA decay for per-bin tracking.
"""
super().__init__()
self.ema_decay = ema_decay
n = len(degrees)
if binning == "equal":
self.bins = equal_size_bins(degrees, num_bins)
elif binning == "log_kmeans":
self.bins = log_kmeans_bins(degrees, num_bins)
else:
raise ValueError(f"Unknown binning: {binning}")
self.register_buffer(
"degrees_t",
torch.tensor(degrees, dtype=torch.float32),
)
beta_stars = torch.tensor(
[b["beta_star"] for b in self.bins],
dtype=torch.float32,
)
self.register_buffer("beta_stars", beta_stars)
counts = torch.tensor(
[b["count"] for b in self.bins],
dtype=torch.float32,
)
self.register_buffer("bin_weights", counts / counts.sum())
node_to_bin = np.zeros(n, dtype=np.int64)
for j, b in enumerate(self.bins):
node_to_bin[b["indices"]] = j
self.register_buffer(
"node_to_bin",
torch.tensor(node_to_bin, dtype=torch.long),
)
self._bin_indices = [
torch.tensor(b["indices"], dtype=torch.long) for b in self.bins
]
def _bin_means(self, probs: Tensor) -> Tensor:
"""Per-bin mean assignments, shape (d, K)."""
num_bins = len(self.bins)
num_clusters = probs.shape[1]
alpha_bars = torch.zeros(
num_bins,
num_clusters,
device=probs.device,
dtype=probs.dtype,
)
for j, idx in enumerate(self._bin_indices):
idx = idx.to(probs.device)
alpha_bars[j] = probs[idx].mean(0)
return alpha_bars
[docs]
def compute_phi(
self,
q: Tensor,
alpha_bars: Tensor,
m: int,
) -> Tensor:
"""Compute per-vertex Holder envelope.
Args:
q: Per-vertex degrees, shape (num_v,).
alpha_bars: Per-bin means, shape (d, K).
m: Polynomial degree for 2F1.
Returns:
Phi values, shape (num_v, K).
"""
num_v = q.shape[0]
num_bins, num_clusters = alpha_bars.shape
device = q.device
beta = self.beta_stars.to(device)
w = self.bin_weights.to(device)
c = q.unsqueeze(1) / beta.unsqueeze(0) + 1.0
z = alpha_bars.clamp(1e-7, 1 - 1e-7)
c_3d = c.unsqueeze(2).expand(num_v, num_bins, num_clusters)
z_3d = z.unsqueeze(0).expand(num_v, num_bins, num_clusters)
f_val = _hyp2f1(-m, 1.0, c_3d, z_3d)
h_val = f_val / q.view(num_v, 1, 1)
log_h = torch.log(h_val.clamp(min=1e-30))
w_3d = w.view(1, num_bins, 1)
log_phi = (w_3d * log_h).sum(dim=1)
return torch.exp(log_phi)
[docs]
def forward(
self,
w_mat: Tensor,
probs: Tensor,
alpha_bars_ema: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
"""Compute NCut envelope loss.
Args:
w_mat: Adjacency matrix, shape (n, n).
probs: Assignment probs, shape (n, K).
alpha_bars_ema: Running per-bin means,
shape (d, K).
Returns:
(loss, updated_alpha_bars).
"""
n = probs.shape[0]
device = probs.device
alpha_bars = self._bin_means(probs.detach())
if alpha_bars_ema is not None and self.ema_decay > 0:
alpha_bars = (
self.ema_decay * alpha_bars_ema + (1 - self.ema_decay) * alpha_bars
)
q = self.degrees_t.to(device).clamp(min=1e-6)
m = n
phi = self.compute_phi(q, alpha_bars, m)
m_weights = edge_source_weights(w_mat, probs)
loss = (m_weights * phi).sum() / w_mat.sum().clamp(min=1e-9)
return loss, alpha_bars
# -----------------------------------------------------------
# Edge-pair interface for minibatch training
# -----------------------------------------------------------
[docs]
def compute_ncut_bin_phi(
q_stars: Tensor,
alpha_bars: Tensor,
beta_stars: Tensor,
bin_weights: Tensor,
m: int,
) -> Tensor:
"""Compute per-bin envelope Phi (without 1/q).
Args:
q_stars: Per-bin degrees, shape (d,).
alpha_bars: Per-bin means, shape (d, K).
beta_stars: Per-bin exponents, shape (d,).
bin_weights: Holder weights, shape (d,).
m: Polynomial degree.
Returns:
Phi values, shape (d, K).
"""
num_bins, num_clusters = alpha_bars.shape
q = q_stars.clamp(min=1e-6)
c = q.unsqueeze(1) / beta_stars.unsqueeze(0) + 1.0
z = alpha_bars.clamp(1e-7, 1 - 1e-7)
c_2d = (
c.unsqueeze(2)
.expand(num_bins, num_bins, num_clusters)
.reshape(num_bins * num_bins, num_clusters)
)
z_2d = (
z.unsqueeze(0)
.expand(num_bins, num_bins, num_clusters)
.reshape(num_bins * num_bins, num_clusters)
)
f_val = _hyp2f1(-m, 1.0, c_2d, z_2d).view(num_bins, num_bins, num_clusters)
log_f = torch.log(f_val.clamp(min=1e-30))
w = bin_weights.view(1, num_bins, 1)
log_phi = (w * log_f).sum(dim=1)
return torch.exp(log_phi)