본문 바로가기
🟣 AI & ML

애플의 AI 인텔리전스는 어떻게 학습되었을까? Apple Foundation Model(AFM) 논문 살펴보기 - 01

by 제리강 2024. 11. 17.

Image generated by Midjourney

 

TL;DR

지난 10월 애플 인텔리전스(Apple intelligence)를 탑재한 iOS 신규 버전이 출시되었습니다. 애플 인텔리전스를 통해 각종 텍스트 작업이나 앱 간 상호작용, 이미지 편집 작업 경험을 개선할 수 있다고 합니다. 애플 인텔리전스의 타겟 시장이 주로 모바일 디바이스인 만큼, 경량의 온디바이스(on-device) 모델 구축에 특히 집중한 것으로 보입니다. 이번 포스트에서는 애플 인텔리전스의 파운데이션 모델에 대한 기술 보고서라 할 수 있는 논문 'Apple Intelligence Foundation Language Models' 를 살펴보며 애플이 모델을 어떻게 구성했는 지 살펴보겠습니다.
본 주제는 3회에 나누어 다룰 예정되며, 이번 포스트에선 사전 학습(pre-training) 파트를 먼저 살펴보겠습니다.

 

 

AFM paper link: https://arxiv.org/abs/2407.21075

 

 

1. Introduction

  • Introduction 파트에서는 모델 구조와 학습 프로세스를 대략적으로 소개하고 있습니다.
  • 학습 과정은 기존 LLM 모델과 크게 다르지 않습니다. 데이터 수집 및 전처리, 사전 학습, 사후 학습, 최적화로 이루어집니다.
  • 개인 정보나 윤리 규약을 철저히 준수했다는 점을 특히 강조하고 있습니다. 이는 논문 전반에 걸쳐 반복해서 강조되며, 최근 미국 정부가 AI 윤리를 강조하고 있는 점과 관련이 있어 보입니다.

Image from AFM paper

 

  • 애플 인텔리전스의 파운데이션 모델(Apple Foundation Model, 이하 AFM)은 '서버(AFM-server)' 파트와 '온디바이스(AFM-on-device)' 나누어 구축됩니다. 
  • AFM-on-device는 약 3B 규모이며, 모바일 기기에서의 사용을 목적으로 구축됩니다. 
  • AFM-server의 모델 크기는 공개되어 있지 않으며, AFM-on-device 구축의 바탕이 되는 대규모 모델입니다.
  • 논문에 자세히 설명되진 않지만, 시각 데이터를 처리하는 멀티모달 기능도 갖추고 있음을 언급하고 있습니다.

 

2. Architecture

  • AFM도 대부분의 LLM 모델의 바탕이 되는 decoder-only transformers 구조를 바탕으로 합니다.
  • 모델에 적용된 주목할 만한 요소들을 다음과 같이 소개하고 있습니다.

 

1. 입출력 공유 임베딩 매트릭스(Shared Input/Output Embedding Matrix - Press and Wolf, 2016)

  • 모델의 입력과 출력을 동일한 임베딩 매트릭스를 공유하도록 설계하여 모델이 사용하는 파라미터 수를 줄이고 메모리 효율성을 높입니다.

 

Image from X - @rasbt

 

 

2. Pre-Normalization & RMSNorm(Nguyen and Salazar, 2019; Zhang and Sennrich, 2019)

  • Pre-Normalization은 입력이 모델의 각 레이어로 들어가기 전에 정규화하는 방식으로, 훈련 초기 단계부터 안정적인 학습을 가능하게 합니다.
  • RMSNorm은 평균 제곱근을 이용한 정규화 방식으로, 각 뉴런의 출력 분포가 일정하게 유지되도록 합니다.

 

3. Query/Key 정규화(Wortsman et al., 2023)

  • Transformer의 핵심 구성 요소인 Query와 Key의 값을 정규화하여, 훈련 중 지나치게 큰 값이 나오는 것을 방지하고 모델이 안정적으로 작동하게 합니다.


4. Grouped Query Attention(GQA - Ainslie et al., 2023)

  • Transformer의 Attention 메커니즘에서 Query를 여러 그룹으로 나눈 뒤, 각 그룹이 별도로 Key/Value와 상호작용하도록 설계한 방식입니다.
  • Attention 메커니즘의 메모리 사용량(KV-cache footprint)을 줄입니다.

 

Image from Ainslie et al., 2023

 

 

 

5. SwiGLU Activation(Shazeer, 2020)

  • 기존의 ReLU나 GELU 같은 활성화 함수 대신, 두 개의 활성화 함수인 SWish와 GLU(Gated Linear Units)를 결합한 방식입니다.
  • 이는 모델이 정보의 흐름을 보다 효율적으로 처리할 수 있도록 돕습니다.

 

