Source code for pgcuts.hyp2f1.autograd

"""PyTorch autograd wrapper for 2F1(-m, b; c; z).

Derivative identity:
    d/dz 2F1(a, b; c; z) = (a*b/c) * 2F1(a+1, b+1; c+1; z)

Backend selection:
    - Default: Triton (portable across GPU vendors)
    - Fallback: CUDA (NVIDIA only, JIT-compiled)

Set HYP2F1_BACKEND='cuda' to force the CUDA backend.
"""

import os
import torch

_BACKEND = os.environ.get(
    "HYP2F1_BACKEND", "triton"
).lower()

if _BACKEND == "cuda":
    from .cuda_kernels import (
        par_hyp2f1_precomp as _hyp2f1_kernel,
    )
else:
    from .triton_kernels import (
        triton_hyp2f1 as _hyp2f1_kernel,
    )


def _reduce_broadcast(grad, target_shape):
    """Sum grad over broadcast dimensions."""
    while grad.dim() > len(target_shape):
        grad = grad.sum(0)
    for i, (gs, ts) in enumerate(
        zip(grad.shape, target_shape)
    ):
        if ts == 1 and gs > 1:
            grad = grad.sum(i, keepdim=True)
    return grad


[docs] class Hyp2F1(torch.autograd.Function): """Differentiable 2F1(-m, b; c; z) with backward. Uses Triton backend by default (portable). Usage:: result = Hyp2F1.apply(a, b, c, z) result.sum().backward() # computes z.grad """
[docs] @staticmethod def forward(ctx, a, b, c, z): """Compute 2F1(a, b; c; z) forward pass.""" a_val = int(a) ctx.save_for_backward(z) ctx.a_val = a_val ctx.b = b ctx.c = c ctx.z_shape = z.shape return _hyp2f1_kernel(a_val, b, c, z)
[docs] @staticmethod def backward(ctx, grad_output): """Compute gradient w.r.t. z.""" (z,) = ctx.saved_tensors a_val = ctx.a_val b = ctx.b c = ctx.c m = -a_val if m == 0: return None, None, None, torch.zeros_like(z) # d/dz 2F1(a,b;c;z) = (a*b/c) * 2F1(a+1,b+1;c+1;z) df = _hyp2f1_kernel(a_val + 1, b + 1, c + 1, z) device, dtype = z.device, z.dtype b_t = ( b if isinstance(b, torch.Tensor) else torch.tensor( b, dtype=dtype, device=device ) ) c_t = ( c if isinstance(c, torch.Tensor) else torch.tensor( c, dtype=dtype, device=device ) ) deriv_scale = a_val * b_t / c_t grad_full = grad_output * deriv_scale * df grad_z = _reduce_broadcast( grad_full, ctx.z_shape ) return None, None, None, grad_z