--- title: DJLのPyTorchバックエンドでMPS (Metal Performance Shaders) を使うメモ tags: ["Java", "DJL", "Machine Learning", "PyTorch", "MPS"] categories: ["Dev", "Java", "ai", "djl"] date: 2023-10-31T08:10:27Z updated: 2023-10-31T09:15:27Z --- [DJL (Deep Java Library)](https://github.com/deepjavalibrary/djl) 0.20.0以降で [MPS](https://developer.apple.com/metal/pytorch/) が使えるようになっていました。 サンプルが見当たらなかったので試したメモ。 サンプルコードは [こちら](https://github.com/making/hello-djl-pytorch) です。Apple M2 Pro、メモリ32 GB、macOS 13.5.2で試しました。 次のように`Device`インスタンスを`Device.of("mps", 0)`で作れば良いようです。 ```java import ai.djl.Device; import ai.djl.ndarray.NDArray; import ai.djl.ndarray.NDManager; import ai.djl.ndarray.types.Shape; public class Main { public static void main(String[] args) { int dimension = 1024; Device device = Device.of("mps", 0); //Device device = Device.cpu(); System.out.println(device.isGpu()); // false try (NDManager manager = NDManager.newBaseManager(device)) { NDArray array1 = manager.randomUniform(0, 1, new Shape(dimension, dimension)); NDArray array2 = manager.randomUniform(0, 1, new Shape(dimension, dimension)); NDArray result = array1.add(array2).mul(10).matMul(array1.transpose()).div(5); System.out.println(result); } } } ``` MPS自体はGPUではないですが、MPSのAPIを使うことで、GPUが利用されるため、MPSを使ったコードを実行するとアクティビティモニタで `% GPU` の数字が0より大きくなります。