complextorch.nn.masked.linear#

Layers that apply a fixed binary mask to their complex weight at forward time. Used to deploy a learned-sparsity pattern at inference.

Classes#

BilinearMasked

Complex bilinear with a fixed binary weight mask.

LinearMasked

Complex linear with a fixed binary weight mask.

Module Contents#

class complextorch.nn.masked.linear.BilinearMasked(in1_features: int, in2_features: int, out_features: int, bias: bool = True, conjugate: bool = True, device=None, dtype: torch.dtype = torch.cfloat)[source]#

Bases: complextorch.nn.masked.base.MaskedWeightMixin, complextorch.nn.masked.base.BaseMasked

Complex bilinear with a fixed binary weight mask.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input1: torch.Tensor, input2: torch.Tensor) torch.Tensor[source]#
conjugate = True#
in1_features#
in2_features#
out_features#
weight#
class complextorch.nn.masked.linear.LinearMasked(in_features: int, out_features: int, bias: bool = True, device=None, dtype: torch.dtype = torch.cfloat)[source]#

Bases: complextorch.nn.masked.base.MaskedWeightMixin, complextorch.nn.masked.base.BaseMasked

Complex linear with a fixed binary weight mask.

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(input: torch.Tensor) torch.Tensor[source]#
in_features#
out_features#
weight#