Source code for pygho.backend.utils

import torch
from torch import Tensor, LongTensor
from typing import Tuple


[docs] def torch_scatter_reduce(dim: int, src: Tensor, ind: LongTensor, dim_size: int, aggr: str) -> Tensor: """ Applies a reduction operation to scatter elements from `src` to `dim_size` locations based on the indices in `ind`. This function is a wrapper for `torch.Tensor.scatter_reduce_` and is designed to scatter elements from `src` to `dim_size` locations based on the specified dimension `dim` and the indices in `ind`. The reduction operation is specified by the `aggr` parameter, which can be 'sum', 'mean', 'min', 'max'. Args: - dim (int): The dimension along which to scatter elements (only dim=0 is currently supported). - src (Tensor): The source tensor of shape (nnz, denseshape). - ind (LongTensor): The indices tensor of shape (nnz). - dim_size (int): The size of the target dimension for scattering. - aggr (str): The reduction operation to apply ('sum', 'mean', 'min', 'max', 'mul', 'any'). Returns: - Tensor: A tensor of shape (dim_size, denseshape) resulting from the scatter operation. Raises: - AssertionError: If `dim` is not 0, or if `ind` is not 1-dimensional. Example: :: src = torch.tensor([[1, 2], [4, 5], [7, 8], [9, 10]], dtype=torch.float) ind = torch.tensor([2, 2, 0, 1], dtype=torch.long) dim_size = 3 aggr = 'sum' result = torch_scatter_reduce(0, src, ind, dim_size, aggr) """ assert dim == 0, "other dim not implemented" assert ind.ndim == 1, "indice must be 1-d" if aggr in ["min", "max"]: aggr = "a" + aggr onedim = src.ndim - 1 dim_size = dim_size ret = torch.zeros_like(src[[0]].expand((dim_size, ) + (-1, ) * onedim)) ret.scatter_reduce_(dim, ind.reshape((-1, ) + (1, ) * onedim).expand_as(src), src, aggr, include_self=False) return ret