Source code for pygho.honn.utils

'''
A general MLP class
'''
import torch.nn as nn
from torch import Tensor
from typing import Callable
from torch import Tensor

# Norms for subgraph GNN


[docs] class NormMomentumScheduler: def __init__(self, mfunc: Callable, initmomentum: float, normtype=nn.BatchNorm1d) -> None: super().__init__() self.normtype = normtype self.mfunc = mfunc self.epoch = 0 self.initmomentum = initmomentum
[docs] def step(self, model: nn.Module): ratio = self.mfunc(self.epoch) if 1 - 1e-6 < ratio < 1 + 1e-6: return self.initmomentum curm = self.initmomentum * ratio self.epoch += 1 for mod in model.modules(): if type(mod) is self.normtype: mod.momentum = curm return curm
[docs] class NoneNorm(nn.Module): def __init__(self, dim=0, normparam=0) -> None: super().__init__() self.num_features = dim
[docs] def forward(self, x): return x
[docs] class BatchNorm(nn.Module): def __init__(self, dim, normparam=0.1) -> None: super().__init__() self.num_features = dim self.norm = nn.BatchNorm1d(dim, momentum=normparam)
[docs] def forward(self, x: Tensor): if x.dim() == 2: return self.norm(x) elif x.dim() >= 3: shape = x.shape x = self.norm(x.flatten(0, -2)).reshape(shape) return x else: raise NotImplementedError
[docs] class LayerNorm(nn.Module): def __init__(self, dim, normparam=0.1) -> None: super().__init__() self.num_features = dim self.norm = nn.LayerNorm(dim)
[docs] def forward(self, x: Tensor): return self.norm(x)
# Define a dictionary for normalization layers normdict = {"bn": BatchNorm, "ln": LayerNorm, "none": NoneNorm} # a dictionary for activation functions act_dict = { "relu": nn.ReLU(inplace=True), "ELU": nn.ELU(inplace=True), "silu": nn.SiLU(inplace=True) }
[docs] class MLP(nn.Module): """ Multi-Layer Perceptron (MLP) module with customizable layers and activation functions. Args: - hiddim (int): Number of hidden units in each layer. - outdim (int): Number of output units. - numlayer (int): Number of hidden layers in the MLP. - tailact (bool): Whether to apply the activation function after the final layer. - dp (float): Dropout probability, if greater than 0, dropout layers are added. - norm (str): Normalization method to apply between layers (e.g., "bn" for BatchNorm). - act (str): Activation function to apply between layers (e.g., "relu"). - tailbias (bool): Whether to include a bias term in the final linear layer. - normparam (float): Parameter for normalization (e.g., momentum for BatchNorm). Methods: - forward(x: Tensor) -> Tensor: Forward pass of the MLP. Notes: - This class defines a multi-layer perceptron with customizable layers, activation functions, normalization, and dropout. """ def __init__(self, hiddim: int, outdim: int, numlayer: int, tailact: bool, dp: float = 0, norm: str = "bn", act: str = "relu", tailbias=True, normparam: float = 0.1) -> None: super().__init__() assert numlayer >= 0 if numlayer == 0: assert hiddim == outdim self.lins = NoneNorm() else: lin0 = nn.Sequential(nn.Linear(hiddim, outdim, bias=tailbias)) if tailact: lin0.append(normdict[norm](outdim, normparam)) if dp > 0: lin0.append(nn.Dropout(dp, inplace=True)) lin0.append(act_dict[act]) for _ in range(numlayer - 1): lin0.insert(0, act_dict[act]) if dp > 0: lin0.insert(0, nn.Dropout(dp, inplace=True)) lin0.insert(0, normdict[norm](hiddim, normparam)) lin0.insert(0, nn.Linear(hiddim, hiddim)) self.lins = lin0
[docs] def forward(self, x: Tensor): # Forward pass through the MLP return self.lins(x)