Skip to content

matthewlai/JaxVidFlow

Repository files navigation

JaxVidFlow

Video processing pipeline using JAX. Especially suitable for experimenting with custom video processing operations.

Why?

  • FFmpeg has done great things for the community over the past decades, and has been the go-to for automated video processing, but -
    • It's really hard to extend with filters. Writing high performance C code is hard, and very hardware-dependent
    • Only a handful of filters have GPU implementations, and generally using them requires very messy command line options
    • Every new CPU or GPU architecture requires new code
    • Floating point formats not supported, requiring carefully choosing pixel formats to minimize quality loss for multi-stage processing, and different filters support different formats.
  • JAX is a high performance and very user-friendly array computing library.
    • Write simple Numpy expressions, get well-optimised native code performance on CPU/GPU/TPU
    • Very easy to implement most custom operations. As long as you can express it as matrix operations, you are 80% of the way there!
    • Code generation for:
      • CPUs (compiled into Eigen operations with good vectorization for x86, ARM, and other architectures)
      • GPU (NVIDIA CUDA is best supported, AMD ROCm experimental, Intel oneAPI also experimental, Apple Metal Performance Shaders Graph on all Apple GPUs also experimental)
      • Google TPUs
      • Future architectures as they come out, without having to change our code (in theory)
    • We can do everything in floating point. FP is the state of the art for minimal-loss multi-stage video processing, and we can do it very fast with GPUs (and CPUs with SIMD).

Current Status

You can run benchmarks.py to see how fast things are, but it doesn't really have a UI yet (not even a CLI). It's just a library. See examples/process_dive_video.py for a typical pipeline setup for filtering a video.

Implemented functions

  • Decode / encode pipeline using FFmpeg (through PyAV)
    • Reasonably optimised - only about 10% slower than using FFmpeg directly for a straight transcode with hardware encoding
      • For a straight transcode, use FFmpeg instead
    • Supports hardware encoders

Transforms

  • YUV to RGB and back, including chroma subsampling/supersampling
  • Rec709 to linear and back
  • 3D LUT application with trilinear filtering
  • Resizing with Lanczos interpolation (ok, this is really just a one line call to jax.image.resize())
  • Denoising using NL-Means (both pixelwise and blockwise variants implemented)
    • ~50 fps at 4K on NVIDIA RTX 3060 Ti, compared to ~2 fps with FFmpeg's CPU implementation

Installation Instructions

Windows (no GPU-accelerated Jax or encoding)

# Install JaxVidFlow.
pip install JaxVidFlow

Mac

# Install JaxVidFlow.
pip install JaxVidFlow

# Optional: Install JAX with the experimental METAL backend.
pip install jax-metal

Linux (Debian/Ubuntu) or Windows (with WSL2, in an Ubuntu VM)

# Install dependencies.
sudo apt install pkg-config python3 libavformat-dev libavcodec-dev libavdevice-dev \
  libavutil-dev libswscale-dev libswresample-dev libavfilter-dev

# Setup a virtual environment (optional).
python3 -m venv venv

# Activate the virtual environment (optional).
source venv/bin/activate

# Install PyAV from source (this links against system ffmpeg libraries, which is necessary to get hardware-accelerated decoders and encoders).
pip3 install av --no-binary av

# Install JAX (with NVIDIA GPU support).
pip3 install jax[cuda12]

# Or, install CPU-only JAX.
pip3 install jax

# See Jax documentation for installing JAX with experimental backends (eg. AMD ROCm):
# https://jax.readthedocs.io/en/latest/installation.html

Acknowledgements

About

Experimental video processing pipeline using Jax

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published