AlvaTech

PyTorch in Apple Silicon (M1) Mac

2 min read

Starting PyTorch 1.12 official release, PyTorch supports Apple’s new Metal Performance Shaders (MPS) backend.

Install Anaconda package manager

curl -O https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh
sh Miniconda3-latest-MacOSX-arm64.sh

Follow the on screen instructions to complete the installation. Remember to close and re-open the terminal before going to the next step.

Install PyTorch

conda install pytorch torchvision torchaudio -c pytorch-nightly

Install Jupyter Lab (Optional)

conda install -c conda-forge jupyterlab
conda install -c conda-forge nb_conda_kernels

Verify the Installation

You can verify mps support using a simple Python script

import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

The output of the script should be

tensor([1.], device='mps:0')

If you chose to Install Jupyter lab then you can easily run the above code in the jupyter notebook by running

jupyter notebook