Native vs. Gauss-trick modules#
Convolution and linear layers in complextorch.nn exist in two variants:
Native cfloat (recommended) |
Gauss-trick (reference) |
|---|---|
|
|
|
|
|
|
Note
Up to complextorch < 2.0 the Gauss-trick variants lived at the top level as
SlowConv* / SlowLinear. The prefix was a misleading legacy from when they
were faster than the naive split; they have since been moved to the
complextorch.nn.gauss subpackage and the Slow names removed.
What’s the difference?#
Native cfloat modules are thin wrappers around the corresponding torch.nn
module constructed with dtype=torch.cfloat. They rely on PyTorch’s native
complex kernels (available since PyTorch 2.1) and are the recommended path for
all new code.
import torch
import complextorch as ctorch
x = torch.randn(8, 5, 7, dtype=torch.cfloat)
y = ctorch.nn.Conv1d(5, 16, kernel_size=3)(x) # native cfloat kernel
Gauss-trick modules are the original hand-rolled implementations that split each complex tensor into real and imaginary parts and apply Gauss’ multiplication trick:
with a three-multiply real-valued formulation under the hood. They predate PyTorch’s native complex support and are kept for two reasons:
Reference math — the Gauss path is the easiest place to read the real/imag split when you’re learning the package internals or implementing a new layer.
Explicit split parameters —
conv_r/conv_i(orlinear_r/linear_i) are exposed as separatenn.Modulechildren, which is useful if you want to apply different parameterizations or constraints to each half.
Which should I use?#
Use the native cfloat variant. The Gauss-trick path no longer offers a
speed advantage since PyTorch 2.1, so its only remaining role is as a
numerically-equivalent reference. The test suite under tests/invariants/
checks the two paths agree to floating-point tolerance on the same weights.
If you’re adding a new layer that has a native PyTorch complex equivalent,
follow the native pattern (wrap torch.nn.X with dtype=torch.cfloat)
rather than reimplementing the real/imag split.
The three composition primitives#
Most non-convolutional layers in complextorch are built on three helpers in
complextorch.nn.functional:
apply_complex()— the “naive” complex linear lift: \((R(x_r) - I(x_i)) + j(R(x_i) + I(x_r))\).apply_complex_split()— Type-A split: apply two separate functions to real and imaginary parts independently. Used byCVSplit*activations,Dropout,CVSoftMax,AdaptiveAvgPool*d.apply_complex_polar()— Type-B polar split: apply functions to magnitude and phase separately, recombine viatorch.polar. Used byCVPolar*/modReLUactivations.
See Activations for the math behind Type-A / Type-B.
Tip
Construct magnitude/phase tensors with torch.polar(abs, angle) — it’s been
a PyTorch builtin since 1.8 and is the idiomatic call. complextorch does
not provide a from_polar helper.