In this blog post, I will dive into some of the math foundations of natural language generation (NLG).
In language modeling, the likelihood of generating a word \(w_i\) is given by the conditional probability \(P(w_i|w_1, ..., w_{i-1})\). To generate a coherent and relevant sequence of \(t\) words using causal language modeling (CLM), three components are essential: a language model \(M\), a prompt \(X\) and a decoding algorithm.
Prompts are task-specific instructions used for guiding a model's output (Amatriain, 2024; Sahoo et al., 2024). In other words, prompts are the user's input. Thus, the model \(M\) generates the subsequent text by maximizing the probability of each word \(w_i\), conditioned not only on the preceding generated words \(w_1, ..., w_{i-1}\), but also on the prompt \(X = x_1, ..., x_m\). Mathematically, this can be expressed as:
where \(t\) is the length of the generated text sequence (Bengio et al., 2003; Vaswani et al., 2017).
A way to find effective prompts is through prompt engineering, i.e., the systematic crafting of prompts (Amatriain, 2024; Sahoo et al., 2024). In general, many approaches exist. In this blog post, I only will cover some of the most widely used prompt engineering techniques.
Zero-Shot Prompting, introduced by Radford et al. (2019), involves directly instructing a model to perform a task without providing any examples or demonstrations, meaning the model must leverage its pre-existing knowledge to generate an answer (Sahoo et al., 2024). An example of a zero-shot prompt is:
“Classify the text into neutral, negative or positive.
Text: I think the vacation is okay.
Sentiment:”
Domain-specific fine-tuning or “instruction tuning” has been shown to improve the output when using zero-shot prompts (Wei et al., 2022).
Few-Shot Prompting, introduced by Brown et al. (2020), comprises a prompt that contains an instruction and a few examples to induce a better understanding of the task and generate a better response (Brown et al., 2020; Sahoo et al., 2024). It is called “few-shot” because the model is given a small number of examples (“shots”). An example of a few-shot prompt is:
“This is awesome! // Negative
This is bad! // Positive
Wow that movie was rad! // Positive
What a horrible show! //”
Note that few-shot prompting is most effective when model size is large enough (Kaplan et al., 2020).
Chain-of-Thought (CoT) Prompting, introduced by Wei et al. (2022), is an approach that enables complex reasoning capabilities through intermediate reasoning steps, i.e., by providing a step-by-step explanation or reasoning process that led the model to its final answer. This approach helps language models better handle tasks requiring multi-step reasoning, such as mathematical problem solving, logical deduction, or comprehension tasks. An example of a (one-shot) CoT prompt is:
“Q: Roger has 5 tennis balls. He buys 2 more cans of tennis balls. Each can has 3 tennis balls. How many tennis balls does he have now?
A: Roger started with 5 balls. 2 cans of 3 tennis balls each is 6 tennis balls. 5 + 6 = 11. The answer is 11.
Q: The cafeteria had 23 apples. If they used 20 to make lunch and bought 6 more, how many apples do they have?”
Note that CoT prompting can be combined with zero-shot or few-shot prompting (cf. Kojima et al., 2022; Wang et al., 2023). Furthermore, as with few-shot prompting, when using CoT, the model must be large enough to be able to reason (Wei et al., 2022).
Tree of Throughts (ToT) Prompting, introduced by Long (2023) and Yao et al. (2023), is a technique the model maintains a tree of thoughts, where thoughts serve as intermediate steps toward solving a problem, encouraging the model to self-evaluate. The tree of thoughts is then traversed using breadth-first search (BFS) or depth-first search (DFS). An example of a single ToT prompt is:
“Imagine three different experts are answering this question.
All experts will write down 1 step of their thinking,
then share it with the group.
Then all experts will go on to the next step, etc.
If any expert realises they're wrong at any point then they leave.
The question is...”
Decoding algorithms are then used to translate the probabilities of a sentence \(s\) with \(t\) words \(w_1, ..., w_{i-1}\) into coherent and intelligible text (natural language). The choice of a decoding algorithm has been shown to significantly affect the quality and diversity of the generated text (Ippolito et al., 2019; Wiher et al., 2022; Zhang et al., 2021). In this blog post, I will cover several common decoding methods.
Greedy Search is the simplest decoding method. It selects the token with the highest probability as its next word at each timestep, aiming to generate the most likely sequence. Mathematically, this can be expressed as:
$$y_t = \arg \max_{y} P(y|y_1, ..., y_{t-1}, x)$$where \(y_t\) is the token with the highest probability at timestep \(t\), \(y_1, ..., y_{t-1}\) are the previously generated tokens and \(x\) is the input. The main problem with this approach is ...
Beam Search maintains multiple candidate sequences (“beams”), allowing it to explore multiple potential outputs, thereby reducing the risk of missing high probability sequences. At each step \(t\), for each partial sequence \(y_1, ..., y_{t-1}\) in the beam, the algorithm considers the probability distribution over the next possible tokens \(P(y_t|y_1, ..., y_{t-1},x)\), and extends the sequence by selecting the top \(B\) tokens according to:
where:
Other scoring functions, such as diverse beam search (DBS; Vijayakumar et al., 2017), have been propsed to address the issue of lack of diversity in the generated text.
Top-k sampling, introduced by Fan et al. (2018), is a truncation-based stochastic method, where, at each step of the text generation, the algorithm selects the \(k\) highest-probability tokens. Given a distribution \(P(w_i|w_1, ..., w_{i-1})\), its top-\(k\) vocabulary \(V^{(k)} \subset V\) is defined as:
where \(V^{(k)} \subset V\) is the top-\(k\) vocabulary, \(\text{rank}(\cdot)\) represents the rank of the probability of word \(w\), where the rank is determined by sorting the probabilities in descending order, and \(k\) is the pre-defined threshold.
A more formal formulation given by Meister et al. (2023a) and Meister et al. (2023b) is:
$$V^{(k)} = \underset{V^{(k)} \subset V}{\arg \max} \sum_{w \in V^{(k)}} P(w_i|w_1, ..., w_{i-1}) \quad \textit{s.t.} \quad |V^{(k)}| = k$$Let \(Z = \sum_{w \in V^{(k)}} P(w_i|w_1, ..., w_{i-1})\). The next word is then again sampled from a re-scaled version of the previous probability distribution:
The limitation of top-\(k\) sampling is that having a fixed vocabulary size \(k\)across the whole text generation process can increase the risk of generating generic and even incoherent text (Holtzman et al. 2020).
Nucleus sampling (also known as top-p sampling), introduced by Holtzman et al. (2020), is also a truncation-based stochastic method, where, at each step of the text generation, the algorithm selects the next token from the smallest set of tokens, whose cumulative probability meets or exceeds a threshold \(p\). Mathematically this can be expressed as:
$$\sum_{w \in V^{(p)}} P(w_i|w_1, ..., w_{i-1}) \ge p$$where \(V^{(p)} \subset V\) is the top-\(p\) vocabulary and \(p \in [0,1]\) is the pre-determined threshold. This threshold is usually set at \(0.7 \le p \le 0.95\), as this generates consistent good quality text (DeLucia et al., 2021; Holtzman et al., 2020).
A more formal formulation given by Meister et al. (2023a) and Meister et al. (2023b) is:
Let \(Z = \sum_{w \in V^{(p)}} P(w_i|w_1, ..., w_{i-1})\). The next word is then again sampled from a re-scaled version of the previous probability distribution:
As with top-\(k\), the probabilites of all words that are not in the top-\(p\) vocabulary at a given time are set to 0 and \(P^{*}\) vary across the text generation process.
Temperature, introduced by Ackley et al. (1985), is a method that modifies the probability distribution before sampling, i.e., before generating each word, controlling the level of “creativity” of the model. Mathematically, temperature \(t\) is a positive scalar applied to the logits \(z\) (raw outputs of the model):
$$P^{*}(w_i|w_1, ..., w_{i-1}) = \frac{\exp(z_i/t)}{\sum_j \exp(z_j/t)}$$where \(t \in [0,1)\), where \(t=1\) would preserve the original probability distribution. Temperature is often combined with techniques like top-\(k\) or top-\(p\) sampling to achieve a balance between diversity and coherence in the generated text.
\(\boldsymbol{\eta}\)-sampling (pronounced eta-sampling), introduced by Hewitt et al. (2022), is an entropy-based method that improves on top-\(p\) sampling by introducing a dynamic threshold. Mathematically, \(\eta\)-sampling can be expressed as obtaining the truncation set \(\mathcal{C}(y_{\lt t})\) at each timestep:
$$\mathcal{C}(y_{\lt t}) = \{y \in V \mid P(y|y_{\lt t}) \gt \eta\}$$where \(\eta = \min(\epsilon, \sqrt{\epsilon} \exp(-H(P(y|y_{\lt t}))))\), \(H(\cdot)\) denotes the entropy, and \(\epsilon = 0.0009\) (cf. Hewitt et al., 2022).
Let \(Z = \sum_{w \in \mathcal{C}(y_{\lt t})} P(w_i|w_1, ..., w_{i-1})\). The next word is then once again sampled from a re-scaled version of the previous probability distribution:As with top-\(k\) and top-\(p\) sampling, the probabilites of all words that are not in the truncation set at a given time are set to 0 and \(P^{*}\) vary across the text generation process.
Locally typical sampling, introduced by Meister et al. (2023a), is also an entropy-based method that uses conditional entropy and (locally) typical sets, i.e., sets of words that fall within a specific range of the conditional entropy at each step of the text generation process. Mathematically, locally typical sampling can be expressed as a subset optimization problem to define the truncation set \(\mathcal{C}(y_{\lt t})\):
where:
The aim of locally typical sampling is to minimize the absolute difference between the conditional entropy and the log-probability for each word in the truncation set.
Domain-specific fine-tuning, also known as “instruction tuning” (Wei et al., 2022), is a technique in NLG, where a pre-trained models' parameters are adapted to perform optimally within a specific field or subject area. It involves further training a pre-trained model on a domain-specific dataset \(D\) and, in doing so, adjust the model's parameters \(\theta\) to better capture the nuances and intricacies of the target domain.
Mathematically, domain-specific fine-tuning can be expressed as a optimization problem to find the set of parameters \(\theta^{*}\) that minimize the loss function \(L(\theta)\):
$$\theta^{*} = \arg \min_{\theta} \sum_{(x_i,y_i) \in D} - \log P(y_i|x_i,\theta)$$where:
The optimization is typically performed using gradient-based methods like stochastic gradient descent (SGD) or its variants (e.g., Adam optimizer). The parameters are updated iteratively:
$$\theta \leftarrow \theta - \eta \nabla_{\theta}L(\theta)$$where:
Moreover, regularization techniques such as weight decay and dropout are often employed during fine-tuning to prevent overfitting. Regularization adds a penalty term \(R(\theta)\) to the loss function:
where \(\lambda\) is a hyperparameter.
By incorporating domain-specific fine-tuning, the model becomes adept at generating text that is not only grammatically correct but also contextually appropriate for the domain. This results in more coherent, relevant, and specialized outputs, enhancing academic and industry applications.
Prompts, decoding algorithms, and domain-specific fine-tuning are all ways to affect the quality of machine-generated text.