본문 바로가기
🟣 AI & ML

아이폰의 애플 인텔리전스는 어떻게 학습되었을까? Apple Foundation Model(AFM) 논문 살펴보기 - 02

by 제리강 2024. 12. 29.

Image generated by Midjourney

 

이전 포스트:

2024.11.17 - [🟣 AI & ML] - 애플의 AI 인텔리전스는 어떻게 학습되었을까? Apple Foundation Model(AFM) 논문 살펴보기 - 01

 

AFM paper link: https://arxiv.org/abs/2407.21075

 

4. Post-Training

  • 사후 학습(post-training) 단계에서는 사전 학습된 AFM 모델의 지시 수행(instruction following), 추론(reasoning), 글쓰기(writing) 능력을 강화합니다.
  • Apple Intelligence의 주요 특징은 5장에서 설명할 어댑터(adapter)지만, 전반적인 성능 향상은 사후 학습 단계에서 이루어집니다.
  • 사후 학습은 크게 2단계로 이루어지며, 큰 맥락에서는 일반적인 LLM의 사후 학습 과정과 유사합니다. 
    1. 지도 학습 기반 미세 조정(Supervised Fine-Tuning, SFT)
    2. 인간 피드백 기반 강화 학습(Reinforcement Learning from Human Feedback, RLHF)
  • 사후 학습의 세부 과정에서 AFM에서 새롭게 제안하여 적용한 알고리즘도 있습니다. 
    • iTeC(Rejection Sampling Fine-Tuning with Teacher Committee)은 여러 개의 교사(teacher) 모델이 응답을 생성하고, 최적의 응답을 선택해 모델을 업데이트합니다.
    • RLHF 과정에서는 MDLOO(Mirror Descent Policy Optimization with Leave-One-Out Advantage Estimator)을 사용해 보다 정밀하게 모델 성능을 개선합니다.

 

4.1. Data

AFM의 사후 학습은 인간 주석(human annotation) 데이터와 합성 데이터를 결합한 하이브리드 방식을 사용하며, 철저한 데이터 큐레이션과 필터링 과정을 거쳐 높은 데이터 품질을 유지합니다.

 

4.1.1) Human Annotations

  • 다양한 소스에서 수집된 고품질 인간 주석 데이터를 사용합니다.
  • 데이터는 대화 스타일로 작성되며, 시스템 및 작업 수준의 지시(prompt)와 그에 따른 응답으로 구성합니다.
  • 유용성(helpfulness), 무해성(harmlessness), 표현력(presentation), 응답 정확성(response accuracy)을 데이터 품질 판단 기준으로 설정합니다.
  • Apple Intelligence 기능과 관련된 다양한 작업을 수행할 수 있도록 분포를 다양하게 구성합니다.
  • AFM의 성능을 반복적으로 개선하기 위해 인간 피드백(human feedback)을 추가로 수집합니다.
  • 주석가들이 동일한 프롬프트에 대해 생성된 두 개의 응답을 비교하고 순위를 매기는 side-by-side preference labeling 방식을 사용합니다.
  • 단일 질문을 통해서도 지시 수행 능력, 안전성, 사실성, 표현력 등에서 모델 응답 품질을 평가하도록 유도합니다.
  • 데이터와 모델 품질을 상호 강화하는 플라이휠(Flywheel)* 방식을 채택합니다.

*플라이휠 방식: 모델이 생성한 결과에 대한 피드백을 다시 모델에 학습시켜, 모델이 더 나은 고품질 데이터를 생성하도록 하는 선순환 구조의 학습 방식.

 

4.1.2) Synthetic Data

  • AFM 포스트 트레이닝에서 합성 데이터는 데이터 품질과 다양성을 향상시키는 중요한 역할을 합니다.
  • 합성 데이터는 특히 수학, 도구 사용, 코딩 세 가지 도메인에서 적극적으로 활용됩니다.

 

Mathematics

  • 수학 데이터는 폭넓은 주제와 난이도로 데이터를 수집하는 데 많은 자원이 필요하므로 합성 데이터를 활용합니다.
  • 문제 재구성(Rephrase)과 정답을 기반으로 역으로 문제를 생성하는 역문제 생성(Reversion)을 적용합니다.
  • 또 다른 문제 다양화 방법으로, 문제 진화(Problem Evolution) 방식을 적용합니다. 문제 복잡도를 추가하는 심화(In-depth) 진화와 주제를 다양화하는 확장(In-breadth) 진화의 두 가지 진화 방식을 사용합니다.
  • 생성된 문제는 중복 제거 작업을 거치고, 임베딩 모델과 LLM을 사용해 정합성과 해결 가능성을 평가합니다.
  • 생성된 문제에 대해 AFM이 사고 사슬(Chain-of-Thought) 기반의 응답을 생성하는데, 정답 데이터가 있는 경우 이를 모델 정답에 대한 보상 신호로 사용해 응답 품질을 필터링합니다. 정답 데이터가 없을 경우 LLM으로 응답의 정확성을 평가합니다.

 