ReLU와 SwiGLU 비교. Image from velog - @alstjsdlr0321.

 


6. RoPE Positional Embeddings(Su et al., 2024)

  • Transformer 모델에서 위치 정보를 추가하는 방식 중 하나로, sinusoidal 방식을 변형하여 더 긴 context 지원이 가능하도록 설계되었습니다.
  • AFM에서는 RoPE의 기본 주파수(frequency)를 500k로 설정하여 장기적 의존성을 처리할 수 있도록 합니다.

 

RoPE 구현 양상. Image from Su et al., 2024.

 

 

3. Pre-training

Pre-training 파트에서는 데이터 수집과 사전 학습 프로세스를 소개하며, 다음과 같은 점을 주로 강조하고 있습니다

  • 데이터셋에는 라이선스를 받은 데이터와 공개된 데이터셋, Applebot을 통해 수집한 공개 정보가 포함됩니다.
  • Applebot은 자체 구축한 크롤링 시스템으로 보이며, 미리 설정한 여러 지침을 준수하여 데이터를 수집한다고 합니다.
  • Apple 사용자 개인 데이터는 포함하지 않습니다.
  • 비속어, 안전하지 않은 자료, 개인 식별 정보 등을 포함하지 않도록 철저히 검토합니다.
  • 객관적인 모델 평가를 위해 벤치마크 데이터셋에 포함된 데이터는 학습에서 제외합니다.
  • 데이터의 양보다는 품질을 더욱 중시함.

 

3.1 Data

각 데이터셋 및 데이터셋 처리에 사용된 Tokenizer에 대한 설명입니다.

 

3.1.1 ) Web pages

웹 페이지에서의 데이터 추출 및 필터링 과정의 주요 요소를 다음과 같이 설명하고 있습니다.

  • 본문 추출: Safari의 리더 모드와 Boilerpipe 알고리즘을 결합하여 문서의 본문만을 추출함.
  • 안전 및 비속어 필터링: 휴리스틱과 모델 기반 분류기를 사용해 비속어와 위험 요소가 있는 페이지를 걸러냄.
  • 중복 제거: locality-sensitive n-gram hashing 기법으로 전역적으로(globally) 문서 중복을 제거함.
  • 품질 필터링: 휴리스틱과 모델 기반 분류기를 사용해 높은 품질의 데이터를 확보함.
  • 벤치마크 데이터 필터링: 811개의 일반적인 사전 학습 평가 데이터셋을 기준으로, 특정 n-그램이 기준 데이터셋과 일치할 경우 필터링함. 단, 해당 n-그램이 1000회 이상 사용된 일반적인 표현으로 판단되면 필터링에서 제외함.

 

3.1.2) Licensed Datasets

  • 제한된 고품질 데이터를 확보하기 위해 출판사로부터 라이선스를 획득하고, 다양한 문맥 길이의 고품질 데이터를 제공하는 이 데이터를 사전 학습 데이터에 포함.
  • 웹 페이지 데이터와 동일한 방법으로 라이선스 데이터를 정제하고 검증.

 

3.1.3) Code

  • GitHub의 오픈 소스 저장소에서 MIT, Apache, BSD, CC0 등 허용된 라이선스의 데이터를 사용하여 코드 데이터를 수집함.
  • Swift, Python, C, Objective-C, C++, JavaScript, Java, Go 등 14개의 주요 프로그래밍 언어가 포함됨.
  • PII(Personally Identifiable Information) 제거 및 품질 필터링을 수행하여 개인 식별 정보를 걸러냄.
  • 웹 페이지 데이터와 동일한 방법으로 라이선스 데이터를 정제하고 검증.

 

3.1.4) Math

웹에서 수집한 고품질 수학 데이터를 두 가지 유형으로 구분하여 사용합니다.

  • 첫 번째는 수학 Q&A 데이터셋으로, 수학 콘텐츠가 풍부한 20개 웹 도메인에서 수집된 30억 개 토큰으로 구성됨. HTML 페이지의 관련 태그를 식별해 질문과 답변을 추출함.
  • 두 번째는 수학 커뮤니티 데이터로 수학 포럼, 블로그, 튜토리얼, 세미나 등에서 수집된 140억 개 토큰의 웹 페이지 데이터로 구성됨.

수학 데이터 필터링 과정은 다음과 같습니다.

  • 수학 태그 필터: 40개의 문자열을 사용해 수학 템플릿을 식별.
  • 수학 기호 필터: 350개의 Unicode 및 LaTeX 기호를 사용해 수학 콘텐츠를 검출.
  • 품질 필터: 수학을 위한 언어 모델 분류기 사용.
  • 도메인 필터: 사람이 수동으로 레이블링한 도메인에서 데이터 수집.
  • 중복 제거, PII 제거 작업을 통해 최종 데이터셋을 완성.

 

