Source code for pygho.backend.Spmamm

from .MaTensor import MaskedTensor, filterinf
import torch
from torch import BoolTensor, Tensor
from typing import Optional
from .SpTensor import SparseTensor
from .utils import torch_scatter_reduce

filled_value_dict = {"sum": 0, "max": -torch.inf, "min": torch.inf}
filter_inf_ops = ["max", "min"]


[docs] def spmamm(A: SparseTensor, dim1: int, B: MaskedTensor, dim2: int, mask: Optional[BoolTensor] = None, aggr: str = "sum") -> MaskedTensor: """ SparseTensor-MaskedTensor multiplication. This function performs multiplication between a SparseTensor `A` and a MaskedTensor `B`. The specified dimensions `dim1` and `dim2` are contracted during the multiplication, and the result is returned as a MaskedTensor. Args: - A (SparseTensor): The SparseTensor with shape (B, n, m, \*shape). - dim1 (int): The dimension to contract in the SparseTensor `A`. - B (MaskedTensor): The MaskedTensor with shape (B, m, \*shape). - dim2 (int): The dimension to contract in the MaskedTensor `B`. - mask (BoolTensor, optional): The mask to apply to the resulting MaskedTensor. Default is None. - aggr (str, optional): The aggregation method for reduction during multiplication (e.g., "sum", "max"). Default is "sum". Returns: - MaskedTensor: A new MaskedTensor with shape (B, n,\*denseshapeshape) and the specified mask. Notes: - This function performs multiplication between a SparseTensor and a MaskedTensor, contracting the specified dimensions. - The `aggr` parameter controls the reduction operation during multiplication. - The result is returned as a MaskedTensor. """ assert A.sparse_dim == 3, f"A should have 3 sparse dims, but input has {A.sparse_dim}" assert aggr != "mean", "not implemented" if dim1 == 1: b, n = A.shape[0], A.shape[2] bij = A.indices[0], A.indices[1] tar_ind = n * A.indices[0] + A.indices[2] elif dim1 == 2: b, n = A.shape[0], A.shape[1] bij = A.indices[0], A.indices[2] tar_ind = n * A.indices[0] + A.indices[1] else: raise NotImplementedError Aval = A.values tB = torch.movedim(B.data, dim2, 1) tBmask = torch.movedim(B.mask, dim2, 1) if Aval is not None: mult = Aval.unsqueeze(1) * tB[bij[0], bij[1]] else: mult = tB[bij[0], bij[1]] validmask = tBmask[bij[0], bij[1]] mult.masked_fill(torch.logical_not(validmask), filled_value_dict[aggr]) val = torch_scatter_reduce(0, mult, tar_ind, b*n, aggr) ret = val.unflatten(0, (b, n)) ret = torch.movedim(ret, 1, dim2) if aggr in filter_inf_ops: ret = filterinf(ret) return MaskedTensor(ret, mask if mask is not None else B.mask)