스트리밍은 대부분의 브라우저와
Developer 앱에서 사용할 수 있습니다.
-
Apple GPU에서 머신 러닝 및 AI 모델 학습시키기
PyTorch, JAX, TensorFlow용 Metal을 사용하여 Apple Silicon에서 모델을 학습시키는 방법을 알아보세요. 새로운 어텐션 작업 및 양자화 지원을 활용하여 기기의 트랜스포머 모델 성능을 향상해 보세요.
챕터
- 0:00 - Introduction
- 1:36 - Training frameworks on Apple silicon
- 4:16 - PyTorch improvements
- 11:26 - ExecuTorch
- 13:19 - JAX features
리소스
관련 비디오
WWDC24
WWDC23
WWDC22
WWDC21
-
다운로드
안녕하세요, 저는 Yona Havocainen입니다 저는 GPU, Graphics 및 Display Software 팀의 엔지니어입니다 오늘은 Apple Silicon GPU로 머신 러닝 및 AI 모델을 학습시키는 방법을 보여 드리고 올해 도입된 새 기능을 소개하겠습니다
Apple Silicon은 기기에서 머신 러닝을 위한 여러 놀라운 기능을 선사합니다 강력한 GPU는 최신 신경망을 최적화하는 데 필요한 유형의 계산에서 우수한 성능을 발휘합니다
통합 메모리 아키텍처를 채택한 GPU는 상당한 양의 메모리에 바로 액세스할 수 있습니다
대량의 메모리를 통해 기기에서 로컬로 큰 모델을 학습시키고 실행할 수 있죠
또한 학습 중에 더 큰 배치 크기를 사용해 수렴이 빨라지는 경우가 많습니다
모델 가중치를 여러 머신에 걸쳐 배포하지 않아도 되므로 학습에서 배포로의 과정이 더 간결해집니다
학습은 Apple 플랫폼에 모델을 배포하는 첫 번째 단계입니다 모델을 학습시킨 후에는 기기에 배포하기 위해 준비해야 합니다
준비된 모델은 애플리케이션에 통합될 수 있습니다
머신 러닝 모델 배포의 전반적인 흐름을 다루는 세션을 아직 시청하지 않았다면 Apple 기기의 머신 러닝 작업 흐름에 대한 동영상을 참고해 주세요
이 세션에서는 학습에 집중할 것이며 Apple Silicon의 고유한 컴퓨팅 기능을 활용할 수 있는 몇 가지 프레임워크를 설명해 드리겠습니다
강력한 GPU를 활용하려면 인기 많은 머신 러닝용 프레임워크 중 하나에서 Metal 백엔드를 사용하면 됩니다 이러한 프레임워크로는 TensorFlow PyTorch, Jax, MLX가 있죠
TensorFlow는 다양한 산업에서 신뢰받는 프레임워크입니다
Metal 백엔드가 대규모 프로젝트를 위한 분산 학습, 혼합 정밀도와 같은 기능을 지원하므로 학습 성능이 개선됩니다 이제 아주 쉽게 TensorFlow에 대해 Metal 백엔드를 활성화할 수 있죠 pip와 같은 패키지 관리자로 TensorFlow를 설치한 다음 프로젝트에 가져오면 됩니다
WWDC21 동영상에서 TensorFlow용 Metal 백엔드에 대해 자세히 알아보세요
PyTorch 또한 많이 사용되는 프레임워크입니다 Metal 백엔드가 맞춤형 작업과 프로파일링과 같은 기능을 지원하므로 네트워크 성능을 쉽게 벤치마킹하고 개선할 수 있죠 PyTorch에서 Metal 백엔드를 시작하는 방법도 간단합니다 프로젝트에서 PyTorch를 설치한 후 가져오고 기존 기기를 mps로 설정하세요
PyTorch용 Metal 백엔드에 대한 자세한 내용을 알아보려면 WWDC22 동영상을 시청해 보세요
JAX 프레임워크에는 Metal 백엔드 지원이 최근에 추가되었죠
JIT 컴파일과 같은 기능을 지원하고 배열 삭제를 위한 Numpy 같은 인터페이스를 제공합니다
JAX용 Metal 백엔드를 시작하려면 jax-metal을 설치한 후 프로젝트에 JAX를 가져오면 됩니다
JAX용 Metal 백엔드는 WWDC23 동영상에서 자세히 다루었습니다
MLX는 Metal 백엔드를 통해 지원되는 최신 프레임워크입니다
MLX는 Apple Silicon에 맞게 설계되고 최적화되었습니다 Numpy와 유사한 API JIT 컴파일 분산 학습, 통합 메모리와 같은 기능을 기본적으로 지원합니다
이 프레임워크는 Python, Swift, C 및 C++에 대한 바인딩을 제공하죠
Apple의 코드 리포지토리에서 트랜스포머 모델의 미세 튜닝 이미지 생성, 오디오 전사와 같은 머신 러닝 작업을 실행하기 위한 예시를 찾을 수 있습니다
MLX도 다른 프레임워크처럼 쉽게 시작할 수 있습니다 Python 환경에 MLX를 설치한 후 프로젝트에 가져오면 됩니다
MLX 프레임워크에 대해 자세히 알아보려면 Apple에서 제공하는 문서와 코드 리포지토리를 확인해 보세요
이제 Apple Silicon을 학습시킬 준비가 되었으므로 오늘의 주요 주제로 넘어가겠습니다 특히 두 가지 프레임워크에 적용된 새로운 기능과 개선 사항을 소개해 드리고자 합니다 PyTorch와 JAX 프레임워크죠
PyTorch부터 살펴보겠습니다
1년 전 WWDC23에서 MPS 백엔드의 개발은 베타 단계에 도달했죠
그 후 맞춤형 커널과 넓어진 op 적용 범위 통합 메모리 아키텍처에
대한 지원이 추가되었습니다 이 밖에도 성능과 기능 모두에 대해 여러 개선과 수정 사항이 도입되었습니다 이 중 상당수에는 PyTorch 중심의 오픈 소스 커뮤니티가 기여했죠
다양한 네트워크에 대한 적용 범위도 한 해 동안 개선되었습니다 예를 들어 최첨단 트랜스포머 모델을 위한 HuggingFace 리포지토리에서 PyTorch-MPS 백엔드는 이제 가장 인기 있는 상위 50개 네트워크를 즉시 가속화할 수 있습니다 여기에는 Stable Diffusion과 Meta LLaMA 모델 Gemma 등 올해 유명해진 여러 모델이 포함됩니다
개선 사항에 대해 이야기해 볼게요 특히 트랜스포머 모델에 영향을 주는 3가지 개선 사항이 있습니다 8비트 및 4비트 정수 양자화가 지원되어 대규모의 모델도 기기 메모리에 담을 수 있습니다
융합 Scaled Dot-Product Attention이 지원되어 흔히 사용되는 여러 모델의 성능이 개선됩니다
마지막으로 통합 메모리 지원이 있어 GPU에 컴퓨팅을 디스패치할 때 텐서의 중복 사본이 삭제됩니다
이제 각 주제에 대해 자세히 설명해 드리겠습니다 32비트 부동 소수점이나 화면에 표시된 16비트 부동 소수점과 같은 데이터 형식은 모델을 학습시킬 때 흔히 사용되죠 맨 앞의 1비트는 값이 양수인지 음수인지 표시하고 5비트는 지수를 표시하며 10비트는 가수를 표시합니다 이러한 정밀도는 학습 중에 매개변수를 업데이트할 때 유용하죠 학습을 완료한 후에는 양자화라는 기법을 통해 매개변수에 필요한 메모리를 줄일 수 있습니다
동일한 값을 8비트 정수로 나타내면 필요한 메모리를 반으로 줄일 수 있죠 이렇게 하면 모델의 메모리 사용량 감소와 컴퓨팅 처리량 향상 등의 이점을 제공하며 모델에 따라 이를 달성해도 출력 정확도가 조금만 저하되거나 아예 저하되지 않습니다
Scaled Dot-Product Attention은 많은 트랜스포머 모델에 중요하죠 이 작업은 토큰화된 텍스트의 입력으로 시작됩니다
입력은 query, key 및 value 텐서로 불리는 3개의 텐서로 분할됩니다
3개의 텐서는 일련의 행렬 곱셈, 스케일링, 소프트맥스 작업을 통해 처리됩니다 일련의 작업을 단일 커널 호출로 융합하면 여러 작은 컴퓨팅을 GPU로 디스패치할 때 발생하는 오버헤드를 피할 수 있어 많은 네트워크에서 성능이 전반적으로 개선되죠
제가 마지막으로 강조하고 싶은 성능 업데이트는 Apple 기기의 통합 메모리 아키텍처가 제공하는 이점입니다 덕분에 메모리의 한 영역에서 다른 영역으로 비트를 복사할 필요 없이 텐서를 메인 메모리에 두어도 CPU와 GPU 모두에서 텐서에 접근할 수 있습니다
다음으로 PyTorch에 대한 설명을 마무리하기 위해 언어 모델을 가져와 이를 맞춤화하고 사용 사례에 맞게 미세 튜닝한 후 기기에서 실행하는 종단간 워크플로를 보여 드리겠습니다
우선 torch를 가져오고 재현 가능한 결과를 위해 random seed를 고정하겠습니다
인기 있는 transformers 라이브러리를 통해 모델과 토크나이저를 다운로드하고 설정합니다 이 라이브러리는 HuggingFace 리포지토리에서 모델을 가져올 때 편리합니다
저는 30억 개의 매개변수를 갖는 Open-Llama 버전 2를 작업의 기본 모델로 활용하고 있습니다 모델을 학습시킬 때 사용된 토크나이저도 필요하죠
모델에 미세 튜닝을 위한 어댑터를 연결하기 위해 lora_config를 포함하는 peft 라이브러리를 사용할게요 어댑터에 대한 매개변수를 정의하고 기본 모델 및 구성을 사용하여 새로운 PeftModel을 만들게요
이제 컴퓨팅 기기인 MPS에 모델을 전송할 수 있습니다
다음으로 튜닝에 사용할 데이터를 선택해야 합니다 Andrej Karpathy의 tinyshakespeare 데이터세트를 학습 입력으로 사용하겠습니다 이는 편의를 위해 셰익스피어의 작품을 단일 파일로 결합한 데이터세트입니다
데이터세트가 로드되면 dataset 객체에 로드할 수 있습니다 이는 데이터에 대해 어떤 토크나이저를 사용할지 지시하죠
튜닝을 위해 몇 가지 학습 매개변수를 설정해야 합니다 Trainer 클래스를 통해 배치 크기나 사용할 training epoch의 수와 같은 인수를 설정할 수 있죠
DataCollator 객체는 Trainer 객체를 위한 학습 배치를 형성합니다
이제 모델, 인수, DataCollator 및 학습 데이터세트를 전달하여 Trainer 객체를 만들 수 있습니다
학습을 시작하기 전에 모델이 미세 튜닝 전에 어떤 내용을 출력할지 볼게요 입력 텍스트를 가져와서 모델이 사용할 수 있도록 토큰화하고 출력을 생성한 다음 사람이 읽을 수 있는 텍스트로 다시 해독하는 작은 편의용 함수를 추가할게요
셰익스피어의 몇 문장으로 테스트하고 어떤 응답이 나오는지 보겠습니다
튜닝하기 전에는 모델이 인용문을 정확하게 인용하는 것으로 시작해서 가장에 대한 토론으로 넘어가는 일종의 사전 내용처럼 작동하네요
사전과 대화하는 것은 재미없으니 약간의 미세 튜닝으로 활기를 추가해 볼게요
Trainer 클래스로 학습을 시작하겠습니다 앞서 정의한 매개변수를 사용하여 데이터세트를 자세히 살펴보죠
데이터세트에서 10개의 epoch만큼 실행한 후 학습이 거의 끝나갑니다
이제 이전과 동일한 입력을 사용해 볼게요
메네니우스의 흥미로운 대사네요 미세 튜닝으로 상당한 변화를 이끌어 냈습니다
이제 나중에 사용할 수 있도록 모델을 저장하겠습니다 간편한 사용을 위해 어댑터와 기본 모델을 단일 항목으로 병합할게요 모델과 함께 토크나이저도 저장해야 합니다
이제 모델을 학습시켰으니 기기에 배포하여 활용해 보고 싶네요
배포를 위해 대부분의 네트워크에서는 CoreML을 사용하는 것을 선호합니다
이는 기기에 모델을 배포하는 것에 대한 세션에서 자세히 다룰게요
이 경우 PyTorch 생태계에 남고 싶으므로 새로운 ExecuTorch 프레임워크로 모델을 배포할 수 있습니다
ExecuTorch는 추론을 위해 다양한 기기에 PyTorch 모델을 배포하는 용도로 설계되었죠 PyTorch 학습 중에 정의한 모든 맞춤형 작업은 ExecuTorch를 통한 배포에 원활하게 사용할 수 있습니다
ExecuTorch는 MPS 파티셔너로 컴퓨팅 그래프를 분석하고 MPS 기기를 통해 인식된 패턴을 가속화합니다
이러한 방식으로 로컬 기기에서 ExecuTorch를 설정할 수 있습니다
먼저 머신에 리포지토리를 복제합니다
그다음 하위 모듈을 업데이트합니다
마지막으로, ExecuTorch를 빌드할 때 MPS 바인딩을 사용하는 옵션을 전달하며 설치 스크립트를 실행합니다 이제 ExecuTorch에서 모델을 배포하는 방법을 보여 드릴게요 ExecuTorch 리포지토리의 예시에 따라 진행하겠습니다 테스트 모델로는 Meta LLaMA 2 모델을 사용할게요 이 모델은 그룹별 양자화 방법을 통해 4비트 정수 데이터 유형으로 변환되었습니다 덕분에 모델이 작아지고 속도도 빨라졌습니다
macOS에서 리포지토리를 통해 iOS를 위한 데모 앱을 빌드하고 iPad Pro를 배포 대상으로 사용할게요
앱이 빌드되면 사용할 모델과 모델을 학습시킬 때 사용한 토크나이저를 선택할게요
다음으로 라자냐를 만드는 방법을 모델에 물어볼게요
이 쿼리를 위해 LLaMA 2 프롬프트 템플릿을 사용하고 있어요 이 모델은 챗봇처럼 작동하도록 미세 튜닝되었으며 이 형식을 예상하기 때문이죠
ExecuTorch를 통해 iPad에서 로컬로 실행한 결과 모델이 저에게 저녁 식사에 좋은 메뉴를 추천하고 있네요 물론 저라면 토마토와 치즈도 추가하겠지만요
지금까지 새 기능과 개선 사항을 통해 PyTorch 워크플로 속도를 높이는 방법을 모두 다루었습니다 다음으로, MPS 그래프에서 지원되는 또 다른 인기 있는 머신 러닝 프레임워크인 JAX에 추가된 새 기능을 설명할게요
JAX-Metal 플러그인은 작년 WWDC23에 개발자를 대상으로 출시되었죠 그 이후로 플러그인은 계속 발전했으며 여기에 많은 기능 및 성능 관련 업데이트가 추가되었습니다
도입된 변경 사항에는 향상된 고급 배열 인덱싱
공식 JAX 리포지토리에서의 CI 러너 워크플로의 채택
혼합 정밀도가 포함됩니다
여기에서 출시 이후 JAX-Metal 백엔드를 채택한 사례를 2가지 소개해 드리고자 합니다 첫 번째는 MuJoCo입니다 이는 로봇 공학, 생체 역학 등 빠르고 정확한 시뮬레이션이 필요한 사용 사례를 위한 오픈 소스 프레임워크입니다
JAX-Metal 백엔드를 사용한 결과 Mac 플랫폼의 사용자에게 최고의 성능을 제공할 수 있게 되었죠
두 번째는 AXLearn입니다 대규모 딥러닝 모델 개발을 위한 라이브러리죠 Metal 백엔드를 사용한 덕분에 로컬 기기에서 워크플로를 빠르게 학습시키고 테스트할 수 있습니다
두 라이브러리를 확인해 보고 Mac 기기에서 우수한 경험을 제공하는 데 JAX-Metal 백엔드가 어떻게 기여하는지 테스트해 보세요
다음으로 JAX-Metal 백엔드에 추가된 몇 가지 개선 사항을 자세히 살펴보겠습니다 혼합 정밀도와 NDArray 인덱싱 JAX의 패딩에 대해 설명해 드릴게요
올해 도입된 업데이트 중 하나로 이제 JAX-Metal 프레임워크에서 BFloat16 데이터 유형이 지원됩니다
이 데이터 유형은 넓은 동적 범위의 부동 소수점 값을 나타내며 혼합 정밀도 훈련과 같은 사용 사례에 적합합니다
BFloat16의 활용 방식은 다른 JAX 데이터 유형과 동일합니다
예를 들어, BFloat16 데이터 유형으로 텐서를 만들 수 있습니다
또 다른 개선 사항은 NDArray 인덱싱 및 업데이트가 지원되어 numpy와 유사한 구문으로 배열을 조작할 수 있다는 것입니다
예를 들어, 행 2개와 열 2개로 구성된 작은 배열을 만들었다면 numpy 인덱싱 구문으로 첫 번째 열을 10으로 나눌 수 있죠
JAX를 통해 패딩 정책을 정의할 수 있습니다 이러한 패딩 정책이 이제 JAX-Metal 백엔드에서 지원됩니다
덕분에 요소 간에 패딩을 추가할 수 있는데 이를 다일레이션이라고 하죠
이를 위해서는 pad 함수를 호출하세요 이 함수는 차원별로 3개의 매개변수를 허용합니다
또한 네거티브 패딩으로 텐서에서 요소를 삭제할 수도 있죠
padding 구성에서 음수 값을 전달하면 됩니다
JAX 사용에 대한 간단한 예시로 JAX에 대한 설명을 마무리할게요 앞서 논의한 AXLearn 라이브러리를 사용하겠습니다 여기에서 fuji 7B 매개변수 모델을 선택하고 실행할게요 또한 앞서 논의한 BFloat16 데이터 유형을 모델에서 사용하겠습니다
스크립트가 모델에 전달할 작은 임의의 입력을 생성하고 다음 토큰을 생성하도록 모델에 요청합니다
출력에서 로그와 출력 토큰을 제공할 것입니다
예측이 완료되면 같은 스크립트를 다시 실행하되 환경 변수를 사용하여 JAX가 CPU에서 실행되도록 설정할게요
CPU 출력이 추론에 대한 Metal 백엔드의 출력과 일치하므로 데모를 마무리할 수 있습니다
이것으로 JAX에 대한 설명과 이 WWDC 세션을 마무리하겠습니다 오늘 다룬 내용을 요약해 볼게요
Apple Silicon에는 통합 메모리 아키텍처가 탑재되어 있어 머신 러닝 작업을 위한 여러 이점을 제공합니다 더 큰 모델과 더 큰 배치 크기를 사용할 수 있을 뿐만 아니라 CPU와 GPU 모두 동일한 메모리에 액세스할 수 있으므로 CPU와 GPU에 사본을 만들 필요가 없습니다
인기 있는 프레임워크인 PyTorch, JAX, TensorFlow, MLX를 위한 Metal 백엔드를 통해 Apple Silicon GPU가 선사하는 강력한 성능을 활용할 수 있습니다 올해에는 인기가 많은 트랜스포머 클래스 모델을 지원하기 위해 여러 성능 개선 사항이 도입되었습니다
이러한 업데이트를 활용하려면 프레임워크의 최신 버전을 사용하고 macOS를 업데이트하는 것을 잊지 마세요
시청해 주셔서 감사합니다 오늘 다룬 새 기능이 많은 도움이 되기를 바랍니다
-
-
찾고 계신 콘텐츠가 있나요? 위에 주제를 입력하고 원하는 내용을 바로 검색해 보세요.
쿼리를 제출하는 중에 오류가 발생했습니다. 인터넷 연결을 확인하고 다시 시도해 주세요.