complextorch.nn.masked#

Fixed-sparsity-pattern complex layers and helpers for managing their masks across a whole network.

Submodules#

Functions#

binarize_masks(→ torch.nn.Module)

In-place binarize every mask attached to a BaseMasked submodule.

deploy_masks(→ torch.nn.Module)

Load a {name: mask} dict into the matching BaseMasked

is_sparse(→ bool)

True if layer is a BaseMasked with a mask set.

named_masks(→ collections.abc.Iterator[tuple[str, ...)

Yield (qualified_name, mask) for each currently-set mask.

Package Contents#

complextorch.nn.masked.binarize_masks(model: torch.nn.Module) torch.nn.Module[source]#

In-place binarize every mask attached to a BaseMasked submodule.

complextorch.nn.masked.deploy_masks(model: torch.nn.Module, state_dict: dict[str, torch.Tensor], *, strict: bool = True) torch.nn.Module[source]#

Load a {name: mask} dict into the matching BaseMasked submodules of model.

Keys in state_dict are interpreted as fully-qualified module names (e.g. "encoder.layer1.linear.mask" or just "linear.mask"). Any key ending in ".mask" is matched to the corresponding submodule’s mask buffer.

complextorch.nn.masked.is_sparse(layer: torch.nn.Module) bool[source]#

True if layer is a BaseMasked with a mask set.

complextorch.nn.masked.named_masks(model: torch.nn.Module) collections.abc.Iterator[tuple[str, torch.Tensor]][source]#

Yield (qualified_name, mask) for each currently-set mask.