Saturday, July 22, 2023

Costs in Training LLMs

 I went through the Llama-2 white paper that was released with the model by meta. I was hoping to learn some special technique they may be using to train their models. Apparently, there isn't any. The learning process is straightforward. What is different is the huge costs associated with fine tuning after the model is trained. This fine tuning requires human interaction and feedback. To incorporate the feedback, the model has to be altered that requires more computation. Training, fine tuning the model costs more than $20M (~$4 per hour and 5M hours). This immediately limits the number of players who will actively develop LLMs. The cost of adding safety to these models (e.g. block prompts for ransom letters etc.) is almost as high as cost of training the model. 

Another interesting tidbit from the paper was the assertion that RoCE based 200 Gbps interconnected cluster was adequate and more economical than Inifiband based cluster. RoCE uses commodity ethernet. If one can train 70B parameter model with trillions of tokens using commodity ethernet with RDMA, what is the compelling need to move to expensive NVLink linked superchips based systems? May be they are overfitting? (pun intended)

There is a significant cost to building these models that are shared with public (unknowingly) i.e. they are climate related. The carbon emission of these clusters is shown in the paper at 539 tonnes of CO2e. It took 3.3M hours of GPU (A100-80G). All of this to chat with a bot?

I found more benchmarks and metrics related to safety, climate and other social concerns in the paper than what one finds a technical paper. 

It was easy to play with the model using oobaooga's text gen UI. I used the 13B parameter model from the family of Llama-2s released. It is a bit dated. You can see for yourself. 




Sunday, July 09, 2023

Learning a Model

 Neural Networks have a bad reputation of being very confident when they are wrong. This is the result of a bad probability estimates being calculated (i.e. learned). They also suffer from adversarial attacks. Training is the activity that takes the most time in arriving at a functional LLM. Besides collecting, curating and integrating data sets, we have to also navigate around pot holes by employing optimization techniques on the objective function. Objective function or goal seeking function is a function that takes data and model parameters as arguments and outputs a number. The goal is to find values for these parameters which either maximize or minimize this number. Maximum likelihood (MLE) is one of the most often used function that fits this task of finding the set of parameters that best fit the observed data.

LLMs have three model architectures (a) encoder only (BERT) (b) decoder only (GPT) (c) encoder-decoder (T5). Looking at (b), which is a probability distribution over a word given the prompt, which is arrived by taking a smoothed exponential (softmax) of scores calculated using scaled dot products between each new prediction word with prompt, we use MLE to find the best distribution that fits the observed data.

Stochastic gradient descent (SGD) and ADAM (adaptive moment estimation) are two common methods used to optimize the objective function. The latter is memory intensive. There are many knobs like size of a floating point, calculating moments (more value per parameter), changing learning rates among others that can be used to learn a model. Sometimes the knob settings result in generic learning and other times overfitting. More often than not we just don’t converge i.e. no learning. ADAM is popular optimizer (I use it on the open source transformers from hugging), it keeps 3X more values per parameter than vanilla SGD. AdaFactor is an optimization on ADAM to reduce memory consumption but has been known to not always work.

A rule of thumb in ML is gather more data as a dumb model with lots of data beats a smart model with limited data. But training on large amount of data is costly. It used computational resources and as the whole process is iterative, we need fast processors to crunch through the data so the model can be adapted and iterated upon. More data does not guarantee convergence i.e learning. The whole exercise in learning a model looks and feels like art bordering on black magic than anything analytic or scientific. If modeling felt like a recipe than this is like cooking the recipe. The end result has a lot of variance.  

Llama 3 - More ways to run it, but still nothing new

 Llama 3 is out and getting to it can be a challenge. The approval email's URL expires in 24 hours. It can take 8hrs to download. But af...