complextorch.nn.masked.linear#
Layers that apply a fixed binary mask to their complex weight at forward time. Used to deploy a learned-sparsity pattern at inference.
Classes#
Complex bilinear with a fixed binary weight mask. |
|
Complex linear with a fixed binary weight mask. |
Module Contents#
- class complextorch.nn.masked.linear.BilinearMasked(in1_features: int, in2_features: int, out_features: int, bias: bool = True, conjugate: bool = True, device=None, dtype: torch.dtype = torch.cfloat)[source]#
Bases:
complextorch.nn.masked.base.MaskedWeightMixin,complextorch.nn.masked.base.BaseMaskedComplex bilinear with a fixed binary weight mask.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input1: torch.Tensor, input2: torch.Tensor) torch.Tensor[source]#
- conjugate = True#
- in1_features#
- in2_features#
- out_features#
- weight#
- class complextorch.nn.masked.linear.LinearMasked(in_features: int, out_features: int, bias: bool = True, device=None, dtype: torch.dtype = torch.cfloat)[source]#
Bases:
complextorch.nn.masked.base.MaskedWeightMixin,complextorch.nn.masked.base.BaseMaskedComplex linear with a fixed binary weight mask.
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- forward(input: torch.Tensor) torch.Tensor[source]#
- in_features#
- out_features#
- weight#