TRL 라이브러리를 사용한 Lora 튜닝 시 주의사항
카테고리 없음

TRL 라이브러리를 사용한 Lora 튜닝 시 주의사항

TL;DR

Llama나, Mistral 등의 파인튜닝을 할 때, tokenizer.pad_token = tokenizer.eos_token을 하는 사람은 읽어보는게 좋습니다.

TRL의 DataCollatorForCompletionOnlyLM가 상속받는 로직으로 인해, 문장 끝에 eos token가 label -100으로 치환되며, 계산이 안되는 이슈.

내 모델이 계속 말을 끊지 않고 이어나갈 경우, 이 이슈가 문제일 가능성이 높음.

TRL 라이브러리 운영자는, Warning 추가 후 별도 조치 없음.

 

최소한의 솔루션

Vicuna에서 따온 방식인데, pad_token을 unk_token으로 대체하는 방법

또는 정석대로 pad_token을 별도로 추가하는 방법이 있으나, 모델 설정과 토크나이저 설정을 꽤 많이 건드려야 하므로, unk token이 별도로 있으면 unk_token을 pad_token으로 설정하는 것을 추천

 

 

문제의 발단

최근 경량화 학습이 보편적으로 퍼지게 되면서, 그리고 예제 코드가 퍼지게 되면서 저도 기본 코드를 많이 따와서 쓰는 편이었고, Transformers + TRL 조합이 편해서 주로 학습툴로 많이 사용하는 편이었습니다.

Polyglot 학습부터 한 2번은 당한 것 같고(eos 토큰을 못뱉고 억지로 말을 지어내는 현상), 이때는 pad 토큰을 추가하는 방식으로 문제를 잠깐 스쳐지나갔었습니다.

그러나 문득, 최근에는 이런일이 Transformers 만 사용할 때는 겪지 않았던 일이라는 것을 떠올리게 되었고, TRL에 문제가 있을 것 같은 직감을 가지게 되었고, 비슷한 현상을 겪는 사람이 제공한 코드 템플릿을 보면 대게 Lora를 사용한 경우를 식별했습니다.

 

그렇다면 어디서부터 문제인가?

내부적으로 연산이 되지 않는 다는 것에서 두가지 경우를 생각했습니다

1. Attention_mask에서 pad_token과 eos_token을 구분하지 못하게 되면서 모두 0으로 하는건가?

2. 모델 내부에서 연산할 때, -100으로 처리 된 것은 로스 계산을 하지 않으므로, 모델 내부에서 잘못되는 건가?

 

우선 토크나이저 부터 확인했습니다

허깅페이스 토크나이저 결과

토크나이저는 값을 입력할 때, 좌측 padding된 2 값을 0으로, eos 토큰 역할을 하는 2는 1로 마스킹을 정확하게 하는 것을 확인할 수 있습니다.

 

그렇다면 이 값이 모델로 들어가면서, -100으로 바뀌어 간다는 부분이 유력해지는 부분 입니다.

 

TRL 기반의 학습 베이스는 주로

Lora 모델 준비, 데이터셋 준비, 데이터 Collator 준비, TrainingArgument, SFTTrainer 방식으로 이뤄지게 됩니다. 나열된 순서와는 별개로, SFTTrainer 객체에 데이터셋을 넣는 순간 데이터를 토크나이징 하고, 학습 간 배치에 Collator가 동작을 하게 될겁니다.

위와 같이, 토크나이징에서 문제가 없는 것을 확인했으니, 다음은 Collator를 확인할 차례 입니다.

VScode에서는 해당 객체를 Ctrl + 좌클릭으로 접근해서 볼 수 있기 때문에, 탐색하기에 매우 편리합니다.

허깅페이스의 DataColltor 상속 받은 모습
학습 간 Collator가 동작하면서 사용하는 함수

위 torch_call에서는 -100으로 바꾸는 로직을 찾을 수 없었고, 첫줄을 보면 부모 객체의 torch_call을 사용하는 모습을 볼 수 있으니, 부모 객체로 한번 더 타고 들어갑니다.

 

부모 Collator 객체의 torch_call 함수

코드가 많고, 모르겠는게 많지만 결국은 내부 동작을 알고 있다면( 연산하지 않는 토큰은 -100으로 바꾸는 것) 원인을 바로 찾을 수 있을 겁니다.

맨 아래쪽에 보면 pad_token_id와 같은 값이 있으면 모두 -100으로 바꿔버리는 코드가 있었네요. 요즘에는 보통 tokenizer의 pad token을 eos token으로 동일하게 설정해버리기 때문에.. 이전에 만들어졌던 코드와 이런 호환성이 안맞게 된 케이스가 생긴 것 같습니다.

그렇다면 왜 바꾸지 않는가?

모르겠습니다. Transformers가 바꿔야 할까요? 그러기엔, Transformers에 별도로 Seq2Seq Collator가 존재하며, 해당 이슈를 없앨 수 있습니다.

TRL이 바꿔야 할까요?  TRL 측은 간단한 Warning만 추가하며, 사용자 스스로 바꾸게끔 하였습니다

참조: https://github.com/huggingface/trl/pull/988#event-10958286174

다소 미온적인 태도가 아쉬운 마음에, 원인을 분석하고 공유해서 다른 분들은 이런 이슈를 사전에 예방하길 바라는 마음에 글을 올립니다.

무엇보다 같은 실수로 여러번 학습을 돌리는 불상사가 없기를 바랍니다.

반응형