ストリーミングはほとんどのブラウザと
Developerアプリで視聴できます。
-
Apple GPUでの機械学習モデルとAIモデルのトレーニング
PyTorch、JAX、TensorFlow向けのMetalを使用し、Appleシリコンでモデルをトレーニングする方法を解説します。新しいアテンション操作と量子化のサポートを利用して、デバイス上でのTransformerモデルのパフォーマンスを向上させましょう。
関連する章
- 0:00 - Introduction
- 1:36 - Training frameworks on Apple silicon
- 4:16 - PyTorch improvements
- 11:26 - ExecuTorch
- 13:19 - JAX features
リソース
関連ビデオ
WWDC24
WWDC23
WWDC22
WWDC21
-
ダウンロード
こんにちは Yona Havocainenです GPU, Graphics and Display Software チームのソフトウェアエンジニアです 今日はAppleシリコンGPUで 機械学習モデルとAIモデルを トレーニングする方法や 今年追加された新機能をご紹介します
Appleシリコンは デバイスでの機械学習に対応する 優れた機能を多数備えています この強力なGPUは 最新のニューラルネットワークの 最適化に必要な 演算処理に秀でています
これをユニファイドメモリアーキテクチャ と組み合わせることで GPUから直接 大量のメモリにアクセスできます
大容量メモリにより デバイスでローカルに 大規模モデルをトレーニングし実行できます
またトレーニング時に大きなバッチサイズを 使えるため 一般に収束が速くなります
さらにモデルの重みを 複数のマシンに分散する必要がないので トレーニングから導入までの プロセスが簡単になります
トレーニングはAppleのプラットフォームに モデルを導入するための最初のステップです モデルのトレーニングが完了したら デバイスへの導入を準備する必要があります
準備ができると モデルはアプリに統合可能になります
機械学習モデルを導入するための 全体的なフローの解説については Appleデバイスでの 機械学習ワークフローに関する ビデオをご覧ください
このセッションでは トレーニングに焦点を当て Appleシリコン独自の演算能力を活用できる フレームワークをいくつか紹介します
この強力なGPUにアクセスするには Metalバックエンドを 機械学習でよく使われる フレームワークのいずれかで使用します TensorFlow、PyTorch、 JAX、MLXです
TensorFlowは多くの業界アプリで使われる 信頼性の高いフレームワークです
Metalバックエンドで サポートされる機能としては 大規模なプロジェクトでの 分散トレーニングや トレーニングのパフォーマンスを高める 混合精度などがあります TensorFlowでのMetalバックエンドの 有効化はこれまでになく簡単です Pipなどのパッケージマネージャを使って TensorFlowをインストールし プロジェクトにインポートするだけです
TensorFlowのMetalバックエンドの詳細は WWDC21のビデオでご確認ください
もう1つ 広く使われている フレームワークがPyTorchです Metalバックエンドはカスタム操作や プロファイリングなどの機能に 対応しているので ネットワーク性能を 簡単にベンチマークして改善できます PyTorchでMetalバックエンドを 使い始めるのも簡単です PyTorchをプロジェクトにインポートして デフォルトのデバイスをmpsに設定します
PyTorchの Metalバックエンドの詳細については WWDC22のビデオをご覧ください
JAXは最近 Metalバックエンドの サポート対象フレームワークに追加され
サポートされる機能には ジャストインタイムコンパイルや Numpyのような 使いやすいインターフェイスがあります
JAXのMetalバックエンドを使用するには jax-metalをインストールして JAXをプロジェクトにインポートします
JAXのMetalバックエンドについては WWDC23のビデオで詳しく説明しています
MLXはMetalバックエンドで サポートされる最新のフレームワークです
MLXはAppleシリコン向けに 設計され最適化されています サポートされる機能には NumpyのようなAPIや ジャストインタイムコンパイル 分散トレーニング ネイティブの ユニファイドメモリがあります
Python、Swift、C、C++用の バインディングも用意されています
Transformerモデルの微調整 画像生成 音声の書き起こしなど 機械学習タスクを実行するためのサンプルは コードリポジトリに用意されています
MLXを使い始めるのは 他のフレームワークと同じように簡単です ホイールをPython環境にインストールして プロジェクトにインポートするだけです
MLXフレームワークの詳細については こちらのドキュメントと コードリポジトリをご覧ください
これでトレーニングに関する Appleシリコンの基本がわかったので 今日のメイントピックに移りましょう 新機能と改善点を いくつかご紹介したいと思いますが 特に2つのフレームワークを 中心にお話しします PyTorchとJAXです
まずPyTorchから始めましょう
1年前のWWDC23で MPSバックエンドの 開発はベータ版段階に進みました
それ以降 カスタムカーネル 広範な操作への対応 ユニファイドメモリアーキテクチャの
サポートが追加されました またパフォーマンスと機能の両面で 数多くの改善と修正が加えられました これはPyTorch関連のオープンソース コミュニティに依るところが大きいです
様々なネットワークへの対応も この1年で強化されました 例えば 最先端のTransformerモデル用の HuggingFaceリポジトリでは現在 人気の高い上位50のネットワークを PyTorch-MPSバックエンドで すぐに高速化できます これには 今年有名になった 多数のモデルが含まれます Stable Diffusion、 Meta LLaMAモデル、Gemmaなどです
改善点については 特に影響力の大きい 3つのTransformerモデルを取り上げます まず 8ビットと4ビットの 整数量子化のサポートにより 大規模なモデルでも デバイスのメモリに格納できます
融合型のスケーリングされた ドット積アテンションにより 多くの一般的なモデルの パフォーマンスが向上します
そしてユニファイドメモリのサポートにより GPUに演算処理をディスパッチする時に 不要なテンソルのコピーがなくなります
では これらのトピックについて それぞれ詳しくお話ししましょう 32ビット浮動小数点数や ご覧の16ビット浮動小数点数などの データ形式は モデルのトレーニングでよく使われます 1ビットは値の符号 5ビットは指数 10ビットは小数を表します 精度はトレーニング中に パラメータを更新する時に役立ちます トレーニング後に 量子化という手法を用いると パラメータに必要なメモリを 減らすことができます
同じ値を8ビットの整数として表すことで 必要なメモリを半分に削減できます その利点として モデルのメモリ占有量が小さくなり 演算処理のスループットが向上し モデルによっては 出力精度がほとんどあるいは まったく低下することなく これを実現できます
スケーリングされたドット積アテンションは 多くのTransformerモデルの中核です この操作の起点となるのは トークン化されたテキストの入力です
この入力は クエリ、キー、バリューという 3つのテンソルに分割されます
その後 3つのテンソルは 一連の行列乗算 スケーリング Softmax演算を通じて操作されます 一連の操作を 1つのカーネル呼び出しに融合することで 多数の小さな演算処理をGPUに ディスパッチした時のオーバーヘッドを避け 多くのネットワークの全体的な パフォーマンスを改善できます
最後に取り上げる パフォーマンス面の改善点は Appleデバイスのユニファイド メモリアーキテクチャがもたらす利点です それによりメインメモリ内に テンソルを単純に保持して メモリの領域間で ビットをコピーする必要なしに CPUとGPUの両方から メモリ内のテンソルにアクセスできます
次にPyTorchに関する 説明の締めくくりとして 言語モデルを取得してカスタマイズし ユースケースに合わせて微調整し デバイスで実行するための ワークフロー全体を紹介します
まずトーチをインポートし 結果を再現できるように ランダムシートをロックします
人気のtransformersライブラリを使って モデルとトークナイザーを ダウンロードして設定します この方法により HuggingFaceリポジトリから モデルを簡単に取得できます
タスクのベースモデルとして 30億のパラメータを持つ OpenLLaMAバージョン2を使用します また モデルのトレーニングに使った 対応するトークナイザーも必要です
微調整アダプタを モデルにアタッチするために peftライブラリとLoraConfigを使用します アダプタのパラメータを定義してから ベースモデルと設定を使って 新しいPeftModelを作成します
これで演算デバイスのMPSに モデルを送ることができます
次に 調整に使うデータを 選択する必要があります ここではトレーニングの入力として Andrej Karpathyの tinyshakespeareデータセットを使います これはシェイクスピアの作品が 1つにまとめられたファイルです
データセットの読み込み後 それを データセットオブジェクトに読み込んで このデータに使用する トークナイザーを指定します
調整のためにトレーニングパラメータを いくつか設定する必要があります Trainerクラスを使って バッチサイズや トレーニングエポック数などの 引数を設定します
データコレクタオブジェクトは トレーナーのオブジェクトの トレーニングバッチを形成します
これでモデル、引数、データコレクタ、 トレーニングデータセットを渡して トレーナーオブジェクトを作成できます
トレーニングを始める前に 微調整前のモデルの 出力内容を確認しましょう ちょっとした便利な関数を追加して 入力テキストを受け取り モデルで使えるようにトークン化し 出力を生成し トークン化を解除して 人間が読めるテキストに戻すようにします
シェイクスピアの文章で試して どのような応答が 返ってくるか見てみましょう
調整前のモデルは 辞書の項目のように 動作している感じに見えます まず引用文の出所を正しく示してから 唐突に家長に関する説明に移っています
辞書と話してもあまり面白くないので 微調整によってモデルに 少し活気を与えてみましょう
trainerクラスでトレーニングを開始します 先ほど定義したパラメータを使って データセットを処理します
しばらくすると トレーニングがデータセットに対して 10エポック実行されて終了します
では 前と同じ入力で試してみましょう
メニーニアスの興味深いセリフですね 微調整によって明らかに成果が得られました
では 後で使えるようにモデルを保存します 使いやすくするために アダプタとベースモデルを 1つのエンティティにマージし トークナイザーも モデルと一緒に保存しておきます
モデルのトレーニングが完了したので デバイスに導入して試してみたいと思います
ほとんどのネットワークで 推奨される方法は Core MLを使って導入することです
この詳細については デバイスへの モデルの導入に関する解説をご覧ください
ここでは PyTorchエコシステム内にとどまり 新しいExecuTorchフレームワークを使って モデルを導入したいと思います
ExecuTorchの目的は 推論のため様々なデバイスに PyTorchモデルを導入することです PyTorchトレーニングで定義した カスタムの操作は ExecuTorchでの導入で シームレスに使用できます
ExecuTorchでは MPS Partitionerで計算グラフが分析され MPSデバイスを使って 認識パターンが高速化されます
こちらがローカルデバイスで ExecuTorchを設定する方法です
まずリポジトリのクローンを マシンに作成します
次にサブモジュールを更新します
最後にExecuTorchのビルド時にMPSの バインディングを使うオプションを渡して インストールスクリプトを実行します では ExecuTorchで モデルを導入する方法をお見せしましょう ExecuTorchリポジトリの 例に沿って進めていきます テストモデルには Meta LLaMA2モデルを使用します モデルはグループごとの量子化法を使用して 4ビット整数データ型に 変換してあります そのため よりコンパクトで 高速になっています
macOSでiOS向けの デモアプリをリポジトリにビルドし iPad Proを 導入ターゲットとして使用します
アプリのビルド後 使用するモデルと モデルのトレーニングに使った 対応するトークナイザーを選択します
次に ラザニアの作り方を モデルに尋ねてみます
ここでのクエリには LLaMA2プロンプト テンプレートを使っています このモデルはチャットボットのように 動作するよう微調整されていて この形式を想定しているからです
ExecuTorchを介してiPadで ローカルに実行しているこのモデルは 夕食に良いレシピを いくつか提案してくれています ただ トマトとチーズが足りないようです
新機能と改善点を利用して PyTorchワークフローを 高速化する方法については以上です 次にJAXに追加された 新機能について説明します JAXはMPS Graphでサポートされる 一般的な機械学習フレームワークです
JAX-MetalプラグインはWWDC23で デベロッパ向けにリリースされました それ以来このプラグインは進化を続け 多くの機能とパフォーマンス関連の 更新が追加されています
このような更新には 改良された高度な配列のインデックス作成
JAXの公式リポジトリでの CIランナーワークフローの採用
混合精度のサポートなどがあります
リリース以降に JAX-Metalバックエンドを採用している ユーザーを紹介したいと思います 最初はMuJoCoです ロボット工学や生体力学など 高速で正確なシミュレーションを 必要とするユースケース向けの オープンソースのフレームワークです
JAX Metalバックエンドを活用して Macプラットフォームを利用するユーザーに 最高のパフォーマンスを提供しています
次はAXLearnです これは大規模な深層学習モデルを 開発するためのライブラリです Metalバックエンドにより ローカルデバイスでの ワークフローの迅速な やり取りとテストを実現しています
これらのライブラリを確認して どのようにJAX-Metalバックエンドが Macデバイスで優れた体験を実現するか 試してみてください
次に JAX-Metalバックエンドに加えられた 改善点について詳しく見ていきましょう JAXでの混合精度 NDArrayのインデックス作成 パディングについて説明します
今年の更新内容の1つとして JAX-Metalフレームワークで BFloat16データ型がサポートされました
このデータ型は 浮動小数点値の広い動的範囲を表し 混合精度トレーニングなどの ユースケースに適しています
この新しいデータ型は JAXの他のデータ型と同じように使えます
例えば この新しいデータ型を使って テンソルを作成できます
もう1つの改善点として NDArrayの インデックス作成と更新のサポートにより Numpyのような構文で 配列を操作できるようになりました
例えば 2行2列の 小さな配列を作成する場合 Numpyのインデックス作成構文を使って 1列目を10で割ることができます
JAXでは パディングポリシーを定義できますが そのパディングポリシーが JAX-Metal バックエンドでもサポートされました
これを使うと ダイレーションと呼ばれる パディングを要素間に追加できます
これはpad関数を呼び出して行います この関数は次元ごとに 3つのパラメータを受け取ります
ネガティブパディングで テンソルから要素を削除することもできます
これを行うには パディング設定で負の値を渡します
JAXセクションの締めくくりとして JAXの使い方の簡単な例を紹介します 先ほど説明した AXLearnライブラリを使います そこからfujiの70億パラメータの モデルを選択して実行し 先ほど説明した BFloat16データ型をモデルに使います
このスクリプトは 小さな入力をランダムに作成し それをモデルに渡して 次のトークンを 生成するようモデルに要求します
出力では ロジットと結果のトークンが示されます
予測が終わったら 同じスクリプトをもう一度実行します ただし 今度は環境変数を使って CPUで実行するようにJAXを設定します
ご覧のように この推論で CPUの出力がMetalバックエンドの出力と 一致していることが確認できたので デモを終わります
JAXとこのWWDCに関する 今回のプレゼンテーションは これで終了です 今日お話しした内容をまとめましょう
Appleシリコンで利用できる ユニファイドメモリアーキテクチャは 様々な機械学習のタスクに 重要なメリットをもたらします より大きなモデルとバッチサイズを 使えるようになるうえ CPUとGPUで 同じメモリにアクセスできるため CPUとGPU間のコピーも不要になります
PyTorch、JAX、 TensorFlow、MLXといった 人気のあるフレームワーク用の Metalバックエンドを通じて 強力なAppleシリコンGPUを使用できます 今年は 人気の高いTransformer クラスモデルのサポートについて 様々なパフォーマンスの強化が 行われています
そうした更新を活用するためにも ぜひフレームワークの最新リリースを 使用していることを確認し macOSも忘れずに更新してください
ご視聴ありがとうございました これらの新機能が 皆さんのお役に立てば幸いです
-
-
特定のトピックをお探しの場合は、上にトピックを入力すると、関連するトピックにすばやく移動できます。
クエリの送信中にエラーが発生しました。インターネット接続を確認して、もう一度お試しください。