Source code for pygho.backend.Spspmm

import torch
from torch import LongTensor, Tensor
from typing import Optional, Callable, Tuple
from .SpTensor import SparseTensor, indicehash, decodehash
import warnings
from .utils import torch_scatter_reduce

[docs] def ptr2batch(ptr: LongTensor, dim_size: int) -> LongTensor: """ Converts a pointer tensor to a batch tensor. TODO: use torch_scatter gather instead? This function takes a pointer tensor `ptr` and a `dim_size` and converts it to a batch tensor where each element in the batch tensor corresponds to a range of indices in the original tensor. Args: - ptr (LongTensor): The pointer tensor, where `ptr[0] = 0` and `torch.all(diff(ptr) >= 0)` is true. - dim_size (int): The size of the target dimension. Returns: - LongTensor: A batch tensor of shape `(dim_size,)` where `batch[ptr[i]:ptr[i+1]] = i`. """ assert ptr.ndim == 1, "ptr should be 1-d" assert ptr[0] == 0 and torch.all( torch.diff(ptr) >= 0), "should put in a ptr tensor" assert ptr[-1] == dim_size, "dim_size should match ptr" tmp = torch.arange(dim_size, device=ptr.device, dtype=ptr.dtype) ret = torch.searchsorted(ptr, tmp, right=True) - 1 return ret
[docs] def spspmm_ind(ind1: LongTensor, dim1: int, ind2: LongTensor, dim2: int, is_k2_sorted: bool = False) -> Tuple[LongTensor, LongTensor]: """ Sparse-sparse matrix multiplication for indices. This function performs a sparse-sparse matrix multiplication for indices. Given two sets of indices `ind1` and `ind2`, this function eliminates `dim1` in `ind1` and `dim2` in `ind2`, and concatenates the remaining dimensions. The result represents the product of the input indices. Args: - ind1 (LongTensor): The indices of the first sparse tensor of shape `(sparsedim1, M1)`. - dim1 (int): The dimension to eliminate in `ind1`. - ind2 (LongTensor): The indices of the second sparse tensor of shape `(sparsedim2, M2)`. - dim2 (int): The dimension to eliminate in `ind2`. - is_k2_sorted (bool, optional): Whether `ind2` is sorted along `dim2`. Defaults to `False`. Returns: - tarind: LongTensor: The resulting indices after performing the sparse-sparse matrix multiplication. - bcd: LongTensor: In tensor perspective (\*i_1, k, \*i_2), (\*j_1, k, \*j_2) -> (\*i_1, \*i_2, \*j_1, \*j_2). The return indice is of shape (3, nnz), (b, c, d), c represent index of \*i, d represent index of \*j, b represent index of output.For i=1,2,...,nnz, val1[c[i]] * val2[d[i]] will be add to output val's b[i]-th element. Example: :: ind1 = torch.tensor([[0, 1, 1, 2], [2, 1, 0, 2]], dtype=torch.long) dim1 = 0 ind2 = torch.tensor([[2, 1, 0, 1], [1, 0, 2, 2]], dtype=torch.long) dim2 = 1 result = spspmm_ind(ind1, dim1, ind2, dim2) """ assert 0 <= dim1 < ind1.shape[ 0], f"ind1's reduced dim {dim1} is out of range" assert 0 <= dim2 < ind2.shape[ 0], f"ind2's reduced dim {dim2} is out of range" if dim2 != 0 and not (is_k2_sorted): perm = torch.argsort(ind2[dim2]) tarind, bcd = spspmm_ind(ind1, dim1, ind2[:, perm], dim2, True) bcd[2] = perm[bcd[2]] return tarind, bcd else: nnz1, nnz2, sparsedim1, sparsedim2 = ind1.shape[1], ind2.shape[ 1], ind1.shape[0], ind2.shape[0] k1, k2 = ind1[dim1], ind2[dim2] assert torch.all(torch.diff(k2) >= 0), "ind2[0] should be sorted" # for each k in k1, it can match a interval of k2 as k2 is sorted upperbound = torch.searchsorted(k2, k1, right=True) lowerbound = torch.searchsorted(k2, k1, right=False) matched_num = torch.clamp_min_(upperbound - lowerbound, 0) # ptr[i] provide the offset to place pair of ind1[:, i] and the matched ind2 retptr = torch.zeros((nnz1 + 1), dtype=matched_num.dtype, device=matched_num.device) torch.cumsum(matched_num, dim=0, out=retptr[1:]) retsize = retptr[-1] # fill the output with ptr ret = torch.zeros((3, retsize), device=ind1.device, dtype=ind1.dtype) ret[1] = ptr2batch(retptr, retsize) torch.arange(retsize, out=ret[2], device=ret.device, dtype=ret.dtype) offset = (ret[2][retptr[:-1]] - lowerbound)[ret[1]] ret[2] -= offset # compute the ind pair index combinedind = indicehash( torch.concat( ((torch.concat((ind1[:dim1], ind1[dim1 + 1:])))[:, ret[1]], torch.concat((ind2[:dim2], ind2[dim2 + 1:]))[:, ret[2]]))) combinedind, taridx = torch.unique(combinedind, sorted=True, return_inverse=True) tarind = decodehash(combinedind, sparsedim1 + sparsedim2 - 2) ret[0] = taridx sorted_idx = torch.argsort(ret[0]) # sort is optional return tarind, ret[:, sorted_idx]
[docs] def spsphadamard_ind(tar_ind: LongTensor, ind: LongTensor) -> LongTensor: """ Auxiliary function for SparseTensor-SparseTensor Hadamard product. This function is an auxiliary function used in the Hadamard product of two sparse tensors. Given the indices `tar_ind` of sparse tensor A and the indices `ind` of sparse tensor B, this function returns an index array `b2a` of shape `(ind.shape[1],)` such that `ind[:, i]` matches `tar_ind[:, b2a[i]]` for each `i`. If `b2a[i]` is less than 0, it means `ind[:, i]` is not matched. Args: - tar_ind (LongTensor): The indices of sparse tensor A. - ind (LongTensor): The indices of sparse tensor B. Returns: - LongTensor: An index array `b2a` representing the matching indices between `tar_ind` and `ind`. b2a of shape ind.shape[1]. ind[:, i] matches tar_ind[:, b2a[i]]. if b2a[i]<0, ind[:, i] is not matched Example: :: tar_ind = torch.tensor([[0, 1, 1, 2], [2, 1, 0, 2]], dtype=torch.long) ind = torch.tensor([[2, 1, 0, 1], [1, 0, 2, 2]], dtype=torch.long) b2a = spsphadamard_ind(tar_ind, ind) """ assert tar_ind.shape[0] == ind.shape[0] combine_tar_ind = indicehash(tar_ind) assert torch.all(torch.diff(combine_tar_ind) > 0), "tar_ind should be sorted and coalesce" combine_ind = indicehash(ind) b2a = torch.clamp_min_( torch.searchsorted(combine_tar_ind, combine_ind, right=True) - 1, 0) notmatchmask = (combine_ind != combine_tar_ind[b2a]) b2a[notmatchmask] = -1 return b2a
[docs] def filterind(tar_ind: LongTensor, ind: LongTensor, bcd: LongTensor) -> LongTensor: """ A combination of Hadamard and Sparse Matrix Multiplication. Given the indices `tar_ind` of sparse tensor A, the indices `ind` of sparse tensor BC, and the index array `bcd`, this function returns an index array `acd`, where `(A ⊙ (BC)).val[a] = A.val[a] * scatter(B.val[c] * C.val[d], a)`. Args: - tar_ind (LongTensor): The indices of sparse tensor A. - ind (LongTensor): The indices of sparse tensor BC. - bcd (LongTensor): An index array representing `(BC).val`. Returns: - LongTensor: An index array `acd` representing the filtered indices. Example: :: tar_ind = torch.tensor([[0, 1, 1, 2], [2, 1, 0, 2]], dtype=torch.long) ind = torch.tensor([[2, 1, 0, 1], [1, 0, 2, 2]], dtype=torch.long) bcd = torch.tensor([[3, 2, 1, 0], [6, 5, 4, 3], [9, 8, 7, 6]], dtype=torch.long) acd = filterind(tar_ind, ind, bcd) """ b2a = spsphadamard_ind(tar_ind, ind) a = b2a[bcd[0]] retmask = a >= 0 acd = torch.stack((a[retmask], bcd[1][retmask], bcd[2][retmask])) return acd
[docs] def spsphadamard(A: SparseTensor, B: SparseTensor, b2a: Optional[LongTensor] = None) -> SparseTensor: """ Element-wise Hadamard product between two SparseTensors. This function performs the element-wise Hadamard product between two SparseTensors, `A` and `B`. The `b2a` parameter is an optional auxiliary index produced by the `spsphadamard_ind` function. Args: - A (SparseTensor): The first SparseTensor. - B (SparseTensor): The second SparseTensor. - b2a (LongTensor, optional): An optional index array produced by `spsphadamard_ind`. If not provided, it will be computed. Returns: - SparseTensor: A SparseTensor containing the result of the Hadamard product. Notes: - Both `A` and `B` must be coalesced SparseTensors. - The dense shapes of `A` and `B` must be broadcastable. """ assert A.is_coalesced(), "A should be coalesced" assert B.is_coalesced(), "B should be coalesced" assert A.sparseshape == B.sparseshape, "A, B should be of the same sparse shape" ind1, val1 = A.indices, A.values ind2, val2 = B.indices, B.values if b2a is None: b2a = spsphadamard_ind(ind1, ind2) mask = (b2a >= 0) if val1 is None: retval = val2[mask] elif val2 is None: retval = val1[b2a[mask]] else: retval = val1[b2a[mask]] * val2[mask] retind = ind2[:, mask] return SparseTensor(retind, retval, shape=A.sparseshape + retval.shape[1:], is_coalesced=True)
[docs] def spspmm(A: SparseTensor, dim1: int, B: SparseTensor, dim2: int, aggr: str = "sum", bcd: Optional[LongTensor] = None, tar_ind: Optional[LongTensor] = None, acd: Optional[LongTensor] = None) -> SparseTensor: """ SparseTensor SparseTensor matrix multiplication at a specified sparse dimension. This function performs matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values. Args: - A (SparseTensor): The first SparseTensor. - dim1 (int): The dimension along which `A` is multiplied. - B (SparseTensor): The second SparseTensor. - dim2 (int): The dimension along which `B` is multiplied. - aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mean"). Defaults to "sum". - bcd (LongTensor, optional): An optional auxiliary index array produced by spspmm_ind. - tar_ind (LongTensor, optional): An optional target index array for the output. If not provided, it will be computed. - acd (LongTensor, optional): An optional auxiliary index array produced by filterind. Returns: - SparseTensor: A SparseTensor containing the result of the matrix multiplication. Notes: - Both `A` and `B` must be coalesced SparseTensors. - The dense shapes of `A` and `B` must be broadcastable. - This function allows for optional indices `bcd` and `tar_ind` for improved performance and control. """ assert A.is_coalesced(), "A should be coalesced" assert B.is_coalesced(), "B should be coalesced" if acd is not None: assert tar_ind is not None if A.values is None: mult = B.values[acd[2]] elif B.values is None: mult = A.values[acd[1]] else: mult = A.values[acd[1]] * B.values[acd[2]] retval = torch_scatter_reduce(0, mult, acd[0], tar_ind.shape[1], aggr) return SparseTensor(tar_ind, retval, shape=A.sparseshape[:dim1] + A.sparseshape[dim1 + 1:] + B.sparseshape[:dim2] + B.sparseshape[dim2 + 1:] + retval.shape[1:], is_coalesced=True) else: warnings.warn("acd is not found") if bcd is None: ind, bcd = spspmm_ind(A.indices, dim1, B.indices, dim2) if tar_ind is not None: acd = filterind(tar_ind, ind, bcd) return spspmm(A, dim1, B, dim2, aggr, acd=acd, tar_ind=tar_ind) else: warnings.warn("tar_ind is not found") return spspmm(A, dim1, B, dim2, aggr, acd=bcd, tar_ind=ind)
[docs] def spspmpnn(A: SparseTensor, dim1: int, B: SparseTensor, dim2: int, C: SparseTensor, acd: LongTensor, message_func: Callable[[Tensor, Tensor, Tensor, LongTensor], Tensor], aggr: str = "sum") -> SparseTensor: """ SparseTensor SparseTensor matrix multiplication at a specified sparse dimension using a message function. This function extend matrix multiplication between two SparseTensors, `A` and `B`, at the specified sparse dimensions `dim1` and `dim2`, while using a message function `message_func` to compute the messages sent from `A` to `B` and `C`. The result is a SparseTensor containing the result of the multiplication. The `aggr` parameter specifies the reduction operation used for merging the resulting values. Args: - A (SparseTensor): The first SparseTensor. - dim1 (int): The dimension along which `A` is multiplied. - B (SparseTensor): The second SparseTensor. - dim2 (int): The dimension along which `B` is multiplied. - C (SparseTensor): The third SparseTensor, providing the target indice - acd (LongTensor): The auxiliary index array produced by a previous operation. - message_func (Callable): A callable function that computes the messages between `A`, `B`, and `C`. - aggr (str, optional): The reduction operation to use for merging edge features ("sum", "min", "max", "mul", "any"). Defaults to "sum". Returns: - SparseTensor: A SparseTensor containing the result of the matrix multiplication. Notes: - Both `A` and `B` must be coalesced SparseTensors. - The dense shapes of `A`, `B`, and `C` must be broadcastable. - The `message_func` should take four arguments: `A_values`, `B_values`, `C_values`, and `acd`, and return messages based on custom logic. """ mult = message_func(None if A.values is None else A.values[acd[1]], None if B.values is None else B.values[acd[2]], None if C.values is None else C.values[acd[0]], acd[0]) tar_ind = C.indices retval = torch_scatter_reduce(0, mult, acd[0], tar_ind.shape[1], aggr) return SparseTensor(tar_ind, retval, shape=A.sparseshape[:dim1] + A.sparseshape[dim1 + 1:] + B.sparseshape[:dim2] + B.sparseshape[dim2 + 1:] + retval.shape[1:], is_coalesced=True)