Accelerated JAX training on Mac

Metal plug-in

JAX uses the new Metal plug-in to provide Metal acceleration on Mac platforms. The Metal plug-in uses the OpenXLA compiler and PjRT runtime to accelerate JAX machine learning workloads on GPU. The OpenXLA compiler lowers the JAX Graphs to a Stable HLO format, which is converted to MPSGraph executables and Metal runtime APIs to dispatch to GPU.

Requirements

  • Mac computers with Apple silicon or AMD GPUs
  • macOS 13.4 or later
  • Python 3.9 or later
  • Xcode command-line tools: xcode-select --install

The table below tracks jax-metal versions and compatible versions of macOS, jax, and jaxlib:

jax-metal macOS jaxlib jax
0.0.4 Sonoma 14.0 v0.4.11 v0.4.11
0.0.3 Ventura 13.4.1+, Sonoma 14.0 beta v0.4.10 v0.4.11

Get started

1. Set up

python3 -m venv ~/jax-metal
source ~/jax-metal/bin/activate
python -m pip install -U pip
python -m pip install numpy wheel ml-dtypes==0.2.0

2. Installation

jax-metal 0.0.4 or later

A custom build of jaxlib isn’t required for these versions, as they rely on a pinned version of jax and jaxlib through package dependencies.

python -m pip install jax-metal

jax-metal 0.0.3 or earlier

First, build compatible JAX from the source. This version of the plug-in is compatible with jax v0.4.11 and the pinned version of jaxlib v0.4.10. To enable plug-in loading with that version, jaxlib needs to be built from the source.

For pre-required setups and scripts to build JAX from source in general, visit https://jax.readthedocs.io/en/latest/developer.html. Use the steps below to build a specific jaxlib compatible with the Metal plug-in:

# obtain JAX source code
git clone https://github.com/google/jax.git --branch jaxlib-v0.4.10 --single-branch
cd jax
# build jaxlib from source, with capability to load plugin
python build/build.py --bazel_options=--@xla//xla/python:enable_tpu=true
# install jaxlib
python -m pip install dist/*.whl
# install jax
python -m pip install jax

Then, install the jax-metal plug-in.

python -m pip install jax-metal==0.0.3

3. Verification

python -c 'import jax; print(jax.numpy.arange(10))'

Questions and feedback

To ask questions and share feedback about the jax-metal plug-in, visit the Apple Developer Forums.