3.1.5) Public Datasets

  • 사용 허가가 명확히 된 고품질의 공개 데이터셋을 선별하여 활용합니다.
  • 개인 식별 정보를 모두 제거하여 데이터의 안전성을 확보합니다.

 

3.1.6) Tokenizer

Tokenizer로는, SentencePiece의 byte-pair encoding (BPE) 방식을 사용합니다. 세부 설정은 다음과 같습니다.

  • 모든 숫자는 개별 숫자로 나뉘어 처리됩니다.
  • 알 수 없는 UTF-8 문자의 경우, 바이트 토큰으로 분해하여 처리합니다(byte-fallback).
  • AFM-server는 10만, AFM-on-device는 4.9만 개의 토큰으로 구성된 어휘 집합을 사용합니다.

 

3.2 Recipe

Recipe 파트에서는 사전 학습 프로세스를 소개하고 있습니다. AFM의 사전 학습은 세 단계로 진행됩니다.

  • 첫 번째는 Core 학습 단계로, 대부분의 계산 자원을 사용해 모델의 기본 성능을 구축합니다.
  • 두 번째 Continued 학습 단계에서는, 낮은 품질의 웹 데이터를 줄이고 코드와 수학 데이터, 라이선스 데이터를 활용해 특화된 능력을 강화합니다.
  • 세 번째 Context-Lengthening 단계에서는 이전 단계에서보다 긴 context 데이터를 포함하여 모델이 장기 문맥을 처리할 수 있도록 학습합니다.
  • 모든 학습 단계에서는 Decoupled Weight Decay와 단순화된 버전의 µParam 기법을 사용하며, float32와 bfloat16 변환을 통해 학습 효율성을 높입니다.

 

3.2.1 ) Core pre-training

 

AFM-server와 AFM-on-device 모델의 Core 단계 사전 학습은 서로 다른 접근법을 사용합니다.

 

AFM-server

  • AFM-server는 scratch(완전한 초기 가중치) 상태에서 6.3T 토큰에 대해 학습합니다.
  • 8192 TPUv4 칩을 사용하고 시퀀스 길이는 4096, 배치 크기는 4096 시퀀스로 설정됩니다.
  • 배치 크기는 원래 모델 크기와 컴퓨팅 자원에 맞춘 스케일링 법칙에 따라 설정되었으나, 예측된 최적 배치 크기(~3072)와 실제 칩 활용도를 최대화하는 배치 크기(4096)가 반드시 일치하지는 않았다고 합니다. 결국 칩 활용도를 최적화하는 배치 크기로 사용한 것으로 보입니다.
  • 최적 학습률은 0.01로 설정되었으며, µParam을 사용해 linear layer에서는 학습률이 약 0.1로 조정됩니다.
  • 가중치 감소는 3.16e-4로 설정하고, 학습률 스케줄은 5000 단계의 linear warm-up 이후 코사인 감소(cosine decay) 방식을 적용해 학습을 진행합니다.

 

AFM-on-device

  • AFM-on-device는 더 큰 모델로부터 지식을 distillation하며 구조적(structural) pruning을 거친 모델을 사용해 학습합니다.
  • 먼저, 6.4B 모델에서 pruning된 구조로 초기화하고, Soft-Top-K masking 방식을 사용해 Feed-Forward 레이어의 숨겨진 차원을 pruning한다.
  • 학습 초기 단계에서 188B 규모 토큰에 대해 동일한 데이터 혼합 비율로 mask를 학습(다음 토큰 예측)한 후, Core 사전 학습에서는 distillation loss를 적용해 학습합니다.
  • distillation은 원 레이블과 teacher 모델의 예측(0.9 가중치 부여)을 결합해 사용했고, 이는 MMLU와 GSM8K 성능을 각각 5%와 3% 향상시켰습니다.
  • 모든 하이퍼파라미터는 AFM-server와 동일하게 유지하며, 배치 크기만 변경합니다. 이를 통해 데이터 효율성과 벤치마크 성능을 개선했습니다.

 

