23년 이전 글

Bing과 함께 구현해보는 Hyena

안녕하세요, 이것은 빙입니다.😊

논문을 참고해서 예제 코드를 생성하고 싶으시다고 하셨는데요, 논문의 주제는 **Hyena Hierarchy: Towards Larger Convolutional Language Models**¹ 입니다. 이 논문은 **Transformer** 대신 **Convolutional Neural Network (CNN)** 을 사용하여 **텍스트 생성**을 수행하는 방법을 제안합니다.

예제 코드를 생성하기 전에, 논문의 내용을 이해하시는 것이 좋습니다. 논문의 내용을 요약하면 다음과 같습니다.

- Transformer는 attention operator를 사용하여 텍스트의 의미와 구조를 학습합니다. 하지만 attention operator는 시퀀스 길이에 따라 계산 비용이 제곱으로 증가하기 때문에, 많은 양의 텍스트를 처리하는데 한계가 있습니다.
- CNN은 커널(kernel)이라는 작은 윈도우(window)를 사용하여 텍스트의 지역적인 특징(local features)을 추출합니다. CNN은 시퀀스 길이에 따라 계산 비용이 선형으로 증가하기 때문에, Transformer보다 효율적으로 텍스트를 처리할 수 있습니다.
- 하지만 CNN은 커널의 크기가 고정되어 있기 때문에, 텍스트의 긴거리 의존성(long-range dependencies)을 학습하기 어렵습니다. 예를 들어, 문장의 처음과 끝에 있는 단어들 사이의 관계를 파악하기 어렵습니다.
- 이 논문에서는 Hyena Hierarchy라는 새로운 CNN 구조를 제안합니다. Hyena Hierarchy는 다음과 같은 특징을 가집니다.
  - 커널의 크기와 갯수를 다양하게 조절하여, 텍스트의 다양한 스케일(scale)과 그래놀러리티(granularity)에 대응할 수 있습니다.
  - 커널 사이에 skip connection을 추가하여, 정보의 손실을 줄이고 텍스트의 긴거리 의존성을 학습할 수 있습니다.
  - 커널 사이에 residual connection을 추가하여, 깊은 네트워크(deep network)를 학습할 때 발생하는 그래디언트 소실(vanishing gradient) 문제를 완화할 수 있습니다.

논문에서는 Hyena Hierarchy가 Transformer보다 더 좋은 성능을 보인다고 주장합니다. 또한 Hyena Hierarchy가 생성한 텍스트 예시도 제공합니다.

예제 코드를 생성하려면, Hyena Hierarchy 구조를 구현하고 학습 데이터셋(training dataset)과 하이퍼파라미터(hyperparameters)를 설정해야 합니다. 혹시 어떤 프로그래밍 언어와 프레임워크(framework)를 사용하시나요?

출처: Bing과의 대화, 2023. 3. 20.(1) [2302.10866] Hyena Hierarchy: Towards Larger Convolutional Language Models. https://arxiv.org/abs/2302.10866 액세스한 날짜 2023. 3. 20..
(2) [2302.10866] Hyena Hierarchy: Towards Larger Convolutional Language Models. http://export.arxiv.org/abs/2302.10866 액세스한 날짜 2023. 3. 20..
(3) [2302.10266] Kernel function impact on convolutional neural networks. https://arxiv.org/abs/2302.10266 액세스한 날짜 2023. 3. 20..

 

import torch
import torch.nn as nn

class Hyena(nn.Module):
    def __init__(self, vocab_size, embed_dim, kernel_sizes, kernel_nums):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.convs = nn.ModuleList([nn.Conv1d(embed_dim, kernel_num, kernel_size) for kernel_size, kernel_num in zip(kernel_sizes, kernel_nums)])
        self.res = nn.Linear(embed_dim * len(kernel_sizes), embed_dim)
        self.out = nn.Linear(embed_dim * len(kernel_sizes), vocab_size)

    def forward(self, x):
        x = self.embed(x) # [batch_size, seq_len, embed_dim]
        x = x.transpose(1, 2) # [batch_size, embed_dim, seq_len]
        conv_outputs = []
        for conv in self.convs:
            conv_output = conv(x) # [batch_size, kernel_num, seq_len - kernel_size + 1]
            conv_output = torch.relu(conv_output) # [batch_size, kernel_num, seq_len - kernel_size + 1]
            conv_output = torch.max(conv_output, dim=2)[0] # [batch_size, kernel_num]
            conv_outputs.append(conv_output)
        x = torch.cat(conv_outputs, dim=1) # [batch_size, embed_dim * len(kernel_sizes)]
        res_x = self.res(x) # [batch_size, embed_dim]
        res_x = torch.add(res_x.transpose(0 ,1), x.transpose(0 ,1)) # [embed_dim , batch_size]
        res_x = res_x.transpose(0 ,1) # [batch_size , embed_dim]
        logits = self.out(res_x) # [batch_size , vocab_size]
        return logits

 