Tool Use

  • 함수 호출, 코드 해석, 웹 브라우징 등 도구 사용 능력을 개발합니다.
  • 학습 초기에는 단일 도구 사용 사례 중심의 합성 데이터를 활용하고, 이후 멀티 도구 사용 및 다단계 작업 처리 능력을 개선하기 위해 인간 주석 데이터를 수집합니다.
  • 오라클 도구(Oracle Tool)*와 유사한 도구를 혼합해 도구 선택 난이도를 높입니다. 이는 모델의 적절한 도구 선택 능력을 강화시킵니다.
  • 함수 호출 데이터를 병렬 함수 호출 및 도구 의도 감지 데이터로 확장해 도구 과잉 호출 문제를 완화합니다.

*오라클 도구: 주어진 문제를 해결하는 데에 적절하다고 판단되는 '정답' 도구.

 

 

Coding

  • Self-Instruct 방식과 거부 샘플링(Rejection Sampling)을 사용합니다.
  • 먼저, 71개 프로그래밍 주제를 기반으로 한 코딩 인터뷰 스타일 문제의 초기 풀(initial pool) 생성합니다.
  • 이 문제에 대해 유닛 테스트와 여러 솔루션을 생성한 뒤, 실행 기반 거부 샘플링*으로 최적 솔루션을 선택합니다.
  • 솔루션을 컴파일하고 유닛 테스트를 실행해 성공률이 가장 높은 솔루션을 선택합니다.
  • 이 과정을 통해, <문제, 테스트 케이스, 솔루션> 구조의 트리플렛(triplet) 데이터를 생성합니다.
  • 유닛 테스트 통과율을 기준으로 데이터를 필터링해, 12,000개의 고품질 트리플렛 데이터를 SFT에 사용합니다.

*실행 기반 거부 샘플링: 생성된 솔루션 코드를 실행하여 기준을 통과하지 못한 코드는 제외되는 방식의 샘플링 방식(으로 추정됨).

 

 

4.2 Supervised Fine-Tuning (SFT)

SFT데이터는 모델 학습에 활용되기 전, 라벨러의 평가, 모델 기반 필터링 기법, 텍스트 임베딩을 통한 중복 제거와 같은 품질 관리 절차를 거칩니다.

 

데이터 혼합 비율 조정

  • 데이터 구성 요소의 혼합 비율을 최적화 문제로 설정하여 최적 비율을 구합니다.
  • 특정 데이터 구성 요소의 비율(wi)을 조정하며 학습(wi → wi ± ∆wi)을 진행하고, 벤치마크 평가를 통해 품질 변화를 측정합니다.
  • 반복 실험을 통해 최적의 데이터 혼합 비율을 식별하고, 영향이 적은 데이터 구성 요소를 제거합니다.

 

훈련 설정(Hyperparameter)

  • 학습률은 AFM-server에서 5e−6, AFM-device에서 2e−5, 드롭아웃 비율은 0.1로 설정됩니다.
  • 평가 지표가 체크포인트마다 변동되는데, 자동 평가 벤치마크와 Best-of-N 선택 방식을 활용해 최적의 체크포인트를 선택합니다.

 

4.3 Reinforcement Learning From Human Feedback (RLHF)

RLHF는 모델의 성능과 품질을 개선하기 위해 인간 선호(human preference) 데이터를 기반으로 보상 모델을 학습하고, 이를 iTeC와 MDLOO 알고리즘에 적용합니다.

 

4.3.1) Reward Modeling

  • 보상 모델은 인간 선호 데이터를 사용해 학습되며, 동일 프롬프트에 대해 두 개 응답 중 선호되는 응답과 그 선호 수준(매우 우수, 우수, 약간 우수, 차이가 거의 없음)을 매깁니다.
  • 단일 측면 평가(Single-Sided Grading)을 통해 각 응답의 지시 수행 능력, 간결성, 진실성, 무해성을 측정합니다.
  • 보상 모델 학습은 RLHF의 표준 방식을 따르지만, 다음 두 가지 새로운 측면의 알고리즘을 포함합니다.

 

소프트 라벨 손실 함수(Soft label loss function)

  • 인간 선호 수준을 고려하여 설계된 손실 함수로, 인간의 선호 수준이 높을수록(예: 매우 우수) 해당 응답의 보상이 더 높아지도록 유도합니다.
  • 이는 Llama 2에서 사용된 마진 기반(margin-based) 손실 함수와는 다른 방식으로 동작하며, 실험 결과 더 우수한 성능을 보입니다.
  • 단일 측면 평가 결과를 손실 함수의 정규화 항으로 사용해 보상 모델의 정확성을 향상시킵니다.

 

