pgcuts.utils
Gradient mixing
- class pgcuts.optim.GradientMixer(named_parameters, loss_scale)[source]
Mix gradients from multiple losses.
Usage:
from pgcuts.optim import GradientMixer
grad_mix = GradientMixer(
network.named_parameters(),
loss_scale={"cut": 1.0, "balance": 1.0},
)
optimizer.zero_grad()
with grad_mix("cut"):
cut_loss.backward(retain_graph=True)
with grad_mix("balance"):
balance_loss.backward()
optimizer.step()