complextorch.nn.masked#
Fixed-sparsity-pattern complex layers and helpers for managing their masks across a whole network.
Submodules#
Functions#
|
In-place binarize every mask attached to a |
|
Load a |
|
|
|
Yield |
Package Contents#
- complextorch.nn.masked.binarize_masks(model: torch.nn.Module) torch.nn.Module[source]#
In-place binarize every mask attached to a
BaseMaskedsubmodule.
- complextorch.nn.masked.deploy_masks(model: torch.nn.Module, state_dict: dict[str, torch.Tensor], *, strict: bool = True) torch.nn.Module[source]#
Load a
{name: mask}dict into the matchingBaseMaskedsubmodules ofmodel.Keys in
state_dictare interpreted as fully-qualified module names (e.g."encoder.layer1.linear.mask"or just"linear.mask"). Any key ending in".mask"is matched to the corresponding submodule’smaskbuffer.
- complextorch.nn.masked.is_sparse(layer: torch.nn.Module) bool[source]#
Trueiflayeris aBaseMaskedwith a mask set.
- complextorch.nn.masked.named_masks(model: torch.nn.Module) collections.abc.Iterator[tuple[str, torch.Tensor]][source]#
Yield
(qualified_name, mask)for each currently-set mask.