Atlassian uses cookies to improve your browsing experience, perform analytics and research, and conduct advertising. Accept all cookies to indicate that you agree to our use of cookies on your device. Atlassian cookies and tracking notice, (opens new window)
User Manual

User Manual
Results will update as you type.
  • Application Guide
  • Status of System
  • Usage Guide
  • Compute partitions
  • Software
    • AI Frameworks and Tools
      • PyTorch
      • TensorFlow
      • JAX
      • XGBoost
    • Bring your own license
    • Chemistry
    • Data Manipulation
    • Engineering
    • Environment Modules
    • Miscellaneous
    • Numerics
    • Virtualization
    • Devtools Compiler Debugger
    • Visualisation Tools
  • FAQ
  • NHR Community
  • Contact

    You‘re viewing this with anonymous access, so some content might be blocked.
    /
    JAX
    Updated Feb. 02, 2024

    JAX

     

    jax_logo.png

     

     

     


    JAX is a python package that combines composable NumPy transforms and accelerated linear algebra (XLA) routines. Although not formally a deep learning framework, it can be used to great effect for any problem that requires fast autodifferentiation. It offers good support for vectorization and parallel computing, and when combined with the extensions below it can be used to train general machine learning models.

    JAX is a functionally pure framework - this may be unfamiliar to users of PyTorch or TensorFlow, which are more object-oriented in nature. See here for solutions to common problems and other tips for getting started with JAX.

    Extensions

    There are several useful JAX-related python packages included in the anaconda3/2023.09 module:

    • Haiku - Python package for building object-oriented-like machine learning models in JAX.

    • Optax - Gradient-based optimization library for training models in JAX.

    Examples

    Examples of CPU, (multi) GPU, and multi-node training tasks for HPC environments can be found here. Below are reproduced examples for training convolutional neural network image classification models on the Fashion-MNIST dataset.

    Setup (on login node):

    This sets up some simple packages:

    $ module load anaconda3/2023.09 $ conda activate base $ git clone https://github.com/Ruunyox/jax-hpc $ cd jax-hpc $ pip install --user .

    1. Single node, single GPU:

    We start with a training YAML file (config_local_gpu.yaml). Since only 1 GPU is needed, it is better to use the gpu-a100:shared partition and request just one GPU (gres=gpu:A100:1) rather than queuing for a full node with 4 GPUs. The following SLURM submission script details the options:

    #! /bin/bash #SBATCH -J jax_cli_test_gpu #SBATCH -o jax_cli_test_gpu.out #SBATCH --time=00:30:00 #SBATCH --partition=gpu-a100 #SBATCH --nodes=1 #SBATCH --ntasks-per-node=1 #SBATCH --gres=gpu:A100:1 #SBATCH --mem-per-cpu=1G #SBATCH --cpus-per-task=4 module load sw.a100 module load cuda/11.8 module load anaconda3/2023.09 conda activate base export XLA_FLAGS=--xla_gpu_cuda_data_dir=/sw/compiler/cuda/11.8/a100/install export JAX_PLATFORM_NAME=gpu export PYTHONUNBUFFERED=on jaxhpc --config config_local_gpu.yaml

    and can be run using:

    $ sbatch cli_test_conv_gpu.sh

    The results can be inspected in the associated output log.

    2. Multiple GPUs

    We direct users to the documentation for parallel executions using pmap here.

    {"serverDuration": 10, "requestCorrelationId": "12ca14d74f3c4c55b5e287005431631a"}