Source code for pygho.hodata.SpData

'''
utilities for sparse high order data
'''
from torch_geometric.data import Data as PygData, Batch as PygBatch
import torch
from typing import Any, List, Callable, Union, Tuple, Iterable
from torch import Tensor
from ..backend.Spspmm import spspmm_ind, filterind
from ..backend.SpTensor import SparseTensor
from ..honn.SpOperator import KEYSEP
from torch_geometric.utils import coalesce


[docs] def parseop(op: str): ''' Get the increment for a tensor when combining graphs. Args: - op (str): The operator string. Returns: - str or NotImplementedError: The increment information or NotImplementedError if the operator is not implemented. ''' if op[0] == "X": return f"num_tuples{op[1:]}" elif op == "A": return "num_edges" else: return NotImplementedError, f"operator name {op} not implemented now"
[docs] def parsekey(key: str) -> Tuple[str, str, int, str, int]: ''' Parse the operators in precomputation keys. Args: - key (str): The precomputation key. Returns: - Tuple[str, str, int, str, int]: A tuple containing parsed operators and dimensions. ''' assert len(key.split(KEYSEP)) == 5, "key format not match" op0, op1, dim1, op2, dim2 = key.split(KEYSEP) dim1 = int(dim1) dim2 = int(dim2) parseop(op0) parseop(op1) parseop(op2) return op0, op1, dim1, op2, dim2
[docs] class SpHoData(PygData): ''' A data class for sparse high order graph data. ''' def __inc__(self, key: str, value: Any, *args, **kwargs): if key.startswith('tupleid'): return getattr(self, "tupleshape" + key.removeprefix("tupleid")).reshape( -1, 1) if key.endswith(f"{KEYSEP}acd"): key = key.removesuffix(f"{KEYSEP}acd") op0, op1, _, op2, _ = parsekey(key) return torch.tensor( [[getattr(self, parseop(op0))], [getattr(self, parseop(op1))], [getattr(self, parseop(op2))]], dtype=torch.long) return super().__inc__(key, value, *args, **kwargs) def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: if key.startswith('tupleid') or key.endswith(f"{KEYSEP}acd"): return 1 return super().__cat_dim__(key, value, *args, **kwargs)
[docs] def batch2sparse(batch: PygBatch, keys: List[str] = [""]) -> PygBatch: ''' A main wrapper for converting data in a batch object to SparseTensor. Args: - batch (PygBatch): The batch object containing graph data. - keys (List[str]): The list of keys to convert to SparseTensor. Returns: - PygBatch: The batch object with converted data. ''' batch.A = SparseTensor( batch.edge_index, batch.edge_attr, [batch.num_nodes, batch.num_nodes] if batch.edge_attr is None else [batch.num_nodes, batch.num_nodes] + list(batch.edge_attr.shape[1:]), is_coalesced=True) for key in keys: # print("key=", key) totaltupleshape = getattr(batch, f"tupleshape{key}").sum(dim=0).tolist() tupleid = getattr(batch, f"tupleid{key}") tuplefeat = getattr(batch, f"tuplefeat{key}") X = SparseTensor( tupleid, tuplefeat, shape=totaltupleshape if tuplefeat is None else totaltupleshape + list(tuplefeat.shape[1:]), is_coalesced=True) setattr(batch, f"X{key}", X) return batch
[docs] def sp_datapreprocess(data: PygData, tuplesamplers: List[Callable[[PygData], SparseTensor]], annotate: List[str] = [""], keys: List[str] = [""]) -> SpHoData: ''' A wrapper for preprocessing dense data for sparse high order graphs. Args: - data (PygData): The input dense data in PyG Data format. - tuplesamplers (Union[Callable, List[Callable]]): A single or list of tuple sampling functions. - annotate (List[str]): A list of annotation strings for tuple sampling. - keys (List[str]): A list of precomputation keys. Returns: - SpHoData: The preprocessed sparse high order data in SpHoData format. ''' data.edge_index, data.edge_attr = coalesce(data.edge_index, data.edge_attr, num_nodes=data.num_nodes) assert len(tuplesamplers) == len( annotate ), "number of tuple sampler should match the number of annotate" datadict = data.to_dict() datadict.update({ "num_nodes": data.num_nodes, "num_edges": data.edge_index.shape[1], "x": data.x, "edge_index": data.edge_index, "edge_attr": data.edge_attr, }) for i, tuplesampler in enumerate(tuplesamplers): feat = tuplesampler(data) tupleid, tuplefeat, tupleshape = feat.indices, feat.values, feat.sparseshape num_tuples = tupleid.shape[1] datadict.update({ f"tupleid{annotate[i]}": tupleid, f"tuplefeat{annotate[i]}": tuplefeat, f"tupleshape{annotate[i]}": torch.LongTensor(tupleshape).reshape(1, -1), f"num_tuples{annotate[i]}": num_tuples }) for key in keys: op0, op1, dim1, op2, dim2 = parsekey(key) datadict[key + f"{KEYSEP}acd"] = filterind( datadict[f"tupleid{op0[1:]}"] if op0[0] == "X" else datadict["edge_index"], *spspmm_ind( datadict[f"tupleid{op1[1:]}"] if op1[0] == "X" else datadict["edge_index"], dim1, datadict[f"tupleid{op2[1:]}"] if op2[0] == "X" else datadict["edge_index"], dim2)) return SpHoData(**datadict)