Scalable MatMul-free Language Modeling

520 Views

September 30, 24

スライド概要

シェア

またはPlayer版

埋め込む »CMSなどでJSが使えない場合

関連スライド

各ページのテキスト
1.

Scalable MatMul-free Language Modeling Rui-Jie Zhu et al., University of California, Santa Cruz Alfredo Solano, Matsuo Laboratory 1

2.

Overview • Introduction • Method • Experiments • FPGA • Conclusion References: – Paper: https://arxiv.org/pdf/2406.02528 – Code: https://github.com/ridgerchu/matmulfreellm/tree/master 2

3.

Introduction • Motivation – General matrix-matrix multiplication (GEMM) is a dominant operation in neural networks – It has O(N^3) time complexity in the worst case • twice the input results in eight times the cost – GPUs are designed for this • hardware lottery – Used in both training and inference – If GEMM use could be reduced / simplified, training and inference time should be reduced 3

4.

Method • To simplify GEMM: – replace with simpler operations • AdderNet replaces multiplication with addition in CNNs – use binary / ternary quantization • • • • binary: w = { 0, 1 } ternary: w = { -1, 0, 1 } Binary activations in Spiking Neural Networks (SNN) Binary / Ternary weights in Binary / Ternary Neural Networks (BNN / TNN) – BitNet showed it is possible to scale binarized transformers to 3B • • Replace Linear with BitLinear Keep standard attention – Q, K calculated dynamically – require custom CUDA kernels for optimization – When testing a ternary quantization of Bitnet attention layers, it failed to converge 4

5.

Method (2) • Replace linear layers BitLinear layers with ternary values – multiplication is replaced by addition and negation • Replace weight matrix W values with w = { -1, 0, 1 } – GEMM is then conceptually similar to: • np.where(x==1) - np.where(x==-1) • Not efficient enough, use custom kernel • Optimize memory access with fused root mean squared normalization (RMSNorm) and quantization of activations – reduce HMB I/O costs – memory size is already reduced by ternary values 5

6.

Method (3) 6

7.

Method (4) • Replace attention with other mechanism • View the transformer as: – token mixer: sequence / temporal information: self-attention, mamba, etc. – channel mixer: embedding / spatial information: feed-forward, GLU, etc. • For the token mixer: – ternarize Q, K matrices to get a ternary attention map • fails to converge – replace self-attention with a modified gated recurrent unit (GRU) • • simpler RNN-based architecture replaces GEMM with element-wise operations and accumulation 7

8.

Method (5) Standard GRU 8

9.

Method (6) • MatMul-free GRU – remove hidden-state related weights (Wcc, Whr, Whf) – remove activation between hidden states (tanh) – enables parallel computation – add a data-dependent gate between hidden state and output – decouple candidate state from hidden state 9

10.

Method (7) MatMul-free GRU 10

11.

Method (8) • For the channel mixer – use a gated linear unit (GLU), similar to latest LLMs like Llama, Mistral, etc. • uses only dense layers – make it use BitLinear layers 11

12.

Experiments • Training details – use a surrogate gradient to handle non-differentiable functions like sign, clip, etc. • via Straight-Through Estimator – larger learning rate than traditional transformers • small LRs may lead to no weight updates after clipping – learning rate scheduler • • • shows different learning dynamics cosine scheduler halve midway through 12

13.

Experiments (2) • Compare against advanced transformer architecture from Llama 2 – named Transformer++ in the charts – MatMul-free • Three model sizes: 370M, 1.3B, and 2.7B • All models pre-trained on the SlimPajama dataset – 370M model trained on 15 billion tokens, and the 1.3B and 2.7B models trained on 100 billion • x8 NVIDIA H100 GPUs – ~5 hours for the 370M model – ~84 hours for the 1.3B model – ~173 hours for the 2.7B model 13

14.

Experiments (3) Loss graph 14

15.

Experiments (4) • Loss curve initially better for MatMul-free • Then is taken over by Transformer++ • Scaling projections seem to indicate a steeper descent – more efficient resource usage – projected to intersect at 10^23 flops (similar to Llama 3 8B) • but only 3 data points • Downstream tasks – multiple benchmarks: ARC-Challenge, Hellaswag, Winogrande, etc. – zero-shot – results show competitive performance 15

16.

Experiments (5) Downstream tasks 16

17.

Experiments (6) • Training efficiency – Vanilla BitLinear compared to Fused BitLinear – Fused operator benefits from larger batch sizes • • • faster speed: 25.6% speedup for the 1.3B reduced memory: 61% reduction for the 1.3B more samples are being processed in a time step • Inference efficiency – MatMul-free LM compared to Transformer++ – Lower memory usage and latency • • 4.9 GB vs 48.5 GB for the 1.3B 695 ms vs 3184 ms for the 1.3B 17

18.

FPGA • Field-programmable gate array – configurable integrated circuit (IC) – lower level than GPU, higher than ASIC (application specific) • To test efficiency on hardware that supports ternary operators • Programmed in Verilog • Deployed on Intel Cloud 18

19.

FPGA (2) Verilog RTL 19

20.

FPGA (3) • Clock rate of 60Hz • Around 13W • Implemented single core, estimate the multi-core setting based on that • 1.3B model projected at 42ms and 23.8 tokens/second – human reading speed – low power consumption 20

21.

Conclusion • MatMul-free models are feasible • Performance can be comparable to standard transformers – reduces memory usage and latency • GPUs are optimized for GEMM though, custom hardware may be needed • Code is available on GitHub: – https://github.com/ridgerchu/matmulfreellm – compatible with HuggingFace libraries – CUDA kernels implemented with Triton language • Needs to be tested on larger-scale models (100B+ parameters) 21