3.2.2) Continued Pre-Training

  • AFM-server와 AFM-on-device 모두 Continued 사전 학습에서는 시퀀스 길이를 8192로 확장하고, 총 1T 토큰을 사용하여 학습합니다.
  • 이 단계에서는 수학과 코드 데이터의 비중을 높이고, 일반 웹 크롤링 데이터의 비중을 줄입니다. 또한, 3.1.2에서 설명한 라이선스 데이터도 포함합니다.
  • 학습률 설정은 Core 사전 학습과 다르게, 최대(peak) 학습률을 3e-4로 설정하고, decoupled 가중치 감소는 1e-5로 조정합니다.
  • 학습 초반에는 1000단계 동안 학습률을 warm-up한 뒤, 최종 학습률인 0.001까지 코사인 감소(cosine decay) 스케줄을 적용합니다.
    • 학습률에 대해서는, 쉽게 말하면 학습 초반에는 높은 학습률로 속도를 높이는(동시에 local minimum을 피하며) 학습 후반에는 낮은 학습률로 정밀도를 높인다고 이해하면 됩니다.
  • 배치 크기 등의 다른 설정은 Core 사전 학습에서 그대로 가져옵니다.
  • AFM-on-device의 경우 Core 사전 학습에서 유용했던 distillation loss가 Continued 단계에서는 효과적이지 않아 사용되지 않았습니다. 따라서, 이 단계의 학습 방식은 AFM-server와 동일하게 진행됩니다.

 

3.2.3) Context Lengthening

  • Context Lengthening 단계에서는 시퀀스 길이를 32,768 토큰으로 확장하고, 총 100B 토큰에 대해 학습을 진행합니다.
  • 이 과정에서는 Continued 사전 학습에서 사용된 데이터에 합성된 장문 Q&A 데이터를 추가하여 장문 학습을 강화합니다.
  • RoPE(Rotary Positional Embedding)의 기본 주파수를 500k에서 6315089로 증가시키며, 이는 Liu et al., 2024에서 제시된 스케일링 법칙을 따릅니다.
  • 이러한 조정은 짧은 문맥 데이터를 학습한 모델이 장문 데이터를 더 잘 일반화할 수 있도록 돕는데, 이는 대부분의 사전 학습 데이터가 32k 토큰보다 짧은 문서로 구성된 점을 고려한 설계입니다.
  • 이 단계의 학습 과정은 Continued 사전 학습과 유사하게 구성됩니다.

 

3.2.4) Optimizer

  • 최적화 알고리즘으로는 RMSProp(Hinton, 2012)의 변형된 버전을 사용하며, 모멘텀(momentum)을 포함합니다.
  • 모멘텀은, 관성이라는 뜻처럼 이전 단계 학습 방향에 가중치를 두어 안정적인 학습 변화를 막는다는 의미입니다.
  • 기울기 안정성을 위해 일정 수준 이상의 기울기를 특정 값으로 내려버리는 클리핑(clipping)을 수행합니다.
  • 이 외에는 다소 지엽적인 내용으로, 간단히 보고 넘어갑니다.

 

3.3 Training Infrastructure

  • AFM 모델의 사전 학습은 v4 및 v5p Cloud TPU 클러스터에서 진행되며, JAX 기반의 딥러닝 라이브러리인 AXLearn 프레임워크를 활용합니다. 
  • 각종 병렬화 방법 - 텐서 병렬 처리, 완전 분산 데이터 병렬 처리(Fully-Sharded Data Parallel), 시퀀스 병렬 처리를 조합하여 학습 효율을 극대합니다.
  • AFM-server 학습은 8192개 TPUv4 칩으로 학습을 진행했으며, 8×1024 구조의 칩 슬라이스(slices)로 구성됩니다.
  • AFM-on-device 학습은 2048개 TPUv5p 칩으로 구성된 단일 슬라이스에서 학습을 진행합니다.
  • 이러한 인프라 설계를 통해 높은 연산 성능을 유지하면서 확장 가능성을 극대화한다고 합니다.

 

 

개인적으로 생각하는, 본 파트에서 나타나는  AFM의 주요 특징은 다음과 같습니다. 

  1. 타 사의 경량 모델이 플래그십 모델의 보급형처럼 여겨지는 것과 달리, AFM은 경량인 on-device 모델이 중심이 되고 이 on-device 모델성능을 최대화하기 위해 server 모델이 필요한 것처럼 보입니다. 이는 애플의 강점이 모바일 생태계인 점이 반영되었다고 볼 수 있습니다.
  2. 장문 데이터에 대한 처리 능력을 높이기 위해 RoPE 알고리즘과 context lengthening 학습 과정을 적용한 것이 눈에 띕니다. 이는 모바일 디바이스에서의 멀티턴(multi-turn) 형식의 메시지 데이터나 이메일 같은 장문의 입력에 대응하고자 하는 것으로 보입니다. 

 

그럼 다음 포스트에서는, 사후 학습(post-training)부터 다루어 보도록 하겠습니다.

 

댓글