The Interface Between Reinforcement Learning Theory and Language Model Post-Training
We have another technical blog post, this time by Akshay Krishnamurthy and Audrey Huang, about how ideas from reinforcement learning theory can inspire new algorithms for language model post-training.
Over the last several years, we have seen an explosion of interest and research activity into generative models—particularly large language models like ChatGPT, Claude, and Gemini—which operate via textual inputs and outputs and can be used for a variety of general-purpose tasks like question-answering, creative writing, and reasoning. At a high level, training these models comprises two phases: (1) in the pre-training phase, the model is trained on a large corpus of text to predict each token (word) given the previous tokens in each document, (2) in the post-training phase, a variety of techniques are deployed to align the model, making it suitable for downstream use. For instance, alignment techniques are used to control or steer the model away from producing inappropriate or offensive content, which is essential for safe deployment.
One of the standard approaches for language model alignment is known as Reinforcement Learning from Human Feedback (RLHF). The idea is to treat the language model as a decision-making policy and use techniques from reinforcement learning (RL) to optimize for desirable outcomes, where the notion of desirability is derived from a dataset of outcomes curated with human feedback. These RLHF approaches are pervasive; they are employed in the training of essentially every language model. This new application of RL presented an exciting opportunity for the RL research community to translate, refine, and deploy their ideas toward improving language model alignment, and progress in this direction has been rather rapid. In this blog post, we will discuss some recent advances—focusing on theoretical developments—in reinforcement learning for language model alignment.
As an outline, most of this blog post will focus on the most standard setting for RLHF. We will start with some background to set the stage, highlight the central challenge of overoptimization, and then present a new algorithm, chi-squared preference optimization, that we (in joint work with Wenhao Zhan, Tengyang Xie, Jason Lee, Wen Sun, and Dylan Foster) developed to mitigate this issue. To wrap up, we’ll briefly highlight some other work at the interface of RL theory and LLM post-training, and close with some parting thoughts.
Background
The most basic formulation of RLHF considers single-turn chat scenarios where there is a space of possible prompts and a space
of possible responses. In the RL parlance, the prompt is the state of the environment and the response is the action. There are two main ingredients: a pre-trained language model policy
which (stochastically) responds with responses to the prompts and a dataset
comprising of prompts
along with preferred and dispreferred responses
and
. For mathematical analysis, it is often assumed that this preference dataset is generated via the following process:
- Prompts
are drawn from some prompt distribution
,
- Two responses
are drawn independently from
,
- These responses are ordered as
based on the Bradley-Terry model parametrized by an unknown reward function
:
Given this dataset, we aim to learn a policy that has high reward .
A natural approach, first proposed by Christiano et al (2017), is to use the preference dataset to fit an estimated reward function
and then find a policy that has a high reward according to the estimated reward function. In practice, this is done using a reinforcement learning algorithm to optimize the KL-regularized objective
Where is the average (over prompts) KL divergence between the policies’ response distributions. We refer to this method as “standard RLHF.”
One issue with this approach is that, by using reinforcement learning for optimization, it inherits the brittleness and instability of deep reinforcement learning. To address this, Rafailov et al (2023), observed a certain duality between policies and rewards and used it to derive a much simpler method, called Direct Preference Optimization (DPO). The idea is that for any reward function , the optimal policy for the KL regularized objective above (with
instead of
) has a closed-form solution,
where is a normalizing constant that ensures
is a distribution. Rafailov et al rearranged this expression to parameterize reward functions by policies and then used this parametrization to fit a reward function
to the preference data. This essentially amounts to solving a supervised learning problem with a particularly parameterized function class/architecture, but it directly produces a policy
avoiding the need for complicated reinforcement learning subroutines.
Unfortunately, both standard RLHF and DPO have been observed to suffer from a phenomenon referred to as overoptimization (e.g., in Gao et al (2023)), where the policy degrades in quality, rather than improves, during the optimization process. As we will see in the next section, one explanation for overoptimization is that it arises from a certain statistical inefficiency of both methods, which can be addressed via a novel algorithm design.
Overoptimization hurts performance in RLHF
Overoptimization can be understood by connecting the RLHF setting to a subfield of reinforcement learning theory known as offline reinforcement learning. Although RL typically concerns an agent interacting with an environment in an online manner, it can be more practical/feasible to learn in an offline manner, from data that was previously collected by some other decision-making policy (this is also a useful subroutine in online methods). Since we are unable to interact with the environment in these settings and the dataset may not contain information/demonstrations of near-optimal behavior, a natural desideratum is to do the best we can with the data that we have, i.e., find a policy whose performance is competitive with the best policy “supported” by the data. Recent developments in the theory of offline RL have formalized such guarantees via a notion of “single-policy concentrability” (whose definition is not essential for understanding this blog post).
The fundamental challenge in offline RL is a mismatch between what we can numerically optimize—i.e., an estimate of policy performance based on the data we have—and what we care to optimize—the actual policy performance—resulting in an instance of Goodhardt’s Law. To see this in more detail, observe that the RLHF setting described above is a special case of offline RL because the dataset is collected a priori by and no other information about the ground truth reward
is available. Standard RLHF optimizes the estimated reward
as a surrogate for the true reward
, resulting in overfitting to
while achieving poor performance as measured via
. Indeed, this is precisely what is observed experimentally by Gao et al where it is referred to as overoptimization. Accordingly, this viewpoint suggests a statistical mechanism behind overoptimization: it is equivalent to the known challenge of overfitting in offline RL.
To address the overfitting challenge (and achieve guarantees based on single policy concentrability), the offline RL literature has developed algorithms based on the principle of pessimism—which quantify reward uncertainty and maximize a high confidence lower bound on reward (thus guaranteeing a certain amount of reward). Pessimism can be seen as a form of regularization which forces the optimization process to stay in the region where , avoiding overfitting. Even though existing RLHF methods (including standard RLHF and DPO) employ KL-regularization to prevent deviating from the data collection policy
, the fact that these methods overfit suggests that they are not adequately regularized. Indeed, in our paper (Proposition A.1), we construct an example showing that regularization with KL-divergence is not sufficient to achieve single-policy concentrability guarantees, thus identifying a formal limitation of existing RLHF methods.
Deep dive into Chi-squared preference optimization
Although regularization with KL-divergence is insufficient, it turns out that regularization with the -divergence—defined as
—is!
-divergence is a stronger regularizer, we have
, but, more importantly, the
-divergence more accurately captures the uncertainty about a policy
’s reward when data is collected from
. To see this in a simplified setup, suppose we have “non-preference” data of the form
where
and
. If we fit a reward estimate
via least squares over some function class
We can expect to have low in-distribution risk, say, . Considering some policy
, the difference between its true reward
and its estimated reward
can be bounded as
.
In other words, the -divergence controls the accuracy of our estimate for policy
’s reward when the reward function is trained on data collected by
. It correctly captures the uncertainty in the reward function, which is the main requirement for appropriate regularization in offline RL. Using essentially the above calculation, one can show that solving the
-regularized RLHF objective:
and appropriately tuning leads to single-policy concentrability guarantees, and thus overcomes the theoretical limitation of KL-regularized approaches.
Based on this observation, the main contribution of the paper is a “direct” variant of -regularization, analogous to DPO. The derivation also sheds some light on the favorable statistical properties of
-regularization. Recall that the DPO derivation uses the closed form solution to the KL-regularized objective, that
. Unfortunately, with
-regularization there is no closed form, but we approximately have that
. From this we can see that
regularization is much less greedy, or equivalently much more heavy-tailed than KL-regularization: it does not aggressively overfit to responses that have a high estimated reward. And although there is no closed-form solution to the
-regularized objective, we can still mostly follow the derivation of DPO to obtain our main algorithm: a “direct” method based on
-regularization—called
-preference optimization or
PO, which avoids RL-style optimization and provably achieves single-policy concentrability guarantees.
As a final note, we have run preliminary experiments with PO on the TLDR summarization task. Matching our theoretical predictions,
PO exhibits significantly less distribution shift from
, which leads to performance gains over DPO over a range of training epochs and regularization parameter settings. Notably, the performance gap between
PO and DPO grows as regularization decreases and training length increases, indicating that
PO is an effective mitigator of distribution shift.
At the same time, the fact that we do not observe large performance gains indicates that statistical overfitting is not the whole story, and suggests many avenues for further investigation. For a theoretical audience, perhaps the most interesting of these are (a) the way preference data is collected in standard benchmarks does not precisely conform to our mathematical setup (in particular it is not clear that the responses are sampled from ) and (b)
PO seems to induce a more challenging optimization landscape, in part due to the heavy-tailed nature of the ideal
-regularized distribution. The latter point raises interesting research questions regarding computational-statistical tradeoffs of direct alignment objectives, and whether we can design algorithms that retain the statistical benefits of pessimism while avoiding optimization challenges, for example, through the use of inference time computation.
Parting thoughts
As we mentioned in the introduction, the interface between RL theory and LLMs is a very active research area. Directly relevant to PO, there are several other works about mitigating overoptimization. There are a growing number of theoretical papers trying to simplify, demystify, and understand standard RLHF and DPO. There are also works that focus on reward modeling, which identify shortcomings with the Bradley-Terry model and develop algorithms based on more flexible alternatives. Finally, a direction we are currently quite excited about involves developing LLM post-training methods that deliberately gather novel information via online exploration.
To summarize, we believe there is tremendous potential for a diversity of theoretical perspectives to have an impact in language model post-training and generative AI more broadly. New formalizations and connections with other areas can lead to deeper understanding and novel algorithmic interventions. At the same time, clean mathematical testbeds, while useful, are unlikely to capture the full complexity of modern generative AI. As we’ve learned through our experience working on PO, it is important to stay grounded in the empirics to understand when and how the formalisms might break down. The upshot is that iterating between theory and practice produces a seemingly endless stream of interesting questions and opportunities.