BTL(Bradley-Terry-Luce) 모델

  • BTL 모델은 인간 주석자가 어떤 응답을 다른 응답보다 선호할 확률을 두 응답 보상의 차이에 대한 시그모이드(Sigmoid) 함수로 모델링합니다.
  • 소프트 라벨 손실 함수는 BTL 모델에 기반하며, 선호 수준이 높을수록 선호 응답의 확률이 높아지도록 유도합니다.

 

4.3.2) Iterative Teaching Committee (iTeC)

iTeC는 다양한 선호 최적화 알고리즘(Rejection Sampling; RS, Direct Preference Optimization; DPO, Identity Preference Optimisation; IPO, Online Reinforcement Learning 등)을 결합한 Iterative Committee라는 집합을 이용해 여러 차례의 RLHF를 반복적으로 수행하는 정렬(alignment) 프레임워크입니다.

  • Iterative Committee는 반복적인 RLHF을 수행하는 모델의 집합을 말하며, SFT, RS, DPO/IPO, RL으로 훈련된 최신 모델과 이전 반복에서 성능이 우수했던 모델들로 구성합니다.
  • 다양한 모델 응답에 대해 인간 선호 데이터를 수집하며, 쌍 데이터 비교 선호(pairwise preference) 평가를 통해 데이터를 확보합니다.
  • Iterative Committee를 이용한 학습 프로세스는 다음과 같습니다.
    1. 새로 수집된 인간 선호 데이터를 기반으로 보상 모델을 갱신합니다.
    2. 다양한 선호 최적화 알고리즘으로 모델 세트를 훈련합니다.
    3. 최신 Committee를 구성해 다음 라운드의 RLHF 데이터를 수집합니다.
  • 이러한 방식은 여러 알고리즘의 강점을 함께 활용할 수 있습니다.
    • 온라인 RLHF, DPO, IPO은 수학과 같은 복잡한 추론 능력을 개선합니다.
    • Rejection Sampling 기반 조정은 지시 수행과 글쓰기 능력을 강화합니다.
  • iTeC는 모델 규모에 따라 적용 양상에서 차이를 보입니다.
    • 대형 모델에서는 데이터와 모델 품질의 반복적 개선이 성능에 중요합니다.
    • 소형 모델에서는 프롬프트 수를 확장하여 대규모로 증류(distillation)을 수행하면 성능이 크게 향상됩니다.
    • 소형 모델인 AFM-on-device 모델은 모델 위원회에서 생성된 100만 개 이상의 고품질 응답으로 학습됩니다.

 

4.3.3) Online RLHF Algorithm: MDLOO

  • MDLOO는 AFM 학습 중 응답을 디코딩하고, 강화학습 알고리즘을 적용해 보상을 극대화하는 온라인 RLHF 알고리즘입니다. 
  • MDLOO는 RLHF의 표준 목표 함수를 사용하며, KL-정규화된 보상을 최대화합니다.

$$\operatorname*{max}_{\theta}\mathbb{E}_{x\sim\mathcal{D},y\sim\pi_{\theta}(\cdot|x)}\left[r_{\phi}(x,y)-\beta D_{\mathrm{KL}}\left(\pi_{\theta}(\cdot|x)\|\pi_{\mathrm{ref}}(\cdot|x)\right)\right],$$

  • 현재 정책과 기준 정책간의 차이를 조정하는 계수입니다. 
  • 전체 응답 생성을 하나의 행동으로 간주하는 '밴딧(Bandit)' 설정을 적용합니다. 이는 로봇 제어, 게임같은 복잡한 작업이 아닌 여러 응답을 중 하나를 선택하는 단순한 구조의 작업의 강화학습에 적합합니다. 
  • 복잡한 상태-행동-보상 구조 대신 단순한 응답-보상 관계를 학습합니다.

 

Leave-One-Out(LOO) estimator

  • 특정 응답의 보상(one)을 제외하고(leave out) 동일 프롬프트에서 생성된 다른 응답(one을 제외한 나머지)들의 평균 보상 간의 차이를 계산합니다.
  • 이는 알고리즘의 안정화와 성능 향상에 중요한 역할을 합니다.

 

Mirror Descent Policy Optimization(MDPO)

  • 정책 최적화에 MDPO 알고리즘을 사용하며, 이는 기존 Clipping 기반의 PPO 방식보다 효과적인 것으로 나타났습니다.
  • PPO는 강화 학습 시 미리 지정한 정적인 신뢰 영역(trust region)을 가져 이 범위를 넘어가면 보상을 제한하지만, MLDO는 KL-정규화를 기반으로 신뢰 영역을 최적화하면서 정책 업데이트를 제어합니다.

 

 

댓글