import cmath
import math
from typing import Literal
import numpy as np
import fourier_toolkit.linalg as ftkl
import fourier_toolkit.typing as ftkt
import fourier_toolkit.util as ftku
__all__ = [
"CZT",
"DFT",
"u2u",
]
ExponentSign = Literal[-1, +1]
[docs]
def u2u(
x_spec: ftku.UniformSpec,
v_spec: ftku.UniformSpec,
w: ftkt.ArrayRC,
isign: ExponentSign,
) -> ftkt.ArrayC:
r"""
Multi-dimensional Uniform-to-Uniform Fourier Transform. (:math:`\tuu`)
Computes the Fourier sum
.. math::
\bbz_{n} = \sum_{m} w_{m} \ee^{ \pm \cj 2\pi \innerProduct{\bbx_{m}}{\bbv_{n}} },
where :math:`(\bbx_{m}, \bbv_{n})` lie on the regular lattice
.. math::
\bbx_{\bbm} &= \bbx_{0} + \Delta_{\bbx} \odot \bbm, \qquad [\bbm]_{d} \in \discreteRange{0}{M_{d}-1}, \\
\bbv_{\bbn} &= \bbv_{0} + \Delta_{\bbv} \odot \bbn, \qquad [\bbn]_{d} \in \discreteRange{0}{N_{d}-1},
with :math:`M = \prod_{d} M_{d}` and :math:`N = \prod_{d} N_{d}`.
Parameters
----------
x_spec: UniformSpec
:math:`\bbx_{m}` lattice.
v_spec: UniformSpec
:math:`\bbv_{n}` lattice.
isign: +1, -1
Exponent sign.
w: ArrayRC
(..., M1,...,MD) weights :math:`w_{m} \in \bC`.
Returns
-------
z: ArrayC
(..., N1,...,ND) weights :math:`z_{n} \in \bC`.
Notes
-----
:math:`\tuu` transforms for arbitrary (x_spec,v_spec) can be implemented using the CZT algorithm (using 2 FFTs), but a single FFT can be used in some cases.
This implementation chooses the (FFT, CZT) per axis to maximize efficiency.
"""
assert (isign := int(isign)) in (-1, +1)
op = _U2U(x_spec=x_spec, v_spec=v_spec)
if isign == -1:
z = op.apply(w)
else:
z = op.apply(w.conj()).conj()
return z
# Helper routines (internal) ---------------------------------------------------
class DFT:
r"""
Multi-dimensional Discrete Fourier Transform (DFT) :math:`F: \bC^{N_{1} \times\cdots\times N_{D}} \to \bC^{N_{1} \times\cdots\times N_{D}}`.
The 1D DFT is defined as:
.. math::
\bby[k]
=
(F \, \bbx)[k]
=
\sum_{n=0}^{N-1} \bbx[n] \ee^{-\cj \frac{2\pi}{N} nk},
where :math:`\bbx \in \bC^{N}`, and :math:`k \in \discreteRange{0}{N-1}`.
A D-dimensional DFT corresponds to taking a 1D DFT along each transform axis.
This implementation is a thin shell around the FFT algorithm intended to simplify multi-backend (CPU, GPU) use.
"""
def __init__(self, D: int):
"""
Parameters
----------
D: int
Dimension of the transform.
"""
assert D >= 1
self.cfg = ftku.as_namedtuple(
D=D,
)
def apply(self, x: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Compute :math:`\bby = F \bbx`.
Parameters
----------
x: ArrayRC
(..., N1,...,ND) input :math:`\bbx \in \bC^{N_{1} \times\cdots\times N_{D}}`.
Returns
-------
y: ArrayC
(..., N1,...,ND) output :math:`\bby \in \bC^{N_{1} \times\cdots\times N_{D}}`.
"""
xp = x.__array_namespace__()
y = xp.fft.fftn(x, axes=tuple(range(-self.cfg.D, 0)), norm="backward")
return y
def adjoint(self, y: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Compute :math:`\bbx = F^{\adj} \bby`.
Parameters
----------
y: ArrayRC
(..., N1,...,ND) input :math:`\bby \in \bC^{N_{1} \times\cdots\times N_{D}}`.
Returns
-------
x: ArrayC
(..., N1,...,ND) output :math:`\bbx \in \bC^{N_{1} \times\cdots\times N_{D}}`.
"""
xp = y.__array_namespace__()
x = xp.fft.ifftn(y, axes=tuple(range(-self.cfg.D, 0)), norm="forward")
return x
def inverse(self, y: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Inverse transform :math:`\bbx = F^{-1} \bby`.
Parameters
----------
y: ArrayRC
(..., N1,...,ND) input :math:`\bby = \in \bC^{N_{1} \times\cdots\times N_{D}}`.
Returns
-------
x: ArrayC
(..., N1,...,ND) output :math:`\bbx \in \bC^{N_{1} \times\cdots\times N_{D}}`.
"""
xp = y.__array_namespace__()
x = xp.fft.ifftn(y, axes=tuple(range(-self.cfg.D, 0)), norm="backward")
return x
class CZT:
r"""
Multi-dimensional Chirp Z-Transform (CZT) :math:`C: \bC^{N_{1} \times\cdots\times N_{D}} \to
\bC^{M_{1} \times\cdots\times M_{D}}`.
The 1D CZT of parameters :math:`(A, W, M)` is defined as:
.. math::
\bby[k]
=
(C \, \bbx)[k]
=
\sum_{n=0}^{N-1} \bbx[n] A^{-n} W^{nk},
where :math:`\bbx \in \bC^{N}`, :math:`(A, W) \in \bC`, and :math:`k \in \discreteRange{0}{M-1}`.
A D-dimensional CZT corresponds to taking a 1D CZT along each transform axis.
For stability reasons, this implementation assumes :math:`\abs{A} = \abs{W} = 1`.
"""
def __init__(
self,
N: tuple[int],
M: tuple[int],
A: tuple[complex],
W: tuple[complex],
):
r"""
Parameters
----------
N: tuple[int]
(N1,...,ND) dimensions of the input :math:`\bbx`.
M: tuple[int]
(M1,...,MD) dimensions of the output :math:`\bby = (C \, \bbx)`.
A: tuple[complex]
(D,) circular offsets.
W: tuple[complex]
(D,) circular spacings between transform points.
"""
M = ftku.broadcast_seq(M, None, int)
D = len(M)
N = ftku.broadcast_seq(N, D, int)
A = ftku.broadcast_seq(A, D, complex)
W = ftku.broadcast_seq(W, D, complex)
assert all(m > 0 for m in M)
assert all(n > 0 for n in N)
assert all(math.isclose(abs(a), 1) for a in A)
assert all(math.isclose(abs(w), 1) for w in W)
self.cfg = ftku.as_namedtuple(
M=M,
D=D,
N=N,
A=A,
W=W,
L=tuple(ftku.next_fast_len(n + m - 1) for (n, m) in zip(N, M)),
)
def apply(self, x: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Compute :math:`\bby = C \bbx`.
Parameters
----------
x: ArrayRC
(..., N1,...,ND) input :math:`\bbx \in \bC^{N_{1} \times\cdots\times N_{D}}`.
Returns
-------
y: ArrayC
(..., M1,...,MD) output :math:`\bby \in \bC^{M_{1} \times\cdots\times M_{D}}`.
"""
AWk2, FWk2, Wk2, extract = self._mod_params_apply(x)
pad_width = [(0, 0)] * (x.ndim - self.cfg.D) # stack dimensions
pad_width += [ # core dimensions
(0, l - n)
for (l, n) in zip(self.cfg.L, self.cfg.N) # noqa: E741
]
_x = ftkl.hadamard_outer(x, *AWk2)
_x = np.pad(_x, pad_width)
_x = DFT(self.cfg.D).apply(_x)
_x = ftkl.hadamard_outer(_x, *FWk2)
_x = DFT(self.cfg.D).inverse(_x)
y = ftkl.hadamard_outer(_x[..., *extract], *Wk2)
return y
def adjoint(self, y: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Compute :math:`\bbx = C^{\adj} \bby`.
Parameters
----------
y: ArrayRC
(..., M1,...,MD) input :math:`\bby \in \bC^{M_{1} \times\cdots\times M_{D}}`.
Returns
-------
x: ArrayC
(..., N1,...,ND) output :math:`\bbx \in \bC^{N_{1} \times\cdots\times N_{D}}`.
"""
# CZT^{\adjoint}(y,M,A,W)[n] = CZT(y,N,A=1,W=W*)[n] * A^{n}
czt = CZT(
N=self.cfg.M,
M=self.cfg.N,
A=1,
W=tuple(w.conjugate() for w in self.cfg.W),
)
An = self._mod_params_adjoint(y)
_y = czt.apply(y)
x = ftkl.hadamard_outer(_y, *An)
return x
# Helper routines (internal) ----------------------------------------------
def _mod_params_apply(self, x: ftkt.ArrayRC):
"""
Parameters
----------
x: ArrayRC
Returns
-------
AWk2: ArrayC
(N1,),...,(ND,) pre-FFT modulation vectors.
FWk2: ArrayC
(L1,),...,(LD,) FFT of convolution filters.
Wk2: ArrayC
(M1,),...,(MD,) post-FFT modulation vectors.
extract: list[slice]
(slice1,...,sliceD) FFT interval to extract.
"""
translate = ftku.TranslateDType(x.dtype)
cdtype = translate.to_complex()
# Build modulation vectors (Wk2, AWk2, FWk2).
Wk2 = [None] * self.cfg.D
AWk2 = [None] * self.cfg.D
FWk2 = [None] * self.cfg.D
for d in range(self.cfg.D):
A = self.cfg.A[d]
W = self.cfg.W[d]
N = self.cfg.N[d]
M = self.cfg.M[d]
L = self.cfg.L[d]
k = np.arange(max(M, N), dtype=int, like=x)
_Wk2 = W ** ((k**2) / 2)
_AWk2 = (A ** -k[:N]) * _Wk2[:N]
_FWk2 = np.fft.fft(
np.concatenate([_Wk2[(N - 1) : 0 : -1], _Wk2[:M]]).conj(),
n=L,
)
_Wk2 = _Wk2[:M]
Wk2[d] = _Wk2.astype(cdtype)
AWk2[d] = _AWk2.astype(cdtype)
FWk2[d] = _FWk2.astype(cdtype)
# Build (extract,)
extract = [slice(None)] * self.cfg.D
for d in range(self.cfg.D):
N = self.cfg.N[d]
M = self.cfg.M[d]
L = self.cfg.L[d]
extract[d] = slice(N - 1, N + M - 1)
return AWk2, FWk2, Wk2, extract
def _mod_params_adjoint(self, y: ftkt.ArrayRC):
"""
Parameters
----------
y: ArrayRC
Returns
-------
An: ArrayC
(N1,),...,(ND,) vectors.
"""
translate = ftku.TranslateDType(y.dtype)
cdtype = translate.to_complex()
An = [None] * self.cfg.D
for d in range(self.cfg.D):
_A = self.cfg.A[d]
_N = self.cfg.N[d]
_An = _A ** np.arange(_N, dtype=int, like=y)
An[d] = _An.astype(cdtype)
return An
class _U2U:
r"""
Object-oriented interface to a :math:`\tuu` transform, with exponent sign set to :math:`-1`.
For internal use only.
"""
def __init__(
self,
x_spec: ftku.UniformSpec,
v_spec: ftku.UniformSpec,
):
r"""
Parameters
----------
x_spec: UniformSpec
:math:`\bbx_{m}` lattice.
v_spec: UniformSpec
:math:`\bbv_{n}` lattice.
"""
assert x_spec.ndim == v_spec.ndim
D = x_spec.ndim
# decide on algorithm per axis
fft_axes = []
czt_axes = []
for d in range(D):
dx = x_spec.step[d]
dv = v_spec.step[d]
M = x_spec.num[d]
N = v_spec.num[d]
if (M == N) and math.isclose(abs(dx * dv) * N, 1):
fft_axes.append(d)
else:
czt_axes.append(d)
self.cfg = ftku.as_namedtuple(
D=D,
x_spec=x_spec,
v_spec=v_spec,
fft_axes=tuple(fft_axes),
czt_axes=tuple(czt_axes),
)
def apply(self, w: ftkt.ArrayRC) -> ftkt.ArrayC:
r"""
Compute :math:`\bbz = U \bbw`.
Parameters
----------
w: ArrayRC
(..., M1,...,MD) weights :math:`w_{m} \in \bC`.
Returns
-------
z: ArrayC
(..., N1,...,ND) weights :math:`z_{n} \in \bC`.
"""
_w = w
# Processing FFT axes
if self.cfg.fft_axes:
(ax_fft, Cp, fft, Bp, ax_ifft) = self._fft_params(w)
_w = _w.transpose(ax_fft)
_w = ftkl.hadamard_outer(_w, *Cp)
_w = fft(_w)
_w = ftkl.hadamard_outer(_w, *Bp)
_w = _w.transpose(ax_ifft)
# Processing CZT axes
if self.cfg.czt_axes:
(ax_czt, czt, B, ax_iczt) = self._czt_params(w)
_w = _w.transpose(ax_czt)
_w = czt.apply(_w)
_w = ftkl.hadamard_outer(_w, *B)
_w = _w.transpose(ax_iczt)
z = _w
return z
# Helper routines (internal) ----------------------------------------------
def _fft_params(self, y: ftkt.ArrayRC):
"""
Parameters
----------
y: ArrayRC
Returns
-------
ax_fft: tuple[int]
Permutation tuple to move FFT axes to end of `y`.
Cp: ArrayC
(N1,),...,(ND,) pre-FFT modulation vectors.
fft: callable
FFT() instance; computes fft/ifft along required axes.
Bp: ArrayC
(N1,),...,(ND,) post-FFT modulation vectors.
ax_ifft: tuple[int]
Permutation tuple to undo initial axis transposition.
"""
translate = ftku.TranslateDType(y.dtype)
cdtype = translate.to_complex()
# Build (ax_fft, ax_ifft)
sh = y.shape[: -self.cfg.D] # stack dimensions
stk_axes = tuple(range(len(sh)))
fft_axes = tuple(ax + len(sh) for ax in self.cfg.fft_axes)
czt_axes = tuple(ax + len(sh) for ax in self.cfg.czt_axes)
ax_fft = (*stk_axes, *czt_axes, *fft_axes)
ax_ifft = tuple(np.argsort(ax_fft))
# Build FFT operator
D_fft = len(self.cfg.fft_axes)
neg_axes = [] # '-' exponent axes
pos_axes = [] # '+' exponent axes
for d in range(D_fft):
ax = self.cfg.fft_axes
dx = self.cfg.x_spec.step
dv = self.cfg.v_spec.step
if dx[ax[d]] * dv[ax[d]] > 0:
neg_axes.append(-D_fft + d)
else:
pos_axes.append(-D_fft + d)
def fft(x: ftkt.ArrayRC) -> ftkt.ArrayC:
# basically DFT.apply(), but with `axes` modified.
xp = x.__array_namespace__()
y = xp.fft.fftn(x, axes=neg_axes, norm="backward")
y = xp.fft.ifftn(y, axes=pos_axes, norm="forward")
return y
# Build modulation vectors (Cp, Bp)
Cp = [None] * D_fft
Bp = [None] * D_fft
for d in range(D_fft):
ax = self.cfg.fft_axes
x0 = self.cfg.x_spec.start
dx = self.cfg.x_spec.step
nx = self.cfg.x_spec.num
v0 = self.cfg.v_spec.start
dv = self.cfg.v_spec.step
nv = self.cfg.v_spec.num
phase_scale_c = -2 * math.pi * dx[ax[d]] * v0[ax[d]]
m = np.arange(nx[ax[d]], dtype=int, like=y)
_Cp = np.exp(1j * phase_scale_c * m)
Cp[d] = _Cp.astype(cdtype)
phase_scale_b = -2 * math.pi * x0[ax[d]]
v = v0[ax[d]] + dv[ax[d]] * np.arange(nv[ax[d]], dtype=int, like=y)
_Bp = np.exp(1j * phase_scale_b * v)
Bp[d] = _Bp.astype(cdtype)
return (ax_fft, Cp, fft, Bp, ax_ifft)
def _czt_params(self, y: ftkt.ArrayRC):
"""
Parameters
----------
y: ArrayRC
Returns
-------
ax_czt: tuple[int]
Permutation tuple to move CZT axes to end of `y`.
czt: CZT
CZT(A,W,M,N) instance.
B: ArrayC
(N1,),...,(ND,) post-CZT modulation vectors.
ax_iczt: tuple[int]
Permutation tuple to undo initial axis transposition.
"""
translate = ftku.TranslateDType(y.dtype)
cdtype = translate.to_complex()
# Build (ax_czt, ax_iczt)
sh = y.shape[: -self.cfg.D] # stack dimensions
stk_axes = tuple(range(len(sh)))
fft_axes = tuple(ax + len(sh) for ax in self.cfg.fft_axes)
czt_axes = tuple(ax + len(sh) for ax in self.cfg.czt_axes)
ax_czt = (*stk_axes, *fft_axes, *czt_axes)
ax_iczt = tuple(np.argsort(ax_czt))
# Build CZT operator
D_czt = len(self.cfg.czt_axes)
N = [None] * D_czt
M = [None] * D_czt
A = [None] * D_czt
W = [None] * D_czt
for d in range(D_czt):
ax = self.cfg.czt_axes
dx = self.cfg.x_spec.step
nx = self.cfg.x_spec.num
v0 = self.cfg.v_spec.start
dv = self.cfg.v_spec.step
nv = self.cfg.v_spec.num
N[d] = nx[ax[d]]
M[d] = nv[ax[d]]
A[d] = cmath.exp(+1j * 2 * math.pi * dx[ax[d]] * v0[ax[d]])
W[d] = cmath.exp(-1j * 2 * math.pi * dx[ax[d]] * dv[ax[d]])
czt = CZT(N, M, A, W)
# Build modulation vector (B,)
B = [None] * D_czt
for d in range(D_czt):
ax = self.cfg.czt_axes
x0 = self.cfg.x_spec.start
v0 = self.cfg.v_spec.start
dv = self.cfg.v_spec.step
nv = self.cfg.v_spec.num
phase_scale = -2 * math.pi * x0[ax[d]]
v = v0[ax[d]] + dv[ax[d]] * np.arange(nv[ax[d]], dtype=int, like=y)
_B = np.exp(1j * phase_scale * v)
B[d] = _B.astype(cdtype)
return ax_czt, czt, B, ax_iczt