Source code for pygho.backend.MaTensor

import torch
from torch import Tensor, BoolTensor, LongTensor
from typing import Optional, Callable, Iterable
from typing import Union
# merge torch.nested or torch.masked API in the long run.


[docs] def filterinf(X: Tensor, filled_value: float = 0): """ Replaces positive and negative infinity values in a tensor with a specified value. Args: - X (Tensor): The input tensor. - filled_value (float, optional): The value to replace positive and negative infinity values with (default: 0). Returns: - Tensor: A tensor with positive and negative infinity values replaced by the specified `filled_value`. Example: :: input_tensor = torch.tensor([1.0, 2.0, torch.inf, -torch.inf, 3.0]) result = filterinf(input_tensor, filled_value=999.0) """ return torch.where(torch.logical_or(X == torch.inf, X == -torch.inf), filled_value, X)
[docs] class MaskedTensor: """ Represents a masked tensor with optional padding values. This class allows you to work with tensors that have a mask indicating valid and invalid values. You can perform various operations on the masked tensor, such as filling masked values, computing sums, means, maximums, minimums, and more. Parameters: - data (Tensor): The underlying data tensor of shape (\*maskedshape, \*denseshape) - mask (BoolTensor): The mask tensor of shape (\*maskedshape) where `True` represents valid values, and False` represents invalid values. - padvalue (float, optional): The value to use for padding. Defaults to 0. - is_filled (bool, optional): Indicates whether the invalid values have already been filled to the padvalue. Defaults to False. Attributes: - data (Tensor): The underlying data tensor. - mask (BoolTensor): The mask tensor. - fullmask (BoolTensor): The mask tensor after broadcasting to match the data's dimensions. - padvalue (float): The padding value. - shape (torch.Size): The shape of the data tensor. - masked_dim (int): The number of dimensions in maskedshape. - dense_dim (int): The number of dimensions in denseshape. - maskedshape (torch.Size): The shape of the tensor up to the masked dimensions. - denseshape (torch.Size): The shape of the tensor after the masked dimensions. Methods: - fill_masked_(self, val: float = 0) -> None: In-place fill of masked values. - fill_masked(self, val: float = 0) -> Tensor: Return a tensor with masked values filled with the specified value. - to(self, device: torch.DeviceObjType, non_blocking: bool = True): Move the tensor to the specified device. - sum(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the sum of masked values along specified dimensions. - mean(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the mean of masked values along specified dimensions. - max(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the maximum of masked values along specified dimensions. - min(self, dims: Union[Iterable[int], int], keepdim: bool = False): Compute the minimum of masked values along specified dimensions. - diag(self, dims: Iterable[int]): Extract diagonals from the tensor. The dimensions in dims will be take diagonal and put at dims[0] - unpooling(self, dims: Union[int, Iterable[int]], tarX): Perform unpooling operation along specified dimensions. - tuplewiseapply(self, func: Callable[[Tensor], Tensor]): Apply a function to each element of the masked tensor. - diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): Apply a function to diagonal elements of the masked tensor. - add(self, tarX, samesparse: bool): Add two masked tensors together. - catvalue(self, tarX, samesparse: bool): Concatenate values of two masked tensors. """ def __init__(self, data: Tensor, mask: BoolTensor, padvalue: float = 0.0, is_filled: bool = False): # mask: True for valid value, False for invalid value assert data.ndim >= mask.ndim, "data's #dim should be larger than mask " assert data.shape[:mask. ndim] == mask.shape, "data and mask's first dimensions should match" self.__data = data self.__mask = mask self.__masked_dim = mask.ndim while mask.ndim < data.ndim: mask = mask.unsqueeze(-1) self.__fullmask = mask if not is_filled: self.__padvalue = torch.inf if padvalue != torch.inf else -torch.inf self.fill_masked_(padvalue) else: self.__padvalue = padvalue
[docs] def fill_masked_(self, val: float = 0) -> None: """ inplace fill the masked values """ if self.padvalue == val: return self.__padvalue = val self.__data = torch.where(self.fullmask, self.data, val)
[docs] def fill_masked(self, val: float = 0) -> Tensor: """ return a tensor with masked values filled with val. """ if self.__padvalue == val: return self.data return torch.where(self.fullmask, self.data, val)
[docs] def to(self, device: torch.DeviceObjType, non_blocking: bool = True): """ move data to some device """ self.__data = self.__data.to(device, non_blocking=non_blocking) self.__mask = self.__mask.to(device, non_blocking=non_blocking) self.__fullmask = self.__fullmask.to(device, non_blocking=non_blocking) return self
@property def padvalue(self) -> float: return self.__padvalue @property def data(self) -> Tensor: return self.__data @property def mask(self) -> BoolTensor: return self.__mask @property def fullmask(self) -> BoolTensor: return self.__fullmask @property def shape(self) -> torch.Size: return self.__data.shape @property def masked_dim(self): return self.__masked_dim @property def dense_dim(self): return len(self.denseshape) @property def maskedshape(self): return self.shape[:self.masked_dim] @property def denseshape(self): return self.shape[self.masked_dim:]
[docs] def sum(self, dims: Union[Iterable[int], int], keepdim: bool = False): return MaskedTensor(torch.sum(self.fill_masked(0), dim=dims, keepdim=keepdim), torch.amax(self.mask, dims, keepdim=keepdim), padvalue=0, is_filled=True)
[docs] def mean(self, dims: Union[Iterable[int], int], keepdim: bool = False): count = torch.clamp_min_( torch.sum(self.fullmask, dim=dims, keepdim=keepdim), 1) valsum = self.sum(dims, keepdim) return MaskedTensor(valsum.data / count, valsum.mask, padvalue=valsum.padvalue, is_filled=True)
[docs] def max(self, dims: Union[Iterable[int], int], keepdim: bool = False): tmp = self.fill_masked(-torch.inf) return MaskedTensor(filterinf( torch.amax(tmp, dim=dims, keepdim=keepdim), 0), torch.amax(self.mask, dims, keepdim=keepdim), padvalue=0, is_filled=True)
[docs] def min(self, dims: Union[Iterable[int], int], keepdim: bool = False): tmp = self.fill_masked(torch.inf) return MaskedTensor(filterinf( torch.amax(tmp, dim=dims, keepdim=keepdim), 0), torch.amax(self.mask, dims, keepdim=keepdim), padvalue=0, is_filled=True)
[docs] def diag(self, dims: Iterable[int]): """ put the reduced output to dim[0] """ assert len(dims) >= 2, "must diag several dims" dims = sorted(list(dims)) tdata = self.data tmask = self.mask tdata = torch.diagonal(tdata, 0, dims[0], dims[1]) tmask = torch.diagonal(tmask, 0, dims[0], dims[1]) for i in range(2, len(dims)): tdata = torch.diagonal(tdata, 0, dims[i], -1) tmask = torch.diagonal(tmask, 0, dims[i], -1) tdata = torch.movedim(tdata, -1, dims[0]) tmask = torch.movedim(tmask, -1, dims[0]) return MaskedTensor(tdata, tmask, self.padvalue, True)
[docs] def unpooling(self, dims: Union[int, Iterable[int]], tarX): if isinstance(dims, int): dims = [dims] dims = sorted(list(dims)) tdata = self.data for _ in dims: tdata = tdata.unsqueeze(_) tdata = tdata.expand(*(-1 if i not in dims else tarX.shape[i] for i in range(tdata.ndim))) return MaskedTensor(tdata, tarX.mask, self.padvalue, False)
[docs] def tuplewiseapply(self, func: Callable[[Tensor], Tensor]): # it may cause nan in gradient and makes amp unable to update ndata = func(self.fill_masked(0)) return MaskedTensor(ndata, self.mask)
[docs] def diagonalapply(self, func: Callable[[Tensor, LongTensor], Tensor]): assert self.masked_dim == 3, "only implemented for 2D" diagonaltype = torch.eye(self.shape[1], self.shape[2], dtype=torch.long, device=self.data.device) diagonaltype = diagonaltype.unsqueeze(0).expand_as(self.mask) ndata = func(self.data, diagonaltype) return MaskedTensor(ndata, self.mask)
[docs] def add(self, tarX, samesparse: bool): if samesparse: return MaskedTensor(tarX.data + self.data, self.mask, self.padvalue, is_filled=self.padvalue == tarX.padvalue) else: return MaskedTensor( tarX.fill_masked(0) + self.fill_masked(0), torch.logical_or(self.mask, tarX.mask), 0, True)
[docs] def catvalue(self, tarX, samesparse: bool): assert samesparse == True, "must have the same sparcity to concat value" if isinstance(tarX, MaskedTensor): return self.tuplewiseapply(lambda _: torch.concat( (self.data, tarX.data), dim=-1)) elif isinstance(tarX, Iterable): return self.tuplewiseapply(lambda _: torch.concat( [self.data] + [_.data for _ in tarX], dim=-1)) else: raise NotImplementedError