from .SpTensor import SparseTensor
from torch import Tensor
import torch
from .utils import torch_scatter_reduce
[docs]
def spmm(A: SparseTensor, dim1: int, X: Tensor, aggr: str = "sum") -> Tensor:
"""
SparseTensor, Tensor matrix multiplication.
This function performs a matrix multiplication between a SparseTensor `A` and a dense tensor `X` along the specified dimension `dim1`. The result is a dense tensor. The `aggr` parameter specifies the reduction operation used for merging the resulting values.
Args:
- A (SparseTensor): The SparseTensor used for multiplication.
- dim1 (int): The dimension along which `A` is reduced.
- X (Tensor): The dense tensor to be multiplied with `A`. It dim 0 will be reduced.
- aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mean"). Defaults to "sum".
Returns:
- Tensor: A dense tensor containing the result of the matrix multiplication.
Notes:
- `A` should be a 2-dimensional SparseTensor.
- The dense shapes of `A` and `X` other than `dim1` must be broadcastable.
"""
assert A.sparse_dim == 2, "can only use 2-dim sparse tensor"
val = A.values
if dim1 == 0:
srcind = A.indices[0]
tarind = A.indices[1]
tarshape = A.shape[1]
else:
srcind = A.indices[1]
tarind = A.indices[0]
tarshape = A.shape[0]
if val is None:
mult = X[srcind]
else:
mult = val * X[srcind]
ret = torch_scatter_reduce(0, mult, tarind, tarshape, aggr)
return ret