Skip to content

API Reference

Magnifiers

pyevm.ColorMagnifier

Colour-based EVM magnifier.

Parameters:

Name Type Description Default
alpha float

Luminance amplification factor.

50.0
freq_low float

Temporal bandpass lower frequency (Hz).

0.4
freq_high float

Temporal bandpass upper frequency (Hz).

3.0
n_levels int

Gaussian pyramid levels (typically 4–6).

6
chrom_attenuation float

Scale applied to amplified I, Q channels (0 = no chrominance amplification, 1 = same as luma).

0.1
pyramid_level int | None

Which Gaussian level to temporally filter (default n_levels - 1; coarsest, lowest spatial frequency).

None
filter_type str

"ideal" (FFT bandpass) or "butterworth" (IIR, lower memory).

'ideal'
device device | None

Compute device (auto-selected if None).

None
dtype dtype

Tensor dtype.

float32
Source code in src/pyevm/magnification/color.py
class ColorMagnifier:
    """Colour-based EVM magnifier.

    Args:
        alpha: Luminance amplification factor.
        freq_low: Temporal bandpass lower frequency (Hz).
        freq_high: Temporal bandpass upper frequency (Hz).
        n_levels: Gaussian pyramid levels (typically 4–6).
        chrom_attenuation: Scale applied to amplified I, Q channels
            (0 = no chrominance amplification, 1 = same as luma).
        pyramid_level: Which Gaussian level to temporally filter (default
            ``n_levels - 1``; coarsest, lowest spatial frequency).
        filter_type: ``"ideal"`` (FFT bandpass) or ``"butterworth"`` (IIR,
            lower memory).
        device: Compute device (auto-selected if ``None``).
        dtype: Tensor dtype.
    """

    def __init__(
        self,
        alpha: float = 50.0,
        freq_low: float = 0.4,
        freq_high: float = 3.0,
        n_levels: int = 6,
        chrom_attenuation: float = 0.1,
        pyramid_level: int | None = None,
        filter_type: str = "ideal",
        notch_freqs: list[float] | None = None,
        notch_width: float = 1.0,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.alpha = alpha
        self.freq_low = freq_low
        self.freq_high = freq_high
        self.n_levels = n_levels
        self.chrom_attenuation = chrom_attenuation
        self.pyramid_level = pyramid_level if pyramid_level is not None else n_levels - 1
        self.filter_type = filter_type
        self.notch_freqs = notch_freqs or []
        self.notch_width = notch_width
        self.device = device or torch.device("cpu")
        self.dtype = dtype

        self._pyramid = GaussianPyramid(n_levels=n_levels, device=self.device, dtype=dtype)

        logger.info(
            f"ColorMagnifier: alpha={alpha}, band=[{freq_low}, {freq_high}] Hz, "
            f"pyramid_level={self.pyramid_level}, filter={filter_type}"
        )

    def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
        """Run colour EVM on a video tensor.

        Args:
            frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
            fps: Frames per second.

        Returns:
            Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
        """
        frames = frames.to(device=self.device, dtype=self.dtype)
        T, C, H, W = frames.shape
        logger.info(f"ColorMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

        # --- Convert to YIQ ---
        yiq = rgb_to_yiq(frames)  # (T, 3, H, W)

        # --- Build pyramid per frame, extract target level ---
        logger.debug(f"Building Gaussian pyramids (level {self.pyramid_level})…")
        level_stack: list[torch.Tensor] = []
        for t in range(T):
            levels = self._pyramid.build(yiq[t])  # list of (1, 3, h, w)
            level_stack.append(levels[self.pyramid_level].squeeze(0))  # (3, h, w)

        # Stack → (T, 3, h, w)
        level_tensor = torch.stack(level_stack, dim=0)
        logger.debug(f"Pyramid level tensor: {tuple(level_tensor.shape)}")

        # --- Temporal filtering ---
        if self.filter_type == "ideal":
            filt = IdealBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            filtered = filt.apply(level_tensor)  # (T, 3, h, w)
        else:
            filt = ButterworthBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            filtered = filt.apply(level_tensor)

        # --- Amplify: Y × alpha, I/Q × alpha × chrom_attenuation ---
        amplified = filtered.clone()
        amplified[:, 0] *= self.alpha
        amplified[:, 1:] *= self.alpha * self.chrom_attenuation
        logger.debug("Applied amplification to filtered signal")

        # --- Upsample amplified signal back to original resolution ---
        # We upsample the coarsest-level amplified signal to full resolution
        upsampled = torch.nn.functional.interpolate(
            amplified.reshape(T * 3, 1, *amplified.shape[2:]),
            size=(H, W),
            mode="bilinear",
            align_corners=False,
        ).reshape(T, 3, H, W)

        # --- Add to original YIQ ---
        result_yiq = yiq + upsampled

        # --- Convert back to RGB and clamp ---
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
        logger.info("ColorMagnifier.process complete")
        return result_rgb

    def process_stream(
        self,
        frames: Iterable[torch.Tensor],
        fps: float,
        n_frames: int | None = None,
        chunk_size: int = 64,
    ) -> Generator[torch.Tensor, None, None]:
        """Process frames in chunks, yielding each output frame.

        Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
        requires all frames at once and cannot be used in streaming mode).

        Args:
            frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
            fps: Frames per second.
            n_frames: Total frame count (optional, used for the progress bar).
            chunk_size: Number of frames to process per GPU batch.

        Yields:
            Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
        """
        filt = ButterworthBandpass(
            fps,
            self.freq_low,
            self.freq_high,
            notch_freqs=self.notch_freqs,
            notch_width=self.notch_width,
        )

        def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
            batch = torch.stack(chunk)  # (N, C, H, W)
            N, C, H, W = batch.shape

            t0 = time.perf_counter()
            yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
            levels = self._pyramid.build(yiq)  # list of (N, 3, h, w)
            level_t = levels[self.pyramid_level]  # (N, 3, h, w)
            t_build = time.perf_counter() - t0

            # Temporal filter (state carried across chunks via filt._zi)
            t1 = time.perf_counter()
            filtered = filt.apply_chunk(level_t)  # (N, 3, h, w)
            t_filter = time.perf_counter() - t1

            # Amplify: Y × alpha, I/Q × alpha × chrom_attenuation
            t2 = time.perf_counter()
            amplified = filtered.clone()
            amplified[:, 0] *= self.alpha
            amplified[:, 1:] *= self.alpha * self.chrom_attenuation

            # Upsample amplified signal back to original resolution
            _, _, h, w = amplified.shape
            upsampled = F.interpolate(
                amplified.reshape(N * 3, 1, h, w),
                size=(H, W),
                mode="bilinear",
                align_corners=False,
            ).reshape(N, 3, H, W)

            result_yiq = yiq + upsampled
            result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
            t_reconstruct = time.perf_counter() - t2

            logger.debug(
                f"  [color chunk N={N}]  "
                f"build={t_build * 1000:.1f}ms  "
                f"filter={t_filter * 1000:.1f}ms  "
                f"reconstruct={t_reconstruct * 1000:.1f}ms"
            )
            yield from result_rgb

        chunk: list[torch.Tensor] = []
        with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
            for frame in frames:
                chunk.append(frame.to(device=self.device, dtype=self.dtype))
                if len(chunk) == chunk_size:
                    t0 = time.perf_counter()
                    for out_frame in _process_chunk(chunk):
                        yield out_frame
                        bar.update(1)
                    elapsed = time.perf_counter() - t0
                    logger.debug(
                        f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                    )
                    chunk = []
            if chunk:
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )

process(frames, fps)

Run colour EVM on a video tensor.

Parameters:

Name Type Description Default
frames Tensor

(T, C, H, W) RGB tensor with values in [0, 1].

required
fps float

Frames per second.

required

Returns:

Type Description
Tensor

Amplified (T, C, H, W) RGB tensor, clamped to [0, 1].

Source code in src/pyevm/magnification/color.py
def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
    """Run colour EVM on a video tensor.

    Args:
        frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
        fps: Frames per second.

    Returns:
        Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
    """
    frames = frames.to(device=self.device, dtype=self.dtype)
    T, C, H, W = frames.shape
    logger.info(f"ColorMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

    # --- Convert to YIQ ---
    yiq = rgb_to_yiq(frames)  # (T, 3, H, W)

    # --- Build pyramid per frame, extract target level ---
    logger.debug(f"Building Gaussian pyramids (level {self.pyramid_level})…")
    level_stack: list[torch.Tensor] = []
    for t in range(T):
        levels = self._pyramid.build(yiq[t])  # list of (1, 3, h, w)
        level_stack.append(levels[self.pyramid_level].squeeze(0))  # (3, h, w)

    # Stack → (T, 3, h, w)
    level_tensor = torch.stack(level_stack, dim=0)
    logger.debug(f"Pyramid level tensor: {tuple(level_tensor.shape)}")

    # --- Temporal filtering ---
    if self.filter_type == "ideal":
        filt = IdealBandpass(
            fps,
            self.freq_low,
            self.freq_high,
            notch_freqs=self.notch_freqs,
            notch_width=self.notch_width,
        )
        filtered = filt.apply(level_tensor)  # (T, 3, h, w)
    else:
        filt = ButterworthBandpass(
            fps,
            self.freq_low,
            self.freq_high,
            notch_freqs=self.notch_freqs,
            notch_width=self.notch_width,
        )
        filtered = filt.apply(level_tensor)

    # --- Amplify: Y × alpha, I/Q × alpha × chrom_attenuation ---
    amplified = filtered.clone()
    amplified[:, 0] *= self.alpha
    amplified[:, 1:] *= self.alpha * self.chrom_attenuation
    logger.debug("Applied amplification to filtered signal")

    # --- Upsample amplified signal back to original resolution ---
    # We upsample the coarsest-level amplified signal to full resolution
    upsampled = torch.nn.functional.interpolate(
        amplified.reshape(T * 3, 1, *amplified.shape[2:]),
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    ).reshape(T, 3, H, W)

    # --- Add to original YIQ ---
    result_yiq = yiq + upsampled

    # --- Convert back to RGB and clamp ---
    result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
    logger.info("ColorMagnifier.process complete")
    return result_rgb

process_stream(frames, fps, n_frames=None, chunk_size=64)

Process frames in chunks, yielding each output frame.

Uses Butterworth IIR regardless of filter_type (the ideal FFT filter requires all frames at once and cannot be used in streaming mode).

Parameters:

Name Type Description Default
frames Iterable[Tensor]

Iterable of (C, H, W) float32 RGB tensors on any device.

required
fps float

Frames per second.

required
n_frames int | None

Total frame count (optional, used for the progress bar).

None
chunk_size int

Number of frames to process per GPU batch.

64

Yields:

Type Description
Tensor

Amplified (C, H, W) float32 RGB tensors, clamped to [0, 1].

Source code in src/pyevm/magnification/color.py
def process_stream(
    self,
    frames: Iterable[torch.Tensor],
    fps: float,
    n_frames: int | None = None,
    chunk_size: int = 64,
) -> Generator[torch.Tensor, None, None]:
    """Process frames in chunks, yielding each output frame.

    Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
    requires all frames at once and cannot be used in streaming mode).

    Args:
        frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
        fps: Frames per second.
        n_frames: Total frame count (optional, used for the progress bar).
        chunk_size: Number of frames to process per GPU batch.

    Yields:
        Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
    """
    filt = ButterworthBandpass(
        fps,
        self.freq_low,
        self.freq_high,
        notch_freqs=self.notch_freqs,
        notch_width=self.notch_width,
    )

    def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
        batch = torch.stack(chunk)  # (N, C, H, W)
        N, C, H, W = batch.shape

        t0 = time.perf_counter()
        yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
        levels = self._pyramid.build(yiq)  # list of (N, 3, h, w)
        level_t = levels[self.pyramid_level]  # (N, 3, h, w)
        t_build = time.perf_counter() - t0

        # Temporal filter (state carried across chunks via filt._zi)
        t1 = time.perf_counter()
        filtered = filt.apply_chunk(level_t)  # (N, 3, h, w)
        t_filter = time.perf_counter() - t1

        # Amplify: Y × alpha, I/Q × alpha × chrom_attenuation
        t2 = time.perf_counter()
        amplified = filtered.clone()
        amplified[:, 0] *= self.alpha
        amplified[:, 1:] *= self.alpha * self.chrom_attenuation

        # Upsample amplified signal back to original resolution
        _, _, h, w = amplified.shape
        upsampled = F.interpolate(
            amplified.reshape(N * 3, 1, h, w),
            size=(H, W),
            mode="bilinear",
            align_corners=False,
        ).reshape(N, 3, H, W)

        result_yiq = yiq + upsampled
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
        t_reconstruct = time.perf_counter() - t2

        logger.debug(
            f"  [color chunk N={N}]  "
            f"build={t_build * 1000:.1f}ms  "
            f"filter={t_filter * 1000:.1f}ms  "
            f"reconstruct={t_reconstruct * 1000:.1f}ms"
        )
        yield from result_rgb

    chunk: list[torch.Tensor] = []
    with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
        for frame in frames:
            chunk.append(frame.to(device=self.device, dtype=self.dtype))
            if len(chunk) == chunk_size:
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )
                chunk = []
        if chunk:
            t0 = time.perf_counter()
            for out_frame in _process_chunk(chunk):
                yield out_frame
                bar.update(1)
            elapsed = time.perf_counter() - t0
            logger.debug(
                f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
            )

pyevm.MotionMagnifier

Motion-based EVM magnifier.

Parameters:

Name Type Description Default
alpha float

Nominal amplification factor (may be reduced per level).

20.0
freq_low float

Temporal bandpass lower frequency (Hz).

0.4
freq_high float

Temporal bandpass upper frequency (Hz).

3.0
n_levels int

Laplacian pyramid levels.

6
lambda_c float

Spatial wavelength cutoff (pixels) for adaptive scaling (default 16, matching the reference MATLAB code).

16.0
filter_type str

"butterworth" (default, streaming) or "ideal".

'butterworth'
device device | None

Compute device.

None
dtype dtype

Tensor dtype.

float32
Source code in src/pyevm/magnification/motion.py
class MotionMagnifier:
    """Motion-based EVM magnifier.

    Args:
        alpha: Nominal amplification factor (may be reduced per level).
        freq_low: Temporal bandpass lower frequency (Hz).
        freq_high: Temporal bandpass upper frequency (Hz).
        n_levels: Laplacian pyramid levels.
        lambda_c: Spatial wavelength cutoff (pixels) for adaptive scaling
            (default 16, matching the reference MATLAB code).
        filter_type: ``"butterworth"`` (default, streaming) or ``"ideal"``.
        device: Compute device.
        dtype: Tensor dtype.
    """

    def __init__(
        self,
        alpha: float = 20.0,
        freq_low: float = 0.4,
        freq_high: float = 3.0,
        n_levels: int = 6,
        lambda_c: float = 16.0,
        filter_type: str = "butterworth",
        notch_freqs: list[float] | None = None,
        notch_width: float = 1.0,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.alpha = alpha
        self.freq_low = freq_low
        self.freq_high = freq_high
        self.n_levels = n_levels
        self.lambda_c = lambda_c
        self.filter_type = filter_type
        self.notch_freqs = notch_freqs or []
        self.notch_width = notch_width
        self.device = device or torch.device("cpu")
        self.dtype = dtype

        self._pyramid = LaplacianPyramid(n_levels=n_levels, device=self.device, dtype=dtype)

        logger.info(
            f"MotionMagnifier: alpha={alpha}, band=[{freq_low}, {freq_high}] Hz, "
            f"lambda_c={lambda_c}, filter={filter_type}"
        )

    def _alpha_for_level(self, level: int) -> float:
        """Compute spatially-adaptive alpha for *level*.

        The spatial wavelength at level *i* is ``2^(i+1)`` pixels at full
        resolution (level 0 = finest, λ = 2 px; level 5 = coarsest, λ = 64 px).

        Fine levels (λ < λ_c) are suppressed to avoid amplifying noise and
        aliasing artefacts.  Coarser levels receive progressively higher
        amplification up to *alpha*.  The crossover is at λ = λ_c / 8; above
        that the cap grows linearly.
        """
        lambda_at_level = 2 ** (level + 1)
        alpha_max = 8.0 * lambda_at_level / self.lambda_c - 1.0
        alpha_eff = min(self.alpha, alpha_max) if alpha_max > 0 else 0.0
        logger.debug(f"  Level {level}: lambda={lambda_at_level}, alpha_eff={alpha_eff:.2f}")
        return alpha_eff

    def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
        """Run motion EVM on a video tensor.

        Args:
            frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
            fps: Frames per second.

        Returns:
            Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
        """
        frames = frames.to(device=self.device, dtype=self.dtype)
        T, C, H, W = frames.shape
        logger.info(f"MotionMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

        # --- Convert to YIQ ---
        yiq = rgb_to_yiq(frames)  # (T, 3, H, W)

        # --- Build Laplacian pyramid per frame ---
        logger.debug("Building Laplacian pyramids…")
        # pyramids[level] = list of (3, h, w) tensors, length T
        pyramids: list[list[torch.Tensor]] = [[] for _ in range(self.n_levels)]
        for t in range(T):
            levels = self._pyramid.build(yiq[t])  # list of (1, 3, h, w)
            for lvl, lev in enumerate(levels):
                pyramids[lvl].append(lev.squeeze(0))  # (3, h, w)

        # --- Temporally filter each level and amplify ---
        filtered_pyramids: list[list[torch.Tensor]] = []
        for lvl in range(self.n_levels):
            # Stack → (T, 3, h, w)
            level_tensor = torch.stack(pyramids[lvl], dim=0)
            logger.debug(f"Level {lvl}: shape {tuple(level_tensor.shape)}")

            if self.filter_type == "ideal":
                filt = IdealBandpass(
                    fps,
                    self.freq_low,
                    self.freq_high,
                    notch_freqs=self.notch_freqs,
                    notch_width=self.notch_width,
                )
                filtered = filt.apply(level_tensor)
            else:
                filt = ButterworthBandpass(
                    fps,
                    self.freq_low,
                    self.freq_high,
                    notch_freqs=self.notch_freqs,
                    notch_width=self.notch_width,
                )
                filtered = filt.apply(level_tensor)

            alpha_eff = self._alpha_for_level(lvl)
            filtered = filtered * alpha_eff

            # Back to list of (3, h, w)
            filtered_pyramids.append([filtered[t] for t in range(T)])

        # --- Reconstruct each frame ---
        logger.debug("Collapsing pyramids…")
        output_frames: list[torch.Tensor] = []
        for t in range(T):
            # Modify original pyramid by adding filtered amplification
            orig_levels = self._pyramid.build(yiq[t])
            modified_levels = [
                orig_levels[lvl] + filtered_pyramids[lvl][t].unsqueeze(0)
                for lvl in range(self.n_levels)
            ]
            reconstructed = self._pyramid.collapse(modified_levels)  # (1, 3, H, W)
            output_frames.append(reconstructed.squeeze(0))  # (3, H, W)

        result_yiq = torch.stack(output_frames, dim=0)  # (T, 3, H, W)

        # --- Convert back to RGB and clamp ---
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
        logger.info("MotionMagnifier.process complete")
        return result_rgb

    def process_stream(
        self,
        frames: Iterable[torch.Tensor],
        fps: float,
        n_frames: int | None = None,
        chunk_size: int = 64,
    ) -> Generator[torch.Tensor, None, None]:
        """Process frames in chunks, yielding each output frame.

        Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
        requires all frames at once and cannot be used in streaming mode).

        Args:
            frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
            fps: Frames per second.
            n_frames: Total frame count (optional, used for the progress bar).
            chunk_size: Number of frames to process per GPU batch.

        Yields:
            Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
        """
        # One filter per pyramid level; state carries across chunk boundaries
        filters = [
            ButterworthBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            for _ in range(self.n_levels)
        ]
        # Precompute per-level alpha
        alphas = [self._alpha_for_level(lvl) for lvl in range(self.n_levels)]

        def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
            N = len(chunk)
            batch = torch.stack(chunk)  # (N, C, H, W)

            t0 = time.perf_counter()
            yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
            levels = self._pyramid.build(yiq)  # list of n_levels × (N, 3, h_l, w_l)
            t_build = time.perf_counter() - t0

            t_filter_lvl: list[float] = []
            modified_levels = []
            for lvl in range(self.n_levels):
                level_t = levels[lvl]  # (N, 3, h_l, w_l)
                if alphas[lvl] == 0.0:
                    # Skip IIR call entirely; still advance state so timing is consistent
                    modified_levels.append(level_t)
                    t_filter_lvl.append(0.0)
                else:
                    tf = time.perf_counter()
                    filtered = filters[lvl].apply_chunk(level_t)  # (N, 3, h_l, w_l)
                    t_filter_lvl.append(time.perf_counter() - tf)
                    modified_levels.append(level_t + filtered * alphas[lvl])

            t1 = time.perf_counter()
            result_yiq = self._pyramid.collapse(modified_levels)  # (N, 3, H, W)
            result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
            t_collapse = time.perf_counter() - t1

            total_filter = sum(t_filter_lvl)
            per_lvl = "  ".join(f"L{i}={ms * 1000:.1f}ms" for i, ms in enumerate(t_filter_lvl))
            logger.debug(
                f"  [motion chunk N={N}]  "
                f"build={t_build * 1000:.1f}ms  "
                f"filter_total={total_filter * 1000:.1f}ms ({per_lvl})  "
                f"collapse={t_collapse * 1000:.1f}ms"
            )
            yield from result_rgb

        chunk: list[torch.Tensor] = []
        with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
            for frame in frames:
                chunk.append(frame.to(device=self.device, dtype=self.dtype))
                if len(chunk) == chunk_size:
                    t0 = time.perf_counter()
                    for out_frame in _process_chunk(chunk):
                        yield out_frame
                        bar.update(1)
                    elapsed = time.perf_counter() - t0
                    logger.debug(
                        f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                    )
                    chunk = []
            if chunk:
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )

process(frames, fps)

Run motion EVM on a video tensor.

Parameters:

Name Type Description Default
frames Tensor

(T, C, H, W) RGB tensor with values in [0, 1].

required
fps float

Frames per second.

required

Returns:

Type Description
Tensor

Amplified (T, C, H, W) RGB tensor, clamped to [0, 1].

Source code in src/pyevm/magnification/motion.py
def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
    """Run motion EVM on a video tensor.

    Args:
        frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
        fps: Frames per second.

    Returns:
        Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
    """
    frames = frames.to(device=self.device, dtype=self.dtype)
    T, C, H, W = frames.shape
    logger.info(f"MotionMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

    # --- Convert to YIQ ---
    yiq = rgb_to_yiq(frames)  # (T, 3, H, W)

    # --- Build Laplacian pyramid per frame ---
    logger.debug("Building Laplacian pyramids…")
    # pyramids[level] = list of (3, h, w) tensors, length T
    pyramids: list[list[torch.Tensor]] = [[] for _ in range(self.n_levels)]
    for t in range(T):
        levels = self._pyramid.build(yiq[t])  # list of (1, 3, h, w)
        for lvl, lev in enumerate(levels):
            pyramids[lvl].append(lev.squeeze(0))  # (3, h, w)

    # --- Temporally filter each level and amplify ---
    filtered_pyramids: list[list[torch.Tensor]] = []
    for lvl in range(self.n_levels):
        # Stack → (T, 3, h, w)
        level_tensor = torch.stack(pyramids[lvl], dim=0)
        logger.debug(f"Level {lvl}: shape {tuple(level_tensor.shape)}")

        if self.filter_type == "ideal":
            filt = IdealBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            filtered = filt.apply(level_tensor)
        else:
            filt = ButterworthBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            filtered = filt.apply(level_tensor)

        alpha_eff = self._alpha_for_level(lvl)
        filtered = filtered * alpha_eff

        # Back to list of (3, h, w)
        filtered_pyramids.append([filtered[t] for t in range(T)])

    # --- Reconstruct each frame ---
    logger.debug("Collapsing pyramids…")
    output_frames: list[torch.Tensor] = []
    for t in range(T):
        # Modify original pyramid by adding filtered amplification
        orig_levels = self._pyramid.build(yiq[t])
        modified_levels = [
            orig_levels[lvl] + filtered_pyramids[lvl][t].unsqueeze(0)
            for lvl in range(self.n_levels)
        ]
        reconstructed = self._pyramid.collapse(modified_levels)  # (1, 3, H, W)
        output_frames.append(reconstructed.squeeze(0))  # (3, H, W)

    result_yiq = torch.stack(output_frames, dim=0)  # (T, 3, H, W)

    # --- Convert back to RGB and clamp ---
    result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
    logger.info("MotionMagnifier.process complete")
    return result_rgb

process_stream(frames, fps, n_frames=None, chunk_size=64)

Process frames in chunks, yielding each output frame.

Uses Butterworth IIR regardless of filter_type (the ideal FFT filter requires all frames at once and cannot be used in streaming mode).

Parameters:

Name Type Description Default
frames Iterable[Tensor]

Iterable of (C, H, W) float32 RGB tensors on any device.

required
fps float

Frames per second.

required
n_frames int | None

Total frame count (optional, used for the progress bar).

None
chunk_size int

Number of frames to process per GPU batch.

64

Yields:

Type Description
Tensor

Amplified (C, H, W) float32 RGB tensors, clamped to [0, 1].

Source code in src/pyevm/magnification/motion.py
def process_stream(
    self,
    frames: Iterable[torch.Tensor],
    fps: float,
    n_frames: int | None = None,
    chunk_size: int = 64,
) -> Generator[torch.Tensor, None, None]:
    """Process frames in chunks, yielding each output frame.

    Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
    requires all frames at once and cannot be used in streaming mode).

    Args:
        frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
        fps: Frames per second.
        n_frames: Total frame count (optional, used for the progress bar).
        chunk_size: Number of frames to process per GPU batch.

    Yields:
        Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
    """
    # One filter per pyramid level; state carries across chunk boundaries
    filters = [
        ButterworthBandpass(
            fps,
            self.freq_low,
            self.freq_high,
            notch_freqs=self.notch_freqs,
            notch_width=self.notch_width,
        )
        for _ in range(self.n_levels)
    ]
    # Precompute per-level alpha
    alphas = [self._alpha_for_level(lvl) for lvl in range(self.n_levels)]

    def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
        N = len(chunk)
        batch = torch.stack(chunk)  # (N, C, H, W)

        t0 = time.perf_counter()
        yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
        levels = self._pyramid.build(yiq)  # list of n_levels × (N, 3, h_l, w_l)
        t_build = time.perf_counter() - t0

        t_filter_lvl: list[float] = []
        modified_levels = []
        for lvl in range(self.n_levels):
            level_t = levels[lvl]  # (N, 3, h_l, w_l)
            if alphas[lvl] == 0.0:
                # Skip IIR call entirely; still advance state so timing is consistent
                modified_levels.append(level_t)
                t_filter_lvl.append(0.0)
            else:
                tf = time.perf_counter()
                filtered = filters[lvl].apply_chunk(level_t)  # (N, 3, h_l, w_l)
                t_filter_lvl.append(time.perf_counter() - tf)
                modified_levels.append(level_t + filtered * alphas[lvl])

        t1 = time.perf_counter()
        result_yiq = self._pyramid.collapse(modified_levels)  # (N, 3, H, W)
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
        t_collapse = time.perf_counter() - t1

        total_filter = sum(t_filter_lvl)
        per_lvl = "  ".join(f"L{i}={ms * 1000:.1f}ms" for i, ms in enumerate(t_filter_lvl))
        logger.debug(
            f"  [motion chunk N={N}]  "
            f"build={t_build * 1000:.1f}ms  "
            f"filter_total={total_filter * 1000:.1f}ms ({per_lvl})  "
            f"collapse={t_collapse * 1000:.1f}ms"
        )
        yield from result_rgb

    chunk: list[torch.Tensor] = []
    with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
        for frame in frames:
            chunk.append(frame.to(device=self.device, dtype=self.dtype))
            if len(chunk) == chunk_size:
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )
                chunk = []
        if chunk:
            t0 = time.perf_counter()
            for out_frame in _process_chunk(chunk):
                yield out_frame
                bar.update(1)
            elapsed = time.perf_counter() - t0
            logger.debug(
                f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
            )

pyevm.PhaseMagnifier

Phase-based EVM magnifier.

Parameters:

Name Type Description Default
factor float

Phase amplification factor.

10.0
freq_low float

Temporal bandpass lower frequency (Hz).

0.4
freq_high float

Temporal bandpass upper frequency (Hz).

3.0
n_scales int

Number of pyramid scales.

6
n_orientations int

Number of orientation bands per scale.

8
sigma float

Spatial phase smoothing sigma (pixels, 0 = disabled). Uses amplitude-weighted Gaussian blur (Eq. 17) so that low-amplitude regions do not corrupt the phase of high-amplitude neighbours.

0.0
filter_type str

"ideal" (FFT) or "butterworth" (IIR).

'ideal'
attenuate_motion bool

If True, amplified phase changes that exceed attenuate_mag are wrapped back into [−lim, +lim] rather than applied directly. Large motions (e.g. camera shake) produce large uniform phase changes after amplification and are effectively attenuated, while subtle local motions with |amp| < lim pass through unmodified. Corresponds to the "Attenuate" mode in Fig. 11 of Wadhwa et al. 2013.

False
attenuate_mag float

Threshold for large-motion attenuation (radians). Default π — the largest unambiguous single-step phase change.

pi
device device | None

Compute device.

None
dtype dtype

Real tensor dtype (sub-band coefficients are complex).

float32
Source code in src/pyevm/magnification/phase.py
class PhaseMagnifier:
    """Phase-based EVM magnifier.

    Args:
        factor: Phase amplification factor.
        freq_low: Temporal bandpass lower frequency (Hz).
        freq_high: Temporal bandpass upper frequency (Hz).
        n_scales: Number of pyramid scales.
        n_orientations: Number of orientation bands per scale.
        sigma: Spatial phase smoothing sigma (pixels, ``0`` = disabled).
            Uses amplitude-weighted Gaussian blur (Eq. 17) so that low-amplitude
            regions do not corrupt the phase of high-amplitude neighbours.
        filter_type: ``"ideal"`` (FFT) or ``"butterworth"`` (IIR).
        attenuate_motion: If ``True``, amplified phase changes that exceed
            ``attenuate_mag`` are **wrapped** back into ``[−lim, +lim]`` rather
            than applied directly.  Large motions (e.g. camera shake) produce
            large uniform phase changes after amplification and are effectively
            attenuated, while subtle local motions with ``|amp| < lim`` pass
            through unmodified.  Corresponds to the "Attenuate" mode in Fig. 11
            of Wadhwa et al. 2013.
        attenuate_mag: Threshold for large-motion attenuation (radians).
            Default ``π`` — the largest unambiguous single-step phase change.
        device: Compute device.
        dtype: Real tensor dtype (sub-band coefficients are complex).
    """

    def __init__(
        self,
        factor: float = 10.0,
        freq_low: float = 0.4,
        freq_high: float = 3.0,
        n_scales: int = 6,
        n_orientations: int = 8,
        sigma: float = 0.0,
        filter_type: str = "ideal",
        attenuate_motion: bool = False,
        attenuate_mag: float = math.pi,
        notch_freqs: list[float] | None = None,
        notch_width: float = 1.0,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.factor = factor
        self.freq_low = freq_low
        self.freq_high = freq_high
        self.n_scales = n_scales
        self.n_orientations = n_orientations
        self.sigma = sigma
        self.filter_type = filter_type
        self.attenuate_motion = attenuate_motion
        self.attenuate_mag = attenuate_mag
        self.notch_freqs = notch_freqs or []
        self.notch_width = notch_width
        self.device = device or torch.device("cpu")
        self.dtype = dtype

        self._pyramid = SteerablePyramid(
            n_scales=n_scales,
            n_orientations=n_orientations,
            device=self.device,
            dtype=dtype,
        )
        logger.info(
            f"PhaseMagnifier: factor={factor}, band=[{freq_low}, {freq_high}] Hz, "
            f"scales={n_scales}, orientations={n_orientations}, filter={filter_type}"
            + (f", attenuate_mag={attenuate_mag:.3f}" if attenuate_motion else "")
        )

    def _smooth_phase(self, phase: torch.Tensor, amplitude: torch.Tensor) -> torch.Tensor:
        """Amplitude-weighted spatial Gaussian smoothing (Eq. 17, Wadhwa et al. 2013).

        Computes:   (φ · A) ∗ K_σ  /  (A ∗ K_σ)

        where K_σ is a Gaussian kernel.  Pixels with near-zero amplitude
        contribute negligible weight, preventing their noisy phase values from
        corrupting high-amplitude neighbours.

        Args:
            phase:     ``(H, W)`` or ``(T, H, W)`` real phase tensor.
            amplitude: Same shape — local coefficient amplitude (``abs(coeff)``).

        Returns:
            Smoothed phase with the same shape.
        """
        if self.sigma <= 0:
            return phase

        radius = int(3 * self.sigma)
        size = 2 * radius + 1
        coords = torch.arange(size, dtype=self.dtype, device=self.device) - radius
        g1d = torch.exp(-(coords**2) / (2 * self.sigma**2))
        g1d = g1d / g1d.sum()
        kernel = torch.outer(g1d, g1d).unsqueeze(0).unsqueeze(0)  # (1,1,k,k)

        if phase.dim() == 2:
            ph = phase.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
            am = amplitude.unsqueeze(0).unsqueeze(0)  # (1,1,H,W)
        else:  # (T, H, W) or (N, H, W) — leading dim treated as batch
            ph = phase.unsqueeze(1)  # (T,1,H,W)
            am = amplitude.unsqueeze(1)  # (T,1,H,W)

        numer = F.conv2d(ph * am, kernel, padding=radius)
        denom = F.conv2d(am, kernel, padding=radius).clamp(min=1e-8)
        result = numer / denom

        if phase.dim() == 2:
            return result.squeeze(0).squeeze(0)
        return result.squeeze(1)

    def _apply_attenuation(self, amp_phase: torch.Tensor) -> torch.Tensor:
        """Wrap amplified phase into ``[−lim, +lim]`` to attenuate large motions.

        Implements the "Attenuate" mode from Fig. 11 of Wadhwa et al. 2013.

        The wrapping formula is:
            ``mod(amp_phase + lim, 2·lim) − lim``

        Behaviour by region (lim = π by default):
            * ``|amp| ≤ π``: returned unchanged — small motions amplified normally.
            * ``|amp| = 2π``: maps to 0 — motion fully cancelled.
            * ``|amp| > 2π``: wraps back into [−π, π] — motion attenuated.

        Camera shake produces large uniform phase changes across the frame;
        after amplification these exceed π and are wrapped back, effectively
        cancelling the global motion.  Subtle local vibrations with
        ``|amp| < π`` are unaffected.

        Args:
            amp_phase: Already-amplified phase tensor (any shape).

        Returns:
            Phase tensor with large-motion contributions attenuated.
        """
        lim = self.attenuate_mag
        return (amp_phase + lim) % (2 * lim) - lim

    def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
        """Run phase-based EVM on a video tensor.

        Args:
            frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
            fps: Frames per second.

        Returns:
            Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
        """
        frames = frames.to(device=self.device, dtype=self.dtype)
        T, C, H, W = frames.shape
        logger.info(f"PhaseMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

        # --- Convert to YIQ ---
        yiq = rgb_to_yiq(frames)  # (T, 3, H, W)
        luma = yiq[:, 0, :, :]  # (T, H, W)

        # --- Build steerable pyramid for ALL frames in one batched GPU call ---
        logger.debug("Building steerable pyramids for all frames…")
        pyramid = self._pyramid.build(luma)  # subbands[s][o]: (T, H_s, W_s) complex

        # --- Process each sub-band ---
        n_bands = self.n_scales * self.n_orientations
        with tqdm(total=n_bands, desc="Magnifying", unit="band", leave=False) as bar:
            for scale in range(self.n_scales):
                for orient in range(self.n_orientations):
                    coeffs = pyramid["subbands"][scale][orient]  # (T, H_s, W_s) complex
                    amplitude = coeffs.abs()  # (T, H_s, W_s)
                    phase = torch.angle(coeffs)  # (T, H_s, W_s)
                    # Circular wrapping: mod(π + Δ, 2π) − π  (matches MATLAB reference)
                    delta_phase = (phase - phase[0:1] + math.pi) % (2 * math.pi) - math.pi

                    # Step 1: temporal filter — isolate motion frequency band
                    if self.filter_type == "ideal":
                        filt = IdealBandpass(
                            fps,
                            self.freq_low,
                            self.freq_high,
                            notch_freqs=self.notch_freqs,
                            notch_width=self.notch_width,
                        )
                        filtered_phase = filt.apply(delta_phase)
                    else:
                        filt_bw = ButterworthBandpass(
                            fps,
                            self.freq_low,
                            self.freq_high,
                            notch_freqs=self.notch_freqs,
                            notch_width=self.notch_width,
                        )
                        filtered_phase = filt_bw.apply(delta_phase)

                    # Step 2: amplitude-weighted spatial smoothing (Eq. 17)
                    # Applied AFTER temporal filtering (Fig. 2 order) so that
                    # broadband phase noise is attenuated before smoothing.
                    if self.sigma > 0:
                        filtered_phase = self._smooth_phase(filtered_phase, amplitude)

                    # Step 3: amplify (and optionally attenuate large motions) then shift
                    amp_phase = filtered_phase * self.factor
                    if self.attenuate_motion:
                        amp_phase = self._apply_attenuation(amp_phase)
                    pyramid["subbands"][scale][orient] = torch.polar(amplitude, phase + amp_phase)
                    logger.debug(
                        f"Scale {scale}, orientation {orient}: "
                        f"phase amp range [{amp_phase.min():.3f}, {amp_phase.max():.3f}]"
                    )
                    bar.update(1)

        # --- Reconstruct Y from all modified pyramids in one batched GPU call ---
        logger.debug("Collapsing modified pyramids…")
        luma_out = self._pyramid.collapse(pyramid)  # (T, H, W)

        # --- Recombine with I, Q and convert to RGB ---
        result_yiq = yiq.clone()
        result_yiq[:, 0, :, :] = luma_out
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
        logger.info("PhaseMagnifier.process complete")
        return result_rgb

    def process_stream(
        self,
        frames: Iterable[torch.Tensor],
        fps: float,
        n_frames: int | None = None,
        chunk_size: int = 64,
    ) -> Generator[torch.Tensor, None, None]:
        """Process frames in chunks, yielding each output frame.

        Buffers *chunk_size* frames, then runs the batched pyramid build and
        collapse in one GPU call.  The Butterworth IIR filter state carries
        across chunk boundaries via its ``zi`` parameter.

        Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
        requires all frames at once and cannot be used in streaming mode).

        Args:
            frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
            fps: Frames per second.
            n_frames: Total frame count (optional, used for the progress bar).
            chunk_size: Number of frames to process per GPU batch.

        Yields:
            Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
        """
        filters: dict[tuple[int, int], ButterworthBandpass] = {
            (s, o): ButterworthBandpass(
                fps,
                self.freq_low,
                self.freq_high,
                notch_freqs=self.notch_freqs,
                notch_width=self.notch_width,
            )
            for s in range(self.n_scales)
            for o in range(self.n_orientations)
        }
        ref_phase: dict[tuple[int, int], torch.Tensor] = {}

        def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
            N = len(chunk)
            batch = torch.stack(chunk)  # (N, C, H, W)

            t0 = time.perf_counter()
            yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
            luma = yiq[:, 0, :, :]  # (N, H, W)
            pyramid = self._pyramid.build(luma)  # subbands: (N, H_s, W_s)
            t_build = time.perf_counter() - t0

            t1 = time.perf_counter()
            for s in range(self.n_scales):
                for o in range(self.n_orientations):
                    coeffs = pyramid["subbands"][s][o]  # (N, H_s, W_s) complex
                    amplitude = coeffs.abs()  # (N, H_s, W_s)
                    phase = torch.angle(coeffs)  # (N, H_s, W_s)

                    key = (s, o)
                    if key not in ref_phase:
                        ref_phase[key] = phase[0].clone()

                    # Circular wrapping: mod(π + Δ, 2π) − π  (matches MATLAB reference)
                    delta = (phase - ref_phase[key] + math.pi) % (2 * math.pi) - math.pi

                    # Temporal filter first, then amplitude-weighted spatial smooth
                    filtered = filters[key].apply_chunk(delta)
                    if self.sigma > 0:
                        filtered = self._smooth_phase(filtered, amplitude)

                    amp_phase = filtered * self.factor
                    if self.attenuate_motion:
                        amp_phase = self._apply_attenuation(amp_phase)
                    pyramid["subbands"][s][o] = torch.polar(amplitude, phase + amp_phase)
            t_filter = time.perf_counter() - t1

            t2 = time.perf_counter()
            luma_out = self._pyramid.collapse(pyramid)  # (N, H, W)
            result_yiq = yiq.clone()
            result_yiq[:, 0, :, :] = luma_out
            result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
            t_collapse = time.perf_counter() - t2

            logger.debug(
                f"  [phase chunk N={N}]  "
                f"build={t_build * 1000:.1f}ms  "
                f"filter={t_filter * 1000:.1f}ms  "
                f"collapse={t_collapse * 1000:.1f}ms"
            )
            yield from result_rgb  # yield N individual (C, H, W) frames

        chunk: list[torch.Tensor] = []
        with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
            for frame in frames:
                chunk.append(frame.to(device=self.device, dtype=self.dtype))
                if len(chunk) == chunk_size:
                    t0 = time.perf_counter()
                    for out_frame in _process_chunk(chunk):
                        yield out_frame
                        bar.update(1)
                    elapsed = time.perf_counter() - t0
                    logger.debug(
                        f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                    )
                    chunk = []
            if chunk:  # final partial chunk
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )

process(frames, fps)

Run phase-based EVM on a video tensor.

Parameters:

Name Type Description Default
frames Tensor

(T, C, H, W) RGB tensor with values in [0, 1].

required
fps float

Frames per second.

required

Returns:

Type Description
Tensor

Amplified (T, C, H, W) RGB tensor, clamped to [0, 1].

Source code in src/pyevm/magnification/phase.py
def process(self, frames: torch.Tensor, fps: float) -> torch.Tensor:
    """Run phase-based EVM on a video tensor.

    Args:
        frames: ``(T, C, H, W)`` RGB tensor with values in ``[0, 1]``.
        fps: Frames per second.

    Returns:
        Amplified ``(T, C, H, W)`` RGB tensor, clamped to ``[0, 1]``.
    """
    frames = frames.to(device=self.device, dtype=self.dtype)
    T, C, H, W = frames.shape
    logger.info(f"PhaseMagnifier.process: {T} frames @ {fps} fps, shape {(C, H, W)}")

    # --- Convert to YIQ ---
    yiq = rgb_to_yiq(frames)  # (T, 3, H, W)
    luma = yiq[:, 0, :, :]  # (T, H, W)

    # --- Build steerable pyramid for ALL frames in one batched GPU call ---
    logger.debug("Building steerable pyramids for all frames…")
    pyramid = self._pyramid.build(luma)  # subbands[s][o]: (T, H_s, W_s) complex

    # --- Process each sub-band ---
    n_bands = self.n_scales * self.n_orientations
    with tqdm(total=n_bands, desc="Magnifying", unit="band", leave=False) as bar:
        for scale in range(self.n_scales):
            for orient in range(self.n_orientations):
                coeffs = pyramid["subbands"][scale][orient]  # (T, H_s, W_s) complex
                amplitude = coeffs.abs()  # (T, H_s, W_s)
                phase = torch.angle(coeffs)  # (T, H_s, W_s)
                # Circular wrapping: mod(π + Δ, 2π) − π  (matches MATLAB reference)
                delta_phase = (phase - phase[0:1] + math.pi) % (2 * math.pi) - math.pi

                # Step 1: temporal filter — isolate motion frequency band
                if self.filter_type == "ideal":
                    filt = IdealBandpass(
                        fps,
                        self.freq_low,
                        self.freq_high,
                        notch_freqs=self.notch_freqs,
                        notch_width=self.notch_width,
                    )
                    filtered_phase = filt.apply(delta_phase)
                else:
                    filt_bw = ButterworthBandpass(
                        fps,
                        self.freq_low,
                        self.freq_high,
                        notch_freqs=self.notch_freqs,
                        notch_width=self.notch_width,
                    )
                    filtered_phase = filt_bw.apply(delta_phase)

                # Step 2: amplitude-weighted spatial smoothing (Eq. 17)
                # Applied AFTER temporal filtering (Fig. 2 order) so that
                # broadband phase noise is attenuated before smoothing.
                if self.sigma > 0:
                    filtered_phase = self._smooth_phase(filtered_phase, amplitude)

                # Step 3: amplify (and optionally attenuate large motions) then shift
                amp_phase = filtered_phase * self.factor
                if self.attenuate_motion:
                    amp_phase = self._apply_attenuation(amp_phase)
                pyramid["subbands"][scale][orient] = torch.polar(amplitude, phase + amp_phase)
                logger.debug(
                    f"Scale {scale}, orientation {orient}: "
                    f"phase amp range [{amp_phase.min():.3f}, {amp_phase.max():.3f}]"
                )
                bar.update(1)

    # --- Reconstruct Y from all modified pyramids in one batched GPU call ---
    logger.debug("Collapsing modified pyramids…")
    luma_out = self._pyramid.collapse(pyramid)  # (T, H, W)

    # --- Recombine with I, Q and convert to RGB ---
    result_yiq = yiq.clone()
    result_yiq[:, 0, :, :] = luma_out
    result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)
    logger.info("PhaseMagnifier.process complete")
    return result_rgb

process_stream(frames, fps, n_frames=None, chunk_size=64)

Process frames in chunks, yielding each output frame.

Buffers chunk_size frames, then runs the batched pyramid build and collapse in one GPU call. The Butterworth IIR filter state carries across chunk boundaries via its zi parameter.

Uses Butterworth IIR regardless of filter_type (the ideal FFT filter requires all frames at once and cannot be used in streaming mode).

Parameters:

Name Type Description Default
frames Iterable[Tensor]

Iterable of (C, H, W) float32 RGB tensors on any device.

required
fps float

Frames per second.

required
n_frames int | None

Total frame count (optional, used for the progress bar).

None
chunk_size int

Number of frames to process per GPU batch.

64

Yields:

Type Description
Tensor

Amplified (C, H, W) float32 RGB tensors, clamped to [0, 1].

Source code in src/pyevm/magnification/phase.py
def process_stream(
    self,
    frames: Iterable[torch.Tensor],
    fps: float,
    n_frames: int | None = None,
    chunk_size: int = 64,
) -> Generator[torch.Tensor, None, None]:
    """Process frames in chunks, yielding each output frame.

    Buffers *chunk_size* frames, then runs the batched pyramid build and
    collapse in one GPU call.  The Butterworth IIR filter state carries
    across chunk boundaries via its ``zi`` parameter.

    Uses Butterworth IIR regardless of *filter_type* (the ideal FFT filter
    requires all frames at once and cannot be used in streaming mode).

    Args:
        frames: Iterable of ``(C, H, W)`` float32 RGB tensors on any device.
        fps: Frames per second.
        n_frames: Total frame count (optional, used for the progress bar).
        chunk_size: Number of frames to process per GPU batch.

    Yields:
        Amplified ``(C, H, W)`` float32 RGB tensors, clamped to ``[0, 1]``.
    """
    filters: dict[tuple[int, int], ButterworthBandpass] = {
        (s, o): ButterworthBandpass(
            fps,
            self.freq_low,
            self.freq_high,
            notch_freqs=self.notch_freqs,
            notch_width=self.notch_width,
        )
        for s in range(self.n_scales)
        for o in range(self.n_orientations)
    }
    ref_phase: dict[tuple[int, int], torch.Tensor] = {}

    def _process_chunk(chunk: list[torch.Tensor]) -> Generator[torch.Tensor, None, None]:
        N = len(chunk)
        batch = torch.stack(chunk)  # (N, C, H, W)

        t0 = time.perf_counter()
        yiq = rgb_to_yiq(batch)  # (N, 3, H, W)
        luma = yiq[:, 0, :, :]  # (N, H, W)
        pyramid = self._pyramid.build(luma)  # subbands: (N, H_s, W_s)
        t_build = time.perf_counter() - t0

        t1 = time.perf_counter()
        for s in range(self.n_scales):
            for o in range(self.n_orientations):
                coeffs = pyramid["subbands"][s][o]  # (N, H_s, W_s) complex
                amplitude = coeffs.abs()  # (N, H_s, W_s)
                phase = torch.angle(coeffs)  # (N, H_s, W_s)

                key = (s, o)
                if key not in ref_phase:
                    ref_phase[key] = phase[0].clone()

                # Circular wrapping: mod(π + Δ, 2π) − π  (matches MATLAB reference)
                delta = (phase - ref_phase[key] + math.pi) % (2 * math.pi) - math.pi

                # Temporal filter first, then amplitude-weighted spatial smooth
                filtered = filters[key].apply_chunk(delta)
                if self.sigma > 0:
                    filtered = self._smooth_phase(filtered, amplitude)

                amp_phase = filtered * self.factor
                if self.attenuate_motion:
                    amp_phase = self._apply_attenuation(amp_phase)
                pyramid["subbands"][s][o] = torch.polar(amplitude, phase + amp_phase)
        t_filter = time.perf_counter() - t1

        t2 = time.perf_counter()
        luma_out = self._pyramid.collapse(pyramid)  # (N, H, W)
        result_yiq = yiq.clone()
        result_yiq[:, 0, :, :] = luma_out
        result_rgb = yiq_to_rgb(result_yiq).clamp(0.0, 1.0)  # (N, C, H, W)
        t_collapse = time.perf_counter() - t2

        logger.debug(
            f"  [phase chunk N={N}]  "
            f"build={t_build * 1000:.1f}ms  "
            f"filter={t_filter * 1000:.1f}ms  "
            f"collapse={t_collapse * 1000:.1f}ms"
        )
        yield from result_rgb  # yield N individual (C, H, W) frames

    chunk: list[torch.Tensor] = []
    with tqdm(total=n_frames, desc="Magnifying", unit="frame", position=1, leave=True) as bar:
        for frame in frames:
            chunk.append(frame.to(device=self.device, dtype=self.dtype))
            if len(chunk) == chunk_size:
                t0 = time.perf_counter()
                for out_frame in _process_chunk(chunk):
                    yield out_frame
                    bar.update(1)
                elapsed = time.perf_counter() - t0
                logger.debug(
                    f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
                )
                chunk = []
        if chunk:  # final partial chunk
            t0 = time.perf_counter()
            for out_frame in _process_chunk(chunk):
                yield out_frame
                bar.update(1)
            elapsed = time.perf_counter() - t0
            logger.debug(
                f"Chunk {len(chunk)} frames: {elapsed:.2f}s ({len(chunk) / elapsed:.1f} fps)"
            )

Video I/O

pyevm.io.video.VideoReader

Read a video file into a tensor.

Parameters:

Name Type Description Default
path str | Path

Path to the video file.

required
device device | None

Tensor device (CPU only when using decord CPU bridge; the returned tensor is always moved to device after reading).

None
max_frames int | None

Limit number of frames read (None = all).

None
Source code in src/pyevm/io/video.py
class VideoReader:
    """Read a video file into a tensor.

    Args:
        path: Path to the video file.
        device: Tensor device (CPU only when using decord CPU bridge; the
            returned tensor is always moved to *device* after reading).
        max_frames: Limit number of frames read (``None`` = all).
    """

    def __init__(
        self,
        path: str | Path,
        device: torch.device | None = None,
        max_frames: int | None = None,
    ) -> None:
        self.path = Path(path)
        self.device = device or torch.device("cpu")
        self.max_frames = max_frames
        self._meta: dict | None = None

    @property
    def metadata(self) -> dict:
        """Return ``{"fps": float, "n_frames": int, "height": int, "width": int}``."""
        if self._meta is None:
            self._meta = self._read_metadata()
        return self._meta

    def _read_metadata(self) -> dict:
        cap = cv2.VideoCapture(str(self.path))
        if not cap.isOpened():
            raise OSError(f"Cannot open video: {self.path}")
        fps = cap.get(cv2.CAP_PROP_FPS)
        n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        cap.release()
        return {"fps": fps, "n_frames": n_frames, "height": height, "width": width}

    def read(self) -> tuple[torch.Tensor, float]:
        """Read video frames.

        Returns:
            ``(frames, fps)`` where *frames* is ``(T, C, H, W)`` float32
            tensor in ``[0, 1]`` on *self.device*.
        """
        try:
            frames, fps = self._read_torchcodec()
            logger.info(f"VideoReader: read {frames.shape[0]} frames via torchcodec")
        except Exception as exc:
            logger.debug(f"torchcodec unavailable ({exc}), falling back to OpenCV")
            frames, fps = self._read_opencv()
            logger.info(f"VideoReader: read {frames.shape[0]} frames via OpenCV")

        return frames, fps

    def stream(self) -> Generator[torch.Tensor, None, None]:
        """Yield frames one at a time as ``(C, H, W)`` float32 tensors on *self.device*.

        Memory cost is constant — only one decoded frame lives in RAM at a
        time, regardless of video length.  Use this for large videos where
        :meth:`read` would exhaust available memory.

        Uses *torchcodec* when available (GPU-accelerated), falling back to
        OpenCV.
        """
        try:
            yield from self._stream_torchcodec()
            return
        except Exception as exc:
            logger.debug(f"torchcodec stream unavailable ({exc}), falling back to OpenCV")

        yield from self._stream_opencv()

    def _read_torchcodec(self) -> tuple[torch.Tensor, float]:
        from torchcodec.decoders import VideoDecoder  # noqa: PLC0415

        meta = self.metadata
        fps = meta["fps"]
        limit = (
            min(self.max_frames, meta["n_frames"])
            if self.max_frames is not None
            else meta["n_frames"]
        )
        decoder = VideoDecoder(str(self.path), device=str(self.device))
        # Slice returns (T, C, H, W) uint8 tensor already on self.device
        frames = decoder[0:limit].float() * (1.0 / 255.0)
        return frames, fps

    def _stream_torchcodec(self) -> Generator[torch.Tensor, None, None]:
        from torchcodec.decoders import VideoDecoder  # noqa: PLC0415

        meta = self.metadata
        limit = (
            min(self.max_frames, meta["n_frames"])
            if self.max_frames is not None
            else meta["n_frames"]
        )
        decoder = VideoDecoder(str(self.path), device=str(self.device))
        with tqdm(total=limit, desc="   Reading", unit="frame", position=2, leave=True) as bar:
            for i in range(limit):
                # decoder[i] → (C, H, W) uint8 tensor on self.device
                yield decoder[i].float() * (1.0 / 255.0)
                bar.update(1)

    def _stream_opencv(self) -> Generator[torch.Tensor, None, None]:
        meta = self.metadata
        limit = (
            min(self.max_frames, meta["n_frames"])
            if self.max_frames is not None
            else meta["n_frames"]
        )
        cap = cv2.VideoCapture(str(self.path))
        if not cap.isOpened():
            raise OSError(f"Cannot open video: {self.path}")
        count = 0
        try:
            with tqdm(total=limit, desc="   Reading", unit="frame", position=2, leave=True) as bar:
                while count < limit:
                    ret, frame = cap.read()
                    if not ret:
                        break
                    frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    yield (
                        torch.from_numpy(frame_rgb.astype(np.float32) * (1.0 / 255.0))
                        .permute(2, 0, 1)
                        .to(self.device)
                    )
                    count += 1
                    bar.update(1)
        finally:
            cap.release()

        if count == 0:
            raise RuntimeError(f"No frames could be read from {self.path}")

    def _read_opencv(self) -> tuple[torch.Tensor, float]:
        cap = cv2.VideoCapture(str(self.path))
        if not cap.isOpened():
            raise OSError(f"Cannot open video: {self.path}")
        fps = cap.get(cv2.CAP_PROP_FPS)
        total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        limit = min(self.max_frames, total) if self.max_frames is not None else total

        # Pre-allocate the full output tensor so we never hold a separate list of
        # uint8 arrays AND a float32 copy at the same time.  Each frame is decoded
        # straight into its slot; peak extra RAM is just one frame's worth of CV
        # buffer (~6 MB at 1080p) instead of the entire video twice over.
        frames = torch.empty(limit, 3, height, width, dtype=torch.float32)
        count = 0
        with tqdm(total=limit, desc="Reading", unit="frame", leave=False) as bar:
            while count < limit:
                ret, frame = cap.read()
                if not ret:
                    break
                frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frames[count] = torch.from_numpy(
                    frame_rgb.astype(np.float32) * (1.0 / 255.0)
                ).permute(2, 0, 1)
                count += 1
                bar.update(1)
        cap.release()

        if count == 0:
            raise RuntimeError(f"No frames could be read from {self.path}")

        return frames[:count], fps

metadata property

Return {"fps": float, "n_frames": int, "height": int, "width": int}.

read()

Read video frames.

Returns:

Type Description
Tensor

(frames, fps) where frames is (T, C, H, W) float32

float

tensor in [0, 1] on self.device.

Source code in src/pyevm/io/video.py
def read(self) -> tuple[torch.Tensor, float]:
    """Read video frames.

    Returns:
        ``(frames, fps)`` where *frames* is ``(T, C, H, W)`` float32
        tensor in ``[0, 1]`` on *self.device*.
    """
    try:
        frames, fps = self._read_torchcodec()
        logger.info(f"VideoReader: read {frames.shape[0]} frames via torchcodec")
    except Exception as exc:
        logger.debug(f"torchcodec unavailable ({exc}), falling back to OpenCV")
        frames, fps = self._read_opencv()
        logger.info(f"VideoReader: read {frames.shape[0]} frames via OpenCV")

    return frames, fps

stream()

Yield frames one at a time as (C, H, W) float32 tensors on self.device.

Memory cost is constant — only one decoded frame lives in RAM at a time, regardless of video length. Use this for large videos where :meth:read would exhaust available memory.

Uses torchcodec when available (GPU-accelerated), falling back to OpenCV.

Source code in src/pyevm/io/video.py
def stream(self) -> Generator[torch.Tensor, None, None]:
    """Yield frames one at a time as ``(C, H, W)`` float32 tensors on *self.device*.

    Memory cost is constant — only one decoded frame lives in RAM at a
    time, regardless of video length.  Use this for large videos where
    :meth:`read` would exhaust available memory.

    Uses *torchcodec* when available (GPU-accelerated), falling back to
    OpenCV.
    """
    try:
        yield from self._stream_torchcodec()
        return
    except Exception as exc:
        logger.debug(f"torchcodec stream unavailable ({exc}), falling back to OpenCV")

    yield from self._stream_opencv()

pyevm.io.video.VideoWriter

Write a tensor to a video file.

Prefers piping frames through FFmpeg for better codec support and hardware-accelerated encoding. Falls back to cv2.VideoWriter.

Parameters:

Name Type Description Default
path str | Path

Output file path (.mp4 recommended).

required
fps float

Frames per second.

required
use_ffmpeg bool

Try FFmpeg first (default True).

True
Source code in src/pyevm/io/video.py
class VideoWriter:
    """Write a tensor to a video file.

    Prefers piping frames through FFmpeg for better codec support and
    hardware-accelerated encoding.  Falls back to ``cv2.VideoWriter``.

    Args:
        path: Output file path (.mp4 recommended).
        fps: Frames per second.
        use_ffmpeg: Try FFmpeg first (default ``True``).
    """

    def __init__(
        self,
        path: str | Path,
        fps: float,
        use_ffmpeg: bool = True,
    ) -> None:
        self.path = Path(path)
        self.fps = fps
        self.use_ffmpeg = use_ffmpeg

    def write(self, frames: torch.Tensor) -> None:
        """Write *frames* to disk.

        Args:
            frames: ``(T, C, H, W)`` float tensor in ``[0, 1]`` or uint8 in
                ``[0, 255]``.
        """
        frames = frames.cpu()
        if frames.is_floating_point():
            frames_u8 = (frames.clamp(0, 1) * 255).byte()
        else:
            frames_u8 = frames.byte()

        T, C, H, W = frames_u8.shape

        if self.use_ffmpeg and shutil.which("ffmpeg") is not None:
            self._write_ffmpeg(frames_u8, H, W)
        else:
            logger.debug("FFmpeg not found; falling back to OpenCV VideoWriter")
            self._write_opencv(frames_u8, H, W)

    def write_stream(
        self,
        frames: Iterable[torch.Tensor],
        height: int,
        width: int,
        n_frames: int | None = None,
    ) -> None:
        """Write frames from a generator to disk without buffering the full video.

        Args:
            frames: Iterable of ``(C, H, W)`` float32 or uint8 tensors.
            height: Frame height in pixels (needed to open the encoder upfront).
            width: Frame width in pixels.
            n_frames: Total frame count, used only for the progress bar.
        """
        self.path.parent.mkdir(parents=True, exist_ok=True)
        if self.use_ffmpeg and shutil.which("ffmpeg") is not None:
            self._stream_ffmpeg(frames, height, width, n_frames)
        else:
            logger.debug("FFmpeg not found; falling back to OpenCV VideoWriter")
            self._stream_opencv(frames, height, width, n_frames)

    def _stream_ffmpeg(
        self,
        frames: Iterable[torch.Tensor],
        H: int,
        W: int,
        n_frames: int | None,
    ) -> None:
        """Open an FFmpeg pipe and feed frames one at a time."""
        cmd = [
            "ffmpeg",
            "-y",
            "-f",
            "rawvideo",
            "-vcodec",
            "rawvideo",
            "-pix_fmt",
            "rgb24",
            "-s",
            f"{W}x{H}",
            "-r",
            str(self.fps),
            "-i",
            "pipe:0",
            "-vcodec",
            "libx264",
            "-pix_fmt",
            "yuv420p",
            "-crf",
            "18",
            str(self.path),
        ]
        logger.debug(f"VideoWriter stream: FFmpeg command: {' '.join(cmd)}")
        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
        count = 0
        try:
            with tqdm(
                total=n_frames, desc="  Writing", unit="frame", position=0, leave=True
            ) as bar:
                for frame in frames:
                    frame_u8 = (frame.cpu().clamp(0, 1) * 255).byte()
                    frame_np = frame_u8.permute(1, 2, 0).numpy()
                    proc.stdin.write(frame_np.tobytes())  # type: ignore[union-attr]
                    count += 1
                    bar.update(1)
            proc.stdin.close()  # type: ignore[union-attr]
            _, stderr = proc.communicate()
            if proc.returncode != 0:
                raise RuntimeError(f"FFmpeg failed: {stderr.decode()}")
        except Exception:
            proc.kill()
            raise
        logger.info(f"VideoWriter: saved {count} frames to {self.path}")

    def _stream_opencv(
        self,
        frames: Iterable[torch.Tensor],
        H: int,
        W: int,
        n_frames: int | None,
    ) -> None:
        """Write frames one at a time via cv2.VideoWriter."""
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        writer = cv2.VideoWriter(str(self.path), fourcc, self.fps, (W, H))
        if not writer.isOpened():
            raise OSError(f"Cannot open VideoWriter for {self.path}")
        count = 0
        with tqdm(total=n_frames, desc="  Writing", unit="frame", position=0, leave=True) as bar:
            for frame in frames:
                frame_u8 = (frame.cpu().clamp(0, 1) * 255).byte()
                frame_np = frame_u8.permute(1, 2, 0).numpy()
                frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
                writer.write(frame_bgr)
                count += 1
                bar.update(1)
        writer.release()
        logger.info(f"VideoWriter(OpenCV): saved {count} frames to {self.path}")

    def _write_ffmpeg(self, frames_u8: torch.Tensor, H: int, W: int) -> None:
        """Pipe raw RGB24 frames to FFmpeg."""
        logger.debug(f"VideoWriter: writing {frames_u8.shape[0]} frames via FFmpeg → {self.path}")
        self.path.parent.mkdir(parents=True, exist_ok=True)

        # Try hardware-accelerated encoder; fall back to libx264
        cmd = [
            "ffmpeg",
            "-y",
            "-f",
            "rawvideo",
            "-vcodec",
            "rawvideo",
            "-pix_fmt",
            "rgb24",
            "-s",
            f"{W}x{H}",
            "-r",
            str(self.fps),
            "-i",
            "pipe:0",
            "-vcodec",
            "libx264",
            "-pix_fmt",
            "yuv420p",
            "-crf",
            "18",
            str(self.path),
        ]
        logger.debug(f"FFmpeg command: {' '.join(cmd)}")

        proc = subprocess.Popen(cmd, stdin=subprocess.PIPE, stderr=subprocess.PIPE)
        try:
            with tqdm(
                total=frames_u8.shape[0], desc="  Writing", unit="frame", position=0, leave=True
            ) as bar:
                for t in range(frames_u8.shape[0]):
                    frame_np = frames_u8[t].permute(1, 2, 0).numpy()
                    proc.stdin.write(frame_np.tobytes())  # type: ignore[union-attr]
                    bar.update(1)
            proc.stdin.close()  # type: ignore[union-attr]
            _, stderr = proc.communicate()
            if proc.returncode != 0:
                raise RuntimeError(f"FFmpeg failed: {stderr.decode()}")
        except Exception:
            proc.kill()
            raise

        logger.info(f"VideoWriter: saved {frames_u8.shape[0]} frames to {self.path}")

    def _write_opencv(self, frames_u8: torch.Tensor, H: int, W: int) -> None:
        """Write using cv2.VideoWriter (fallback)."""
        self.path.parent.mkdir(parents=True, exist_ok=True)
        fourcc = cv2.VideoWriter_fourcc(*"mp4v")
        writer = cv2.VideoWriter(str(self.path), fourcc, self.fps, (W, H))
        if not writer.isOpened():
            raise OSError(f"Cannot open VideoWriter for {self.path}")
        with tqdm(
            total=frames_u8.shape[0], desc="  Writing", unit="frame", position=0, leave=True
        ) as bar:
            for t in range(frames_u8.shape[0]):
                frame_np = frames_u8[t].permute(1, 2, 0).numpy()  # (H, W, C) RGB
                frame_bgr = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
                writer.write(frame_bgr)
                bar.update(1)
        writer.release()
        logger.info(f"VideoWriter(OpenCV): saved {frames_u8.shape[0]} frames to {self.path}")

write(frames)

Write frames to disk.

Parameters:

Name Type Description Default
frames Tensor

(T, C, H, W) float tensor in [0, 1] or uint8 in [0, 255].

required
Source code in src/pyevm/io/video.py
def write(self, frames: torch.Tensor) -> None:
    """Write *frames* to disk.

    Args:
        frames: ``(T, C, H, W)`` float tensor in ``[0, 1]`` or uint8 in
            ``[0, 255]``.
    """
    frames = frames.cpu()
    if frames.is_floating_point():
        frames_u8 = (frames.clamp(0, 1) * 255).byte()
    else:
        frames_u8 = frames.byte()

    T, C, H, W = frames_u8.shape

    if self.use_ffmpeg and shutil.which("ffmpeg") is not None:
        self._write_ffmpeg(frames_u8, H, W)
    else:
        logger.debug("FFmpeg not found; falling back to OpenCV VideoWriter")
        self._write_opencv(frames_u8, H, W)

write_stream(frames, height, width, n_frames=None)

Write frames from a generator to disk without buffering the full video.

Parameters:

Name Type Description Default
frames Iterable[Tensor]

Iterable of (C, H, W) float32 or uint8 tensors.

required
height int

Frame height in pixels (needed to open the encoder upfront).

required
width int

Frame width in pixels.

required
n_frames int | None

Total frame count, used only for the progress bar.

None
Source code in src/pyevm/io/video.py
def write_stream(
    self,
    frames: Iterable[torch.Tensor],
    height: int,
    width: int,
    n_frames: int | None = None,
) -> None:
    """Write frames from a generator to disk without buffering the full video.

    Args:
        frames: Iterable of ``(C, H, W)`` float32 or uint8 tensors.
        height: Frame height in pixels (needed to open the encoder upfront).
        width: Frame width in pixels.
        n_frames: Total frame count, used only for the progress bar.
    """
    self.path.parent.mkdir(parents=True, exist_ok=True)
    if self.use_ffmpeg and shutil.which("ffmpeg") is not None:
        self._stream_ffmpeg(frames, height, width, n_frames)
    else:
        logger.debug("FFmpeg not found; falling back to OpenCV VideoWriter")
        self._stream_opencv(frames, height, width, n_frames)

Pyramids

pyevm.pyramids.gaussian.GaussianPyramid

Multi-scale Gaussian pyramid.

Frames are expected as (B, C, H, W) float tensors, values in [0, 1].

Parameters:

Name Type Description Default
n_levels int

Number of pyramid levels (including the original).

6
device device | None

Compute device.

None
dtype dtype

Floating-point dtype (default torch.float32).

float32
Source code in src/pyevm/pyramids/gaussian.py
class GaussianPyramid:
    """Multi-scale Gaussian pyramid.

    Frames are expected as ``(B, C, H, W)`` float tensors, values in ``[0, 1]``.

    Args:
        n_levels: Number of pyramid levels (including the original).
        device: Compute device.
        dtype: Floating-point dtype (default ``torch.float32``).
    """

    def __init__(
        self,
        n_levels: int = 6,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.n_levels = n_levels
        self.device = device or torch.device("cpu")
        self.dtype = dtype
        self._kernel = _gaussian_kernel(self.device, self.dtype)
        logger.debug(f"GaussianPyramid: {n_levels} levels on {self.device}")

    def build(self, frame: torch.Tensor) -> list[torch.Tensor]:
        """Decompose *frame* into a Gaussian pyramid.

        Args:
            frame: ``(B, C, H, W)`` or ``(C, H, W)`` tensor.

        Returns:
            List of tensors from finest (level 0 = original) to coarsest.
        """
        if frame.dim() == 3:
            frame = frame.unsqueeze(0)

        frame = frame.to(device=self.device, dtype=self.dtype)
        levels: list[torch.Tensor] = [frame]

        current = frame
        for i in range(1, self.n_levels):
            current = _blur_downsample(current, self._kernel)
            levels.append(current)
            logger.debug(f"Gaussian level {i}: {tuple(current.shape)}")

        return levels

    def collapse(self, levels: list[torch.Tensor]) -> torch.Tensor:
        """Reconstruct from pyramid by upsampling the coarsest level.

        This simply returns the upsampled coarsest level (level 0 = original
        resolution).  For reconstruction with residuals use
        :class:`LaplacianPyramid`.

        Returns:
            ``(B, C, H, W)`` tensor at original resolution.
        """
        result = levels[-1]
        for i in range(len(levels) - 2, -1, -1):
            target = levels[i]
            result = _upsample_blur(result, self._kernel, target.shape[2], target.shape[3])
        return result

build(frame)

Decompose frame into a Gaussian pyramid.

Parameters:

Name Type Description Default
frame Tensor

(B, C, H, W) or (C, H, W) tensor.

required

Returns:

Type Description
list[Tensor]

List of tensors from finest (level 0 = original) to coarsest.

Source code in src/pyevm/pyramids/gaussian.py
def build(self, frame: torch.Tensor) -> list[torch.Tensor]:
    """Decompose *frame* into a Gaussian pyramid.

    Args:
        frame: ``(B, C, H, W)`` or ``(C, H, W)`` tensor.

    Returns:
        List of tensors from finest (level 0 = original) to coarsest.
    """
    if frame.dim() == 3:
        frame = frame.unsqueeze(0)

    frame = frame.to(device=self.device, dtype=self.dtype)
    levels: list[torch.Tensor] = [frame]

    current = frame
    for i in range(1, self.n_levels):
        current = _blur_downsample(current, self._kernel)
        levels.append(current)
        logger.debug(f"Gaussian level {i}: {tuple(current.shape)}")

    return levels

collapse(levels)

Reconstruct from pyramid by upsampling the coarsest level.

This simply returns the upsampled coarsest level (level 0 = original resolution). For reconstruction with residuals use :class:LaplacianPyramid.

Returns:

Type Description
Tensor

(B, C, H, W) tensor at original resolution.

Source code in src/pyevm/pyramids/gaussian.py
def collapse(self, levels: list[torch.Tensor]) -> torch.Tensor:
    """Reconstruct from pyramid by upsampling the coarsest level.

    This simply returns the upsampled coarsest level (level 0 = original
    resolution).  For reconstruction with residuals use
    :class:`LaplacianPyramid`.

    Returns:
        ``(B, C, H, W)`` tensor at original resolution.
    """
    result = levels[-1]
    for i in range(len(levels) - 2, -1, -1):
        target = levels[i]
        result = _upsample_blur(result, self._kernel, target.shape[2], target.shape[3])
    return result

pyevm.pyramids.laplacian.LaplacianPyramid

Multi-scale Laplacian pyramid (difference-of-Gaussians).

Each level stores the band-pass detail image; the coarsest level stores the low-pass residual (a Gaussian level).

Parameters:

Name Type Description Default
n_levels int

Number of pyramid levels.

6
device device | None

Compute device.

None
dtype dtype

Floating-point dtype.

float32
Source code in src/pyevm/pyramids/laplacian.py
class LaplacianPyramid:
    """Multi-scale Laplacian pyramid (difference-of-Gaussians).

    Each level stores the *band-pass* detail image; the coarsest level stores
    the low-pass residual (a Gaussian level).

    Args:
        n_levels: Number of pyramid levels.
        device: Compute device.
        dtype: Floating-point dtype.
    """

    def __init__(
        self,
        n_levels: int = 6,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.n_levels = n_levels
        self.device = device or torch.device("cpu")
        self.dtype = dtype
        self._kernel = _gaussian_kernel(self.device, self.dtype)
        logger.debug(f"LaplacianPyramid: {n_levels} levels on {self.device}")

    def build(self, frame: torch.Tensor) -> list[torch.Tensor]:
        """Decompose *frame* into a Laplacian pyramid.

        Args:
            frame: ``(B, C, H, W)`` or ``(C, H, W)`` tensor.

        Returns:
            List of ``n_levels`` tensors. Levels 0 … n-2 are band-pass detail
            images; level n-1 is the low-pass Gaussian residual.
        """
        if frame.dim() == 3:
            frame = frame.unsqueeze(0)

        frame = frame.to(device=self.device, dtype=self.dtype)

        # Build Gaussian pyramid first
        gaussian: list[torch.Tensor] = [frame]
        current = frame
        for _ in range(1, self.n_levels):
            current = _blur_downsample(current, self._kernel)
            gaussian.append(current)

        # Laplacian levels = Gaussian[i] − upsample(Gaussian[i+1])
        laplacian: list[torch.Tensor] = []
        for i in range(self.n_levels - 1):
            g_up = _upsample_blur(
                gaussian[i + 1],
                self._kernel,
                gaussian[i].shape[2],
                gaussian[i].shape[3],
            )
            lap = gaussian[i] - g_up
            laplacian.append(lap)
            logger.debug(f"Laplacian level {i}: {tuple(lap.shape)}")

        # Append coarsest Gaussian as residual
        laplacian.append(gaussian[-1])
        logger.debug(f"Laplacian residual (level {self.n_levels - 1}): {tuple(gaussian[-1].shape)}")

        return laplacian

    def collapse(self, levels: list[torch.Tensor]) -> torch.Tensor:
        """Reconstruct frame from Laplacian pyramid.

        Args:
            levels: Pyramid returned by :meth:`build` (possibly modified).

        Returns:
            ``(B, C, H, W)`` tensor at original resolution.
        """
        result = levels[-1]
        for i in range(len(levels) - 2, -1, -1):
            result = _upsample_blur(result, self._kernel, levels[i].shape[2], levels[i].shape[3])
            result = result + levels[i]
        return result

build(frame)

Decompose frame into a Laplacian pyramid.

Parameters:

Name Type Description Default
frame Tensor

(B, C, H, W) or (C, H, W) tensor.

required

Returns:

Type Description
list[Tensor]

List of n_levels tensors. Levels 0 … n-2 are band-pass detail

list[Tensor]

images; level n-1 is the low-pass Gaussian residual.

Source code in src/pyevm/pyramids/laplacian.py
def build(self, frame: torch.Tensor) -> list[torch.Tensor]:
    """Decompose *frame* into a Laplacian pyramid.

    Args:
        frame: ``(B, C, H, W)`` or ``(C, H, W)`` tensor.

    Returns:
        List of ``n_levels`` tensors. Levels 0 … n-2 are band-pass detail
        images; level n-1 is the low-pass Gaussian residual.
    """
    if frame.dim() == 3:
        frame = frame.unsqueeze(0)

    frame = frame.to(device=self.device, dtype=self.dtype)

    # Build Gaussian pyramid first
    gaussian: list[torch.Tensor] = [frame]
    current = frame
    for _ in range(1, self.n_levels):
        current = _blur_downsample(current, self._kernel)
        gaussian.append(current)

    # Laplacian levels = Gaussian[i] − upsample(Gaussian[i+1])
    laplacian: list[torch.Tensor] = []
    for i in range(self.n_levels - 1):
        g_up = _upsample_blur(
            gaussian[i + 1],
            self._kernel,
            gaussian[i].shape[2],
            gaussian[i].shape[3],
        )
        lap = gaussian[i] - g_up
        laplacian.append(lap)
        logger.debug(f"Laplacian level {i}: {tuple(lap.shape)}")

    # Append coarsest Gaussian as residual
    laplacian.append(gaussian[-1])
    logger.debug(f"Laplacian residual (level {self.n_levels - 1}): {tuple(gaussian[-1].shape)}")

    return laplacian

collapse(levels)

Reconstruct frame from Laplacian pyramid.

Parameters:

Name Type Description Default
levels list[Tensor]

Pyramid returned by :meth:build (possibly modified).

required

Returns:

Type Description
Tensor

(B, C, H, W) tensor at original resolution.

Source code in src/pyevm/pyramids/laplacian.py
def collapse(self, levels: list[torch.Tensor]) -> torch.Tensor:
    """Reconstruct frame from Laplacian pyramid.

    Args:
        levels: Pyramid returned by :meth:`build` (possibly modified).

    Returns:
        ``(B, C, H, W)`` tensor at original resolution.
    """
    result = levels[-1]
    for i in range(len(levels) - 2, -1, -1):
        result = _upsample_blur(result, self._kernel, levels[i].shape[2], levels[i].shape[3])
        result = result + levels[i]
    return result

pyevm.pyramids.steerable.SteerablePyramid

Complex steerable pyramid (tight frame, near-perfect reconstruction).

Parameters:

Name Type Description Default
n_scales int

Number of octave levels.

4
n_orientations int

Oriented sub-bands per scale (2, 4, 6, or 8).

6
device device | None

Compute device.

None
dtype dtype

Real floating-point dtype.

float32
Source code in src/pyevm/pyramids/steerable.py
class SteerablePyramid:
    """Complex steerable pyramid (tight frame, near-perfect reconstruction).

    Args:
        n_scales: Number of octave levels.
        n_orientations: Oriented sub-bands per scale (2, 4, 6, or 8).
        device: Compute device.
        dtype: Real floating-point dtype.
    """

    def __init__(
        self,
        n_scales: int = 4,
        n_orientations: int = 6,
        device: torch.device | None = None,
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self.n_scales = n_scales
        self.n_orientations = n_orientations
        self.device = device or torch.device("cpu")
        self.dtype = dtype
        logger.debug(
            f"SteerablePyramid: {n_scales} scales × {n_orientations} orientations on {self.device}"
        )

    # ------------------------------------------------------------------
    # Public API
    # ------------------------------------------------------------------

    def build(self, frame: torch.Tensor) -> dict:
        """Decompose one or more single-channel frames into the complex steerable pyramid.

        Args:
            frame: ``(H, W)``, ``(1, H, W)`` / ``(1, 1, H, W)`` for a single
                   frame, or ``(T, H, W)`` to process a batch of *T* frames in
                   one GPU call.

        Returns:
            Dictionary with keys:

            * ``"highpass"``  – ``(..., H, W)`` complex tensor (outer HP residual).
            * ``"lowpass"``   – ``(..., H', W')`` real tensor (coarsest LP residual).
            * ``"subbands"``  – ``[n_scales][n_orientations]`` complex tensors of
                               shape ``(..., H_s, W_s)``.
            * ``"sizes"``     – ``(H_s, W_s)`` at each scale.
        """
        frame = frame.to(device=self.device, dtype=self.dtype)
        # Collapse trivial leading singleton dims so a single (1,H,W) or (1,1,H,W)
        # still behaves as (H,W); a genuine (T,H,W) batch is left untouched.
        while frame.dim() > 2 and frame.shape[0] == 1:
            frame = frame.squeeze(0)

        H, W = frame.shape[-2], frame.shape[-1]
        dft = torch.fft.fft2(frame)  # (..., H, W) complex

        # --- Outer LP/HP split ---
        radius0, _ = _polar_grid(H, W, self.device, self.dtype)
        lo0_vals = _lo0(radius0)
        hi0_vals = _hi0(lo0_vals)

        highpass = torch.fft.ifft2(dft * hi0_vals)  # (..., H, W) complex
        logger.debug(f"Highpass residual: {tuple(highpass.shape)}")

        # Scale loop starts from the LP component only
        current_dft = dft * lo0_vals
        current_h, current_w = H, W

        subbands: list[list[torch.Tensor]] = []
        sizes: list[tuple[int, int]] = []

        for scale in range(self.n_scales):
            radius_s, angle_s = _polar_grid(current_h, current_w, self.device, self.dtype)
            lp_vals = _lp(radius_s)

            scale_bands: list[torch.Tensor] = []
            for orient in range(self.n_orientations):
                filt = _oriented_filter(radius_s, angle_s, orient, self.n_orientations, lp_vals)
                subband = torch.fft.ifft2(current_dft * filt)
                scale_bands.append(subband)
                logger.debug(f"Scale {scale}, orientation {orient}: {tuple(subband.shape)}")

            subbands.append(scale_bands)
            sizes.append((current_h, current_w))

            # Pass LP component to next scale (downsample)
            current_h //= 2
            current_w //= 2
            current_dft = self._downsample_dft(current_dft * lp_vals, current_h, current_w)

        lowpass = torch.fft.ifft2(current_dft).real
        logger.debug(f"Lowpass residual: {tuple(lowpass.shape)}")

        return {
            "highpass": highpass,
            "lowpass": lowpass,
            "subbands": subbands,
            "sizes": sizes,
        }

    def collapse(self, pyramid: dict) -> torch.Tensor:
        """Reconstruct a frame from a (possibly phase-modified) pyramid.

        Args:
            pyramid: Dictionary as returned by :meth:`build`.

        Returns:
            Reconstructed ``(H, W)`` real tensor.
        """
        subbands = pyramid["subbands"]
        sizes = pyramid["sizes"]
        highpass = pyramid["highpass"]
        lowpass = pyramid["lowpass"]

        # --- Inner scale reconstruction (coarse → fine) ---
        current_dft = torch.fft.fft2(lowpass.to(dtype=self.dtype))

        for scale in range(self.n_scales - 1, -1, -1):
            target_h, target_w = sizes[scale]
            current_dft = self._upsample_dft(current_dft, target_h, target_w)

            radius_s, angle_s = _polar_grid(target_h, target_w, self.device, self.dtype)
            lp_vals = _lp(radius_s)

            # LP² component: apply inner LP to the upsampled LP signal
            current_dft = current_dft * lp_vals

            # BP component: sum oriented sub-band contributions.
            # Analysis filter H_k is one-sided (primary lobe only), so
            # Σ_k H_k² = bp² · Σ_k g_k² = bp² · (1/2).
            # The factor-of-2 here compensates → net contribution = bp²·X_s,
            # matching reconSCFpyrGen.m: tempDFT = 2 * fft2(bandVals) * filter.
            for orient in range(self.n_orientations):
                filt = _oriented_filter(radius_s, angle_s, orient, self.n_orientations, lp_vals)
                sb_dft = torch.fft.fft2(subbands[scale][orient])
                current_dft = current_dft + 2.0 * sb_dft * filt

        # --- Outer LP² contribution ---
        # Scale reconstruction gave dft·lo0; apply lo0 again → dft·lo0²
        H, W = sizes[0]
        radius0, _ = _polar_grid(H, W, self.device, self.dtype)
        lo0_vals = _lo0(radius0)
        current_dft = current_dft * lo0_vals

        # --- Outer HP² contribution: dft·hi0² = dft·(1 − lo0²) ---
        hi0_vals = _hi0(lo0_vals)
        current_dft = current_dft + torch.fft.fft2(highpass) * hi0_vals

        result = torch.fft.ifft2(current_dft).real
        logger.debug(f"Collapsed pyramid: {tuple(result.shape)}")
        return result

    # ------------------------------------------------------------------
    # DFT up/downsample helpers
    # ------------------------------------------------------------------

    @staticmethod
    def _downsample_dft(dft: torch.Tensor, new_h: int, new_w: int) -> torch.Tensor:
        """Crop DFT spectrum to ``(new_h, new_w)`` — bandlimited downsampling.

        Works for both ``(H, W)`` and batched ``(..., H, W)`` inputs.
        """
        H, W = dft.shape[-2], dft.shape[-1]
        shifted = torch.fft.fftshift(dft, dim=(-2, -1))
        ch, cw = H // 2, W // 2
        half_h, half_w = new_h // 2, new_w // 2
        cropped = shifted[
            ...,
            ch - half_h : ch - half_h + new_h,
            cw - half_w : cw - half_w + new_w,
        ]
        return torch.fft.ifftshift(cropped, dim=(-2, -1))

    @staticmethod
    def _upsample_dft(dft: torch.Tensor, new_h: int, new_w: int) -> torch.Tensor:
        """Zero-pad DFT spectrum to ``(new_h, new_w)`` — bandlimited upsampling.

        The DC component (after fftshift) sits at ``H // 2`` in the source and
        must land at ``new_h // 2`` in the destination.  Using
        ``ph = new_h // 2 - H // 2`` (rather than ``(new_h - H) // 2``)
        handles the case where *H* is odd and *new_h = 2 H* correctly — the
        naive formula is off by one for that combination, which creates a
        one-pixel phase error that propagates as visible artefacts (e.g. a
        black bar across 1080 p video where height reaches 135 px at scale 3).

        Works for both ``(H, W)`` and batched ``(..., H, W)`` inputs.
        """
        H, W = dft.shape[-2], dft.shape[-1]
        shifted = torch.fft.fftshift(dft, dim=(-2, -1))
        padded = torch.zeros(*dft.shape[:-2], new_h, new_w, dtype=dft.dtype, device=dft.device)
        ph = new_h // 2 - H // 2
        pw = new_w // 2 - W // 2
        padded[..., ph : ph + H, pw : pw + W] = shifted
        return torch.fft.ifftshift(padded, dim=(-2, -1))

build(frame)

Decompose one or more single-channel frames into the complex steerable pyramid.

Parameters:

Name Type Description Default
frame Tensor

(H, W), (1, H, W) / (1, 1, H, W) for a single frame, or (T, H, W) to process a batch of T frames in one GPU call.

required

Returns:

Type Description
dict

Dictionary with keys:

dict
  • "highpass"(..., H, W) complex tensor (outer HP residual).
dict
  • "lowpass"(..., H', W') real tensor (coarsest LP residual).
dict
  • "subbands"[n_scales][n_orientations] complex tensors of shape (..., H_s, W_s).
dict
  • "sizes"(H_s, W_s) at each scale.
Source code in src/pyevm/pyramids/steerable.py
def build(self, frame: torch.Tensor) -> dict:
    """Decompose one or more single-channel frames into the complex steerable pyramid.

    Args:
        frame: ``(H, W)``, ``(1, H, W)`` / ``(1, 1, H, W)`` for a single
               frame, or ``(T, H, W)`` to process a batch of *T* frames in
               one GPU call.

    Returns:
        Dictionary with keys:

        * ``"highpass"``  – ``(..., H, W)`` complex tensor (outer HP residual).
        * ``"lowpass"``   – ``(..., H', W')`` real tensor (coarsest LP residual).
        * ``"subbands"``  – ``[n_scales][n_orientations]`` complex tensors of
                           shape ``(..., H_s, W_s)``.
        * ``"sizes"``     – ``(H_s, W_s)`` at each scale.
    """
    frame = frame.to(device=self.device, dtype=self.dtype)
    # Collapse trivial leading singleton dims so a single (1,H,W) or (1,1,H,W)
    # still behaves as (H,W); a genuine (T,H,W) batch is left untouched.
    while frame.dim() > 2 and frame.shape[0] == 1:
        frame = frame.squeeze(0)

    H, W = frame.shape[-2], frame.shape[-1]
    dft = torch.fft.fft2(frame)  # (..., H, W) complex

    # --- Outer LP/HP split ---
    radius0, _ = _polar_grid(H, W, self.device, self.dtype)
    lo0_vals = _lo0(radius0)
    hi0_vals = _hi0(lo0_vals)

    highpass = torch.fft.ifft2(dft * hi0_vals)  # (..., H, W) complex
    logger.debug(f"Highpass residual: {tuple(highpass.shape)}")

    # Scale loop starts from the LP component only
    current_dft = dft * lo0_vals
    current_h, current_w = H, W

    subbands: list[list[torch.Tensor]] = []
    sizes: list[tuple[int, int]] = []

    for scale in range(self.n_scales):
        radius_s, angle_s = _polar_grid(current_h, current_w, self.device, self.dtype)
        lp_vals = _lp(radius_s)

        scale_bands: list[torch.Tensor] = []
        for orient in range(self.n_orientations):
            filt = _oriented_filter(radius_s, angle_s, orient, self.n_orientations, lp_vals)
            subband = torch.fft.ifft2(current_dft * filt)
            scale_bands.append(subband)
            logger.debug(f"Scale {scale}, orientation {orient}: {tuple(subband.shape)}")

        subbands.append(scale_bands)
        sizes.append((current_h, current_w))

        # Pass LP component to next scale (downsample)
        current_h //= 2
        current_w //= 2
        current_dft = self._downsample_dft(current_dft * lp_vals, current_h, current_w)

    lowpass = torch.fft.ifft2(current_dft).real
    logger.debug(f"Lowpass residual: {tuple(lowpass.shape)}")

    return {
        "highpass": highpass,
        "lowpass": lowpass,
        "subbands": subbands,
        "sizes": sizes,
    }

collapse(pyramid)

Reconstruct a frame from a (possibly phase-modified) pyramid.

Parameters:

Name Type Description Default
pyramid dict

Dictionary as returned by :meth:build.

required

Returns:

Type Description
Tensor

Reconstructed (H, W) real tensor.

Source code in src/pyevm/pyramids/steerable.py
def collapse(self, pyramid: dict) -> torch.Tensor:
    """Reconstruct a frame from a (possibly phase-modified) pyramid.

    Args:
        pyramid: Dictionary as returned by :meth:`build`.

    Returns:
        Reconstructed ``(H, W)`` real tensor.
    """
    subbands = pyramid["subbands"]
    sizes = pyramid["sizes"]
    highpass = pyramid["highpass"]
    lowpass = pyramid["lowpass"]

    # --- Inner scale reconstruction (coarse → fine) ---
    current_dft = torch.fft.fft2(lowpass.to(dtype=self.dtype))

    for scale in range(self.n_scales - 1, -1, -1):
        target_h, target_w = sizes[scale]
        current_dft = self._upsample_dft(current_dft, target_h, target_w)

        radius_s, angle_s = _polar_grid(target_h, target_w, self.device, self.dtype)
        lp_vals = _lp(radius_s)

        # LP² component: apply inner LP to the upsampled LP signal
        current_dft = current_dft * lp_vals

        # BP component: sum oriented sub-band contributions.
        # Analysis filter H_k is one-sided (primary lobe only), so
        # Σ_k H_k² = bp² · Σ_k g_k² = bp² · (1/2).
        # The factor-of-2 here compensates → net contribution = bp²·X_s,
        # matching reconSCFpyrGen.m: tempDFT = 2 * fft2(bandVals) * filter.
        for orient in range(self.n_orientations):
            filt = _oriented_filter(radius_s, angle_s, orient, self.n_orientations, lp_vals)
            sb_dft = torch.fft.fft2(subbands[scale][orient])
            current_dft = current_dft + 2.0 * sb_dft * filt

    # --- Outer LP² contribution ---
    # Scale reconstruction gave dft·lo0; apply lo0 again → dft·lo0²
    H, W = sizes[0]
    radius0, _ = _polar_grid(H, W, self.device, self.dtype)
    lo0_vals = _lo0(radius0)
    current_dft = current_dft * lo0_vals

    # --- Outer HP² contribution: dft·hi0² = dft·(1 − lo0²) ---
    hi0_vals = _hi0(lo0_vals)
    current_dft = current_dft + torch.fft.fft2(highpass) * hi0_vals

    result = torch.fft.ifft2(current_dft).real
    logger.debug(f"Collapsed pyramid: {tuple(result.shape)}")
    return result

Filters

pyevm.filters.temporal.IdealBandpass

FFT-based ideal bandpass filter over the time axis, with optional notch stops.

Parameters:

Name Type Description Default
fps float

Frames per second of the video.

required
freq_low float

Lower cut-off frequency in Hz.

required
freq_high float

Upper cut-off frequency in Hz.

required
notch_freqs list[float] | None

Frequencies to notch out (Hz). Each notch zeros a symmetric window of width notch_width centred on the frequency.

None
notch_width float

Width of each notch in Hz (default 1.0).

1.0
Source code in src/pyevm/filters/temporal.py
class IdealBandpass:
    """FFT-based ideal bandpass filter over the time axis, with optional notch stops.

    Args:
        fps: Frames per second of the video.
        freq_low: Lower cut-off frequency in Hz.
        freq_high: Upper cut-off frequency in Hz.
        notch_freqs: Frequencies to notch out (Hz).  Each notch zeros a
            symmetric window of width ``notch_width`` centred on the frequency.
        notch_width: Width of each notch in Hz (default 1.0).
    """

    def __init__(
        self,
        fps: float,
        freq_low: float,
        freq_high: float,
        notch_freqs: list[float] | None = None,
        notch_width: float = 1.0,
    ) -> None:
        self.fps = fps
        self.freq_low = freq_low
        self.freq_high = freq_high
        self.notch_freqs = notch_freqs or []
        self.notch_width = notch_width
        logger.debug(
            f"IdealBandpass: {freq_low}{freq_high} Hz @ {fps} fps"
            + (f", notches={self.notch_freqs} ±{notch_width / 2} Hz" if self.notch_freqs else "")
        )

    def apply(self, signal: torch.Tensor) -> torch.Tensor:
        """Filter *signal* along its first (time) dimension.

        Args:
            signal: ``(T, ...)`` tensor where ``T`` is the number of frames.

        Returns:
            Bandpass-filtered (and notch-filtered) tensor with the same shape.
        """
        T = signal.shape[0]
        nyq = self.fps / 2.0
        if self.freq_low >= nyq:
            raise ValueError(
                f"freq_low={self.freq_low} Hz is at or above the Nyquist frequency "
                f"({nyq} Hz for fps={self.fps}). Use a lower frequency or a higher-fps video."
            )
        freqs = torch.fft.rfftfreq(T, d=1.0 / self.fps, device=signal.device)

        # FFT along time axis (dim 0)
        spectrum = torch.fft.rfft(signal.float(), dim=0)

        # Build combined mask: bandpass × notch stops
        mask = (freqs >= self.freq_low) & (freqs <= self.freq_high)
        half = self.notch_width / 2.0
        for nf in self.notch_freqs:
            mask = mask & ~((freqs >= nf - half) & (freqs <= nf + half))
        mask_f = mask.float()
        for _ in range(signal.dim() - 1):
            mask_f = mask_f.unsqueeze(-1)

        filtered_spectrum = spectrum * mask_f
        result = torch.fft.irfft(filtered_spectrum, n=T, dim=0)
        logger.debug(
            f"IdealBandpass: filtered signal shape {tuple(signal.shape)}{tuple(result.shape)}"
        )
        return result.to(dtype=signal.dtype)

apply(signal)

Filter signal along its first (time) dimension.

Parameters:

Name Type Description Default
signal Tensor

(T, ...) tensor where T is the number of frames.

required

Returns:

Type Description
Tensor

Bandpass-filtered (and notch-filtered) tensor with the same shape.

Source code in src/pyevm/filters/temporal.py
def apply(self, signal: torch.Tensor) -> torch.Tensor:
    """Filter *signal* along its first (time) dimension.

    Args:
        signal: ``(T, ...)`` tensor where ``T`` is the number of frames.

    Returns:
        Bandpass-filtered (and notch-filtered) tensor with the same shape.
    """
    T = signal.shape[0]
    nyq = self.fps / 2.0
    if self.freq_low >= nyq:
        raise ValueError(
            f"freq_low={self.freq_low} Hz is at or above the Nyquist frequency "
            f"({nyq} Hz for fps={self.fps}). Use a lower frequency or a higher-fps video."
        )
    freqs = torch.fft.rfftfreq(T, d=1.0 / self.fps, device=signal.device)

    # FFT along time axis (dim 0)
    spectrum = torch.fft.rfft(signal.float(), dim=0)

    # Build combined mask: bandpass × notch stops
    mask = (freqs >= self.freq_low) & (freqs <= self.freq_high)
    half = self.notch_width / 2.0
    for nf in self.notch_freqs:
        mask = mask & ~((freqs >= nf - half) & (freqs <= nf + half))
    mask_f = mask.float()
    for _ in range(signal.dim() - 1):
        mask_f = mask_f.unsqueeze(-1)

    filtered_spectrum = spectrum * mask_f
    result = torch.fft.irfft(filtered_spectrum, n=T, dim=0)
    logger.debug(
        f"IdealBandpass: filtered signal shape {tuple(signal.shape)}{tuple(result.shape)}"
    )
    return result.to(dtype=signal.dtype)

pyevm.filters.temporal.ButterworthBandpass

Butterworth IIR bandpass filter applied causally chunk-by-chunk, with optional IIR notch stops cascaded after the bandpass.

On CUDA/MPS devices the filter runs entirely on the accelerator (torch.jit.script loop, no CPU↔device roundtrip). On CPU the original scipy.signal.sosfilt path is used.

Filter state is maintained between :meth:apply_chunk calls so the result is numerically identical to processing the whole video at once.

Parameters:

Name Type Description Default
fps float

Frames per second of the video.

required
freq_low float

Lower cut-off frequency in Hz.

required
freq_high float

Upper cut-off frequency in Hz.

required
order int

Filter order (default 1 — matches reference MATLAB code).

1
notch_freqs list[float] | None

Frequencies to notch out (Hz). Each notch is a 2nd-order IIR notch filter with Q = freq / notch_width.

None
notch_width float

Bandwidth of each notch in Hz (default 1.0).

1.0
Source code in src/pyevm/filters/temporal.py
class ButterworthBandpass:
    """Butterworth IIR bandpass filter applied causally chunk-by-chunk, with
    optional IIR notch stops cascaded after the bandpass.

    On CUDA/MPS devices the filter runs entirely on the accelerator
    (``torch.jit.script`` loop, no CPU↔device roundtrip).  On CPU the
    original ``scipy.signal.sosfilt`` path is used.

    Filter state is maintained between :meth:`apply_chunk` calls so the
    result is numerically identical to processing the whole video at once.

    Args:
        fps: Frames per second of the video.
        freq_low: Lower cut-off frequency in Hz.
        freq_high: Upper cut-off frequency in Hz.
        order: Filter order (default 1 — matches reference MATLAB code).
        notch_freqs: Frequencies to notch out (Hz).  Each notch is a
            2nd-order IIR notch filter with Q = ``freq / notch_width``.
        notch_width: Bandwidth of each notch in Hz (default 1.0).
    """

    def __init__(
        self,
        fps: float,
        freq_low: float,
        freq_high: float,
        order: int = 1,
        notch_freqs: list[float] | None = None,
        notch_width: float = 1.0,
    ) -> None:
        self.fps = fps
        self.freq_low = freq_low
        self.freq_high = freq_high
        self.order = order
        self.notch_freqs = notch_freqs or []
        self.notch_width = notch_width

        nyq = fps / 2.0
        low = freq_low / nyq
        high = min(freq_high / nyq, 1.0 - 1e-6)
        if low >= 1.0:
            raise ValueError(
                f"freq_low={freq_low} Hz is at or above the Nyquist frequency "
                f"({nyq} Hz for fps={fps}). Use a lower frequency or a higher-fps video."
            )
        if low >= high:
            raise ValueError(
                f"freq_high={freq_high} Hz is too close to or above the Nyquist frequency "
                f"({nyq} Hz for fps={fps}). The passband [{freq_low}, {freq_high}] Hz is "
                f"not representable. Use a lower freq_high or a higher-fps video."
            )
        self._sos = butter(order, [low, high], btype="bandpass", output="sos")

        # Cascade 2nd-order IIR notch sections
        for nf in self.notch_freqs:
            w0 = nf / nyq
            if w0 >= 1.0:
                raise ValueError(
                    f"Notch frequency {nf} Hz is at or above the Nyquist frequency "
                    f"({nyq} Hz for fps={fps})."
                )
            Q = max(nf / notch_width, 0.5)  # quality factor; clamp to avoid degenerate filter
            b_n, a_n = iirnotch(w0, Q)
            self._sos = np.vstack([self._sos, tf2sos(b_n, a_n)])

        # CPU state (scipy)
        self._zi: np.ndarray | None = None
        # GPU state: one (s1, s2) pair per SOS section
        self._zi_gpu: list[tuple[torch.Tensor, torch.Tensor]] | None = None

        logger.debug(
            f"ButterworthBandpass order={order}: {freq_low}{freq_high} Hz @ {fps} fps"
            + (f", notches={self.notch_freqs} ±{notch_width / 2} Hz" if self.notch_freqs else "")
        )

    # ------------------------------------------------------------------
    # Batch mode (all frames at once — no persistent state)
    # ------------------------------------------------------------------

    def apply(self, signal: torch.Tensor) -> torch.Tensor:
        """Filter *signal* along its first (time) dimension (batch mode).

        Args:
            signal: ``(T, ...)`` tensor.

        Returns:
            Filtered tensor with the same shape.
        """
        if _GPU_IIR_AVAILABLE and signal.device.type != "cpu":
            return self._apply_gpu(signal)
        return self._apply_cpu(signal)

    def _apply_cpu(self, signal: torch.Tensor) -> torch.Tensor:
        original_shape = signal.shape
        T = signal.shape[0]
        flat = np.from_dlpack(signal.detach().float().cpu().contiguous()).reshape(T, -1)  # (T, N)

        flat_t = np.ascontiguousarray(flat.T)  # (N, T)
        filtered_t = sosfilt(self._sos, flat_t, axis=-1)
        filtered_c = np.ascontiguousarray(filtered_t.T, dtype=np.float32)  # (T, N)
        result = (
            torch.from_dlpack(filtered_c)
            .reshape(original_shape)
            .to(device=signal.device, dtype=signal.dtype)
        )
        logger.debug(f"ButterworthBandpass (CPU batch): filtered {tuple(signal.shape)}")
        return result

    def _apply_gpu(self, signal: torch.Tensor) -> torch.Tensor:
        """GPU batch mode: run the JIT loop from zero initial state."""
        assert _sos_step_gpu is not None
        original_shape = signal.shape
        T = signal.shape[0]
        x = signal.float().reshape(T, -1)  # (T, N)
        N = x.shape[1]

        y = x
        for sec in range(self._sos.shape[0]):
            b0, b1, b2 = (
                float(self._sos[sec, 0]),
                float(self._sos[sec, 1]),
                float(self._sos[sec, 2]),
            )
            a1, a2 = float(self._sos[sec, 4]), float(self._sos[sec, 5])
            s1 = torch.zeros(N, device=signal.device, dtype=torch.float32)
            s2 = torch.zeros(N, device=signal.device, dtype=torch.float32)
            y, s1, s2 = _sos_step_gpu(y, b0, b1, b2, a1, a2, s1, s2)

        logger.debug(f"ButterworthBandpass (GPU batch): filtered {tuple(signal.shape)}")
        return y.reshape(original_shape).to(dtype=signal.dtype)

    # ------------------------------------------------------------------
    # Streaming / chunk mode (state carries across calls)
    # ------------------------------------------------------------------

    def apply_chunk(self, signal: torch.Tensor) -> torch.Tensor:
        """Filter a chunk of frames along the time dimension, updating state.

        Equivalent to calling :meth:`step` T times in sequence; the IIR state
        is updated so the next call picks up seamlessly.

        Uses the GPU JIT path on CUDA/MPS devices (no PCIe roundtrip).

        Args:
            signal: ``(T, ...)`` tensor where ``T`` is the chunk length.

        Returns:
            Filtered tensor with the same shape.
        """
        if _GPU_IIR_AVAILABLE and signal.device.type != "cpu":
            return self._apply_chunk_gpu(signal)
        return self._apply_chunk_cpu(signal)

    def _apply_chunk_cpu(self, signal: torch.Tensor) -> torch.Tensor:
        original_shape = signal.shape
        T = signal.shape[0]
        if self._zi is None:
            self.reset(signal.shape[1:])

        flat = np.from_dlpack(signal.detach().float().cpu().contiguous()).reshape(T, -1)  # (T, P)

        filtered, self._zi = sosfilt(self._sos, flat, axis=0, zi=self._zi)
        out_c = np.ascontiguousarray(filtered, dtype=np.float32)
        return (
            torch.from_dlpack(out_c)
            .reshape(original_shape)
            .to(device=signal.device, dtype=signal.dtype)
        )

    def _apply_chunk_gpu(self, signal: torch.Tensor) -> torch.Tensor:
        """GPU streaming mode: JIT loop with persistent state across chunks."""
        assert _sos_step_gpu is not None
        original_shape = signal.shape
        T = signal.shape[0]
        x = signal.float().reshape(T, -1)  # (T, N)
        N = x.shape[1]

        # Lazy-init GPU state (one (s1, s2) pair per SOS section)
        if self._zi_gpu is None:
            self._zi_gpu = [
                (
                    torch.zeros(N, device=signal.device, dtype=torch.float32),
                    torch.zeros(N, device=signal.device, dtype=torch.float32),
                )
                for _ in range(self._sos.shape[0])
            ]

        y = x
        new_states: list[tuple[torch.Tensor, torch.Tensor]] = []
        for sec in range(self._sos.shape[0]):
            b0, b1, b2 = (
                float(self._sos[sec, 0]),
                float(self._sos[sec, 1]),
                float(self._sos[sec, 2]),
            )
            a1, a2 = float(self._sos[sec, 4]), float(self._sos[sec, 5])
            s1, s2 = self._zi_gpu[sec]
            y, s1, s2 = _sos_step_gpu(y, b0, b1, b2, a1, a2, s1, s2)
            new_states.append((s1, s2))

        self._zi_gpu = new_states
        return y.reshape(original_shape).to(dtype=signal.dtype)

    # ------------------------------------------------------------------
    # Streaming mode (one frame at a time) — CPU only
    # ------------------------------------------------------------------

    def reset(self, signal_shape: tuple[int, ...]) -> None:
        """Initialise CPU filter state for streaming on a signal of *signal_shape*.

        Call once before the first :meth:`step` call.
        """
        n_signals = int(np.prod(signal_shape))
        zi_base = np.zeros((self._sos.shape[0], 2))  # (n_sections, 2)
        self._zi = np.stack([zi_base] * n_signals, axis=-1)  # (sections, 2, N)
        logger.debug(f"ButterworthBandpass: reset state for {n_signals} signals")

    def step(self, frame: torch.Tensor) -> torch.Tensor:
        """Filter a single frame, updating internal state.

        Args:
            frame: ``(...)`` tensor (one frame of the signal, no time dim).

        Returns:
            Filtered frame with the same shape.
        """
        if self._zi is None:
            self.reset(frame.shape)

        original_shape = frame.shape
        flat = np.from_dlpack(frame.detach().float().cpu().contiguous()).flatten()  # (N,)
        filtered_flat, self._zi = sosfilt(self._sos, flat[np.newaxis, :], axis=0, zi=self._zi)
        out_c = np.ascontiguousarray(filtered_flat[0].reshape(original_shape), dtype=np.float32)
        result = torch.from_dlpack(out_c).to(device=frame.device, dtype=frame.dtype)
        return result

apply(signal)

Filter signal along its first (time) dimension (batch mode).

Parameters:

Name Type Description Default
signal Tensor

(T, ...) tensor.

required

Returns:

Type Description
Tensor

Filtered tensor with the same shape.

Source code in src/pyevm/filters/temporal.py
def apply(self, signal: torch.Tensor) -> torch.Tensor:
    """Filter *signal* along its first (time) dimension (batch mode).

    Args:
        signal: ``(T, ...)`` tensor.

    Returns:
        Filtered tensor with the same shape.
    """
    if _GPU_IIR_AVAILABLE and signal.device.type != "cpu":
        return self._apply_gpu(signal)
    return self._apply_cpu(signal)

apply_chunk(signal)

Filter a chunk of frames along the time dimension, updating state.

Equivalent to calling :meth:step T times in sequence; the IIR state is updated so the next call picks up seamlessly.

Uses the GPU JIT path on CUDA/MPS devices (no PCIe roundtrip).

Parameters:

Name Type Description Default
signal Tensor

(T, ...) tensor where T is the chunk length.

required

Returns:

Type Description
Tensor

Filtered tensor with the same shape.

Source code in src/pyevm/filters/temporal.py
def apply_chunk(self, signal: torch.Tensor) -> torch.Tensor:
    """Filter a chunk of frames along the time dimension, updating state.

    Equivalent to calling :meth:`step` T times in sequence; the IIR state
    is updated so the next call picks up seamlessly.

    Uses the GPU JIT path on CUDA/MPS devices (no PCIe roundtrip).

    Args:
        signal: ``(T, ...)`` tensor where ``T`` is the chunk length.

    Returns:
        Filtered tensor with the same shape.
    """
    if _GPU_IIR_AVAILABLE and signal.device.type != "cpu":
        return self._apply_chunk_gpu(signal)
    return self._apply_chunk_cpu(signal)

reset(signal_shape)

Initialise CPU filter state for streaming on a signal of signal_shape.

Call once before the first :meth:step call.

Source code in src/pyevm/filters/temporal.py
def reset(self, signal_shape: tuple[int, ...]) -> None:
    """Initialise CPU filter state for streaming on a signal of *signal_shape*.

    Call once before the first :meth:`step` call.
    """
    n_signals = int(np.prod(signal_shape))
    zi_base = np.zeros((self._sos.shape[0], 2))  # (n_sections, 2)
    self._zi = np.stack([zi_base] * n_signals, axis=-1)  # (sections, 2, N)
    logger.debug(f"ButterworthBandpass: reset state for {n_signals} signals")

step(frame)

Filter a single frame, updating internal state.

Parameters:

Name Type Description Default
frame Tensor

(...) tensor (one frame of the signal, no time dim).

required

Returns:

Type Description
Tensor

Filtered frame with the same shape.

Source code in src/pyevm/filters/temporal.py
def step(self, frame: torch.Tensor) -> torch.Tensor:
    """Filter a single frame, updating internal state.

    Args:
        frame: ``(...)`` tensor (one frame of the signal, no time dim).

    Returns:
        Filtered frame with the same shape.
    """
    if self._zi is None:
        self.reset(frame.shape)

    original_shape = frame.shape
    flat = np.from_dlpack(frame.detach().float().cpu().contiguous()).flatten()  # (N,)
    filtered_flat, self._zi = sosfilt(self._sos, flat[np.newaxis, :], axis=0, zi=self._zi)
    out_c = np.ascontiguousarray(filtered_flat[0].reshape(original_shape), dtype=np.float32)
    result = torch.from_dlpack(out_c).to(device=frame.device, dtype=frame.dtype)
    return result

Device

pyevm.device

Device detection and management for GPU/MPS/CPU compute.

get_device(force=None)

Return the best available compute device.

Priority: CUDA > MPS > CPU, unless force overrides.

Parameters:

Name Type Description Default
force str | None

One of "cuda", "mps", or "cpu". When None the best available device is selected automatically.

None

Returns:

Name Type Description
A device

class:torch.device ready to use.

Source code in src/pyevm/device.py
def get_device(force: str | None = None) -> torch.device:
    """Return the best available compute device.

    Priority: CUDA > MPS > CPU, unless *force* overrides.

    Args:
        force: One of ``"cuda"``, ``"mps"``, or ``"cpu"``. When ``None``
               the best available device is selected automatically.

    Returns:
        A :class:`torch.device` ready to use.
    """
    if force is not None:
        device = torch.device(force)
        logger.debug(f"Forced compute device: {device}")
        return device

    if torch.cuda.is_available():
        device = torch.device("cuda")
        name = torch.cuda.get_device_name(0)
        logger.info(f"Using CUDA GPU: {name}")
    elif torch.backends.mps.is_available():
        device = torch.device("mps")
        logger.info("Using Apple Silicon GPU (MPS)")
    else:
        device = torch.device("cpu")
        logger.info("No GPU detected — using CPU")

    return device

device_info(device)

Return a human-readable summary of device capabilities.

Source code in src/pyevm/device.py
def device_info(device: torch.device) -> dict[str, str]:
    """Return a human-readable summary of *device* capabilities."""
    info: dict[str, str] = {"device": str(device)}
    if device.type == "cuda":
        info["name"] = torch.cuda.get_device_name(device)
        props = torch.cuda.get_device_properties(device)
        info["vram_gb"] = f"{props.total_memory / 1e9:.1f}"
    elif device.type == "mps":
        info["name"] = "Apple Silicon (MPS)"
    else:
        info["name"] = "CPU"
    return info