Formulating Prompt Structuring as an RL Problem

With the rise of foundational language models (FLMs), there has been a lot of work done on prompt optimization, both towards tuning instructions as well as in-context examples (demonstrations). As a macro-level optimization, one can also think of the structure of the prompt as a search space, and optimize that first, followed by finer tuning of the individual sections. A prompt can have (a subset of) the following sections/parts:

A sample prompt structure for our task.

We will look at this task as a running example for the formulation.

Does prompt structuring matter?

There is no fixed answer, since it depends on the FLM, task, and dataset of choice. This post makes sense in use cases where prompt structuring matters.

Notes:

  • Results based on evaluating the model on some test examples (or even all of them) are not statistically significant. So one should run some SS tests (such as z-score and p-value against the null hypothesis).
  • At what level do we expect an optimal structure \(\mathcal{S}\) to generalize? It could be at task type level (say classification), individual task level, test-set level, or even individual test example level. This should direct us to what algorithm(s) we pick for the task.
  • Why should it matter? (can mechanistic interpretability help answer this?)

Problem Statement

Given a foundational language model \(\mathcal{L}\), a task \(\mathcal{T}\), and an initial prompt \(p\) that is structured as an ordered list of sections as follows:

$Latex p = p_{1}p_{2} \cdots p_{k}&s=2$

the goal is to create a prompt structure \(p^{*}\) that maximizes the performance of a test set \(\mathcal{X}\) on the task \(\mathcal{T}\).

Assumptions and design choices:
  • The test set \(\mathcal{X}\) comes from a distribution \(\mathcal{D}\), which is the same as for training data as well as ICL examples.
  • The prompt can be a general prompt or a query-dependent specific prompt \(p_x\).
  • The creation of the prompt can also be query-dependent or independent.

Based on these assumptions and design choices, we can solve prompt structuring by formulating it as one of the following three problems:

*credits

Multi-armed Bandit

For \(k\) sections, we describe a MAB as a tuple \(\mathcal{<A, R>}\), where:

  • There are \(K = 2^k*k!\) machines (prompt structures) with reward probabilities \({\theta_1 \cdots \theta_K}\).
  • At time step \(t\), an action \(a_t\) picks a machine and the environment (the FLM when run with that prompt structure) gives a reward \(r\) (say the \(0-1\) loss).
  • For the reward function \(\mathcal{R}\), we assume it to be stochastic and independent of the query. So \(r_t = \mathcal{R}(a_t) \) would be \(1 \) with probability \(\theta_{a_t} \).

The goal is to maximize the cumulative reward \(\sum_{t=1}^{T} r_t \), or to minimize the corresponding cumulative regret/loss. Then we can use any MAB strategy such as \(\epsilon\)-greedy, UCB1, etc.

Notes:

  • Even for \(k=5 \), the total number of machines is \(3840 \), which is very high. To reduce it, we can put constraints such as the task and query must always be present, and that the task, definition, and instruction must always be before the query, and that “think again”, etc. should always be after the task and instruction. We should aim to get no more than \(50\) machines to start with.
  • Note: We’re not sure that such seemingly illogical orderings would actually perform worse, and a random sampling of the machines might also make sense.

Contextual Bandit

It is possible (even likely) that the best prompt structure is dependent on the query, and different for each query. With this assumption, we can formulate the problem as a partial or full-information i.i.d. contextual bandit problem:

  • At each round \(t = 1, 2, \cdots , T \):
    1. The learner gets a context \(x_t ~ D \).
    2. Using policy \(\pi \), the learner takes an action \(a_t = \pi(x_t) \in \mathcal(A) \)
    3. The environment (FLM call) gives a reward \(r_t ~ D(r|x_t, a_t) \).

The objective is thus minimization of the following regret:

\(\text{Regret}_T=\max\limits_{\pi \in \Pi} \sum\limits_{t=1}^T \mathbb{E}_D[r(\pi(x))]-\sum\limits_{t=1}^T \mathbb{E}_D\left[r\left(a_t\right)\right]\)

In theory, we could use the greedy algorithm (\(\pi_t =\mathop{argmax}\limits_{\pi \in \pi} \sum\limits_{s=1}^{t-1} r_s(\pi(x_s))\)) to solve the full-information i.i.d CB where we get a reward for all actions. But this will become a bit too expensive. Thus, we can use the \(\tau\)-greedy algorithm, or the online cover algorithm for the partial-information case.

Markov Decision Process (MDP)

Following Tempera, we can formulate test-time, query-dependent prompt structuring as an MDP as the tuple \(<\mathcal{S}, \mathcal{A}, \mathcal{R}, \mathcal{T}, \gamma>\), where:

  • The state s is the prompt and the query with the starting state as \((p_0, x)\). Embeddings can be used to get continuous states sampled from a distribution \(\mathcal{D}\).
  • The action space, \(\mathcal{A}\), consists of adding or deleting a section, or swapping two sections. For \(k\) sections, \(|\mathcal{A}| = 2*k + (k)(k-1)/2\).
  • At time step \(t\), taking action \(a_t\) takes us to a new state, and evaluating the FLM on that state gives us a performance measure. Since we do not want to optimize for cumulative rewards, we take the difference of the performance of consecutive states in the trajectory as our reward function \(\mathcal{R}\).
  • The transition is fully deterministic from the previous state to the next one, and \(\gamma\) is used as the discount factor for a fixed horizon \(\mathcal{H}\).

We can then use any RL algorithm (PPO has been seen to perform well) for solving this MDP. We should start with certain techniques Tempera found to be crucial for RL stability: observation and reward normalization and conditioning the policy on action history.

Notes:

  • The constraints put in the case of bandits can be maintained here as well by ending trajectories with illegal action sequences and penalizing them with bad rewards.
  • The same RL algorithm can later be used for joint optimization of the individual sections as well by simply adding to the action space \(\mathcal{A}\).
  • We can also try this same method without test-time optimization by not passing in the query as part of the state. The tempera paper showed this to perform well too.

Discussion

  • How do we justify macro-structuring followed by micro instruction-tuning and ICL optimization? Can we say that macro-structuring is invariant under changes to the individual sub-parts? If not, would it be better to treat it as a joint optimization problem?
  • Basically, does it make sense to optimize micro sections first and then the structure, or the structure first?
  • Which of the three formulations (MAB/CB/MDP) make the most sense?

Thank you!

Share on:

,

Discover more from 7vik

Subscribe now to keep reading and get access to the full archive.

Continue reading