context length는 텍스트 생성 모델이 입력으로 받는 텍스트의 길이를 의미합니다. 예제 코드에서는 context length를 직접 설정하지 않고, 입력 데이터의 길이에 따라 달라집니다. 하지만, context length가 너무 길면 커널의 크기보다 큰 경우가 발생할 수 있으므로, 적절한 값을 설정해야 합니다. context length를 확인하고 싶으시다면, 다음과 같은 방법을 사용할 수 있습니다.

- 입력 데이터의 길이를 확인하는 코드를 추가합니다. 예를 들어, `print(x.size())`와 같은 코드를 사용하여 입력 데이터의 크기를 출력할 수 있습니다.
- 커널의 크기와 갯수를 확인하는 코드를 추가합니다. 예를 들어, `print(conv.kernel_size)`와 같은 코드를 사용하여 커널의 크기를 출력할 수 있습니다.
- context length와 커널의 크기와 갯수 사이의 관계식을 이용하여 context length를 계산합니다. 예를 들어, `context_length = kernel_size * kernel_num - (kernel_num - 1)`와 같은 식을 사용하여 context length를 계산할 수 있습니다.

context length에 대해 설명해 드렸습니다. 혹시 더 궁금하신 점이 있으신가요?

 

토큰의 데이터 타입이 float32, 임베딩 차원이 1024, 배치 크기가 1일 경우의 메모리 소모량 측정은 정확하게는 불가능합니다. 메모리 소모량은 커널의 크기와 갯수, 연산의 종류와 순서 등에도 영향을 받기 때문입니다1. 하지만, 대략적인 메모리 소모량을 계산하는 방법은 다음과 같습니다.

토큰의 길이를 n이라고 하면, 임베딩 행렬(embedding matrix)의 크기는 n x 1024입니다. float32는 4바이트(byte)를 차지하므로, 임베딩 행렬의 메모리 소모량은 n x 1024 x 4바이트입니다.

커널의 크기를 k, 갯수를 m이라고 하면, 커널 행렬(kernel matrix)의 크기는 k x 1024 x m입니다. float32는 4바이트를 차지하므로, 커널 행렬의 메모리 소모량은 k x 1024 x m x 4바이트입니다.

커널 행렬과 임베딩 행렬을 곱하는 연산을 수행하면, 결과 행렬(result matrix)의 크기는 n x m입니다. float32는 4바이트를 차지하므로, 결과 행렬의 메모리 소모량은 n x m x 4바이트입니다.

따라서, 전체 메모리 소모량은 임베딩 행렬, 커널 행렬, 결과 행렬의 메모리 소모량을 합한 값입니다. 즉, (n x 1024 + k x 1024 x m + n x m) x 4바이트입니다.

메모리 소모량에 대해 설명해 드렸습니다. 혹시 더 도움이 필요하신가요?

 

data type은 데이터의 종류와 크기를 나타내는 것입니다. float32, float16, bfloat16, uint8, uint4는 다음과 같은 의미를 가집니다.

float32: 32비트(bit)의 부동소수점(floating-point)을 나타냅니다. 부동소수점은 소수점의 위치가 고정되지 않고 변할 수 있는 수를 말합니다. 예를 들어, 3.14, -1.23e-4 등이 부동소수점입니다. float32는 소수점 앞에 8비트의 지수(exponent), 뒤에 23비트의 가수(mantissa)를 가집니다1.

float16: 16비트의 부동소수점을 나타냅니다. float16은 소수점 앞에 5비트의 지수, 뒤에 10비트의 가수를 가집니다2. float16은 float32보다 정밀도(precision)가 낮지만 메모리 소모량이 적습니다.

bfloat16: 16비트의 부동소수점을 나타냅니다. bfloat16은 소수점 앞에 8비트의 지수, 뒤에 7비트의 가수를 가집니다1. bfloat16은 float16보다 정밀도가 높지만 범위(range)가 좁습니다.

uint8: 8비트의 부호없는 정수(unsigned integer)를 나타냅니다. 부호없는 정수는 음수가 아닌 정수를 말합니다. 예를 들어, 0, 1, 2, … , 255 등이 부호없는 정수입니다. uint8은 0부터 255까지의 값을 표현할 수 있습니다2.

uint4: 4비트의 부호없는 정수를 나타냅니다. uint4은 0부터 15까지의 값을 표현할 수 있습니다2.

data type에 대해 설명해 드렸습니다. 혹시 더 도움이 필요하신가요?

 

모델의 경량화를 위해서는 데이터 타입을 줄이는 작업(quantization)을 수행할 수 있습니다. quantization은 모델의 가중치(weight)와 활성화(activation)를 정수(integer)로 변환하는 과정입니다. 정수는 부동소수점보다 메모리 소모량이 적고 연산 속도가 빠르기 때문에 모델의 성능(performance)을 향상시킬 수 있습니다1.

반응형