Tractable ICL Tuple Optimization

The problem of optimization of in-context-learning demonstration tuples is an important problem because foundational models seem to have shown a major increase in the performance over a task when provided with in-context examples. We want to make exponential-complexity demonstration optimization tractable.

Problem Statement

Consider a set \(\mathcal{T}\) of the train-set (the so-called train set for ICL), from which we can pick a maximum of \(k\) examples in order (we start with picking exactly \(k\) examples for now). The (which can depend on the query at hand \(x\), but not on the label \(y\):

\(\vec{S} = \langle(x_1, y_1), (x_2, y_2), \cdots (x_k, y_k)\rangle; S \subseteq \mathcal{T} \)

\(\vec{S}\) needs to be picked such that the performance of a separate validation set \(\mathcal{V}\) on a task is maximized (for some metric/loss/performance measure \(l\):

\(\mathcal{L}(\mathcal{T}, \vec{S}, \mathcal{V}) = \sum\limits_{(x, y) \in \mathcal{V}} l(f_{\theta} (y|p, x, \vec{S}), y)\)

for some initial prompt \(p\). The optimal ICL tuple is thus:

\(\vec{S}^* = \mathop{\text{argmin}}\limits_{\vec{S}} \mathcal{L}(\mathcal{T}, \vec{S}, \mathcal{V})\)

Complexity

If we were to directly evaluate all possible combinations and pick the best one, that would give us a perfect global optimum but will be combinatorially explosive (number of LLM calls):

\(|\mathcal{V}| \sum\limits_{j=0}^{k}{T \choose j} \cdot j!\)

Thus, we need to reduce this search space and/or optimization time. In traditional settings, a number of ways have been suggested, but some of them are not likely to work here. For instance, a number of papers found diversity-based selection to not work. This is expected because those methods were designed to optimize the learning of probabilities. But here, we’re interested in the learning of conditional probabilities. Conditional on the query datapoint, the prompt instruction, and the internal representation of everything the model has learned. Thus, we consider the following things:

Tractability

The first thing we note is that we need to change the multiplicative factor of this exponential into an additive factor by picking examples sequentially instead of parallelly. That is, we replace the optimization problem to instead pick the \(j\)th example as follows:

\(\vec{S}_{i+1} = \vec{S}_i \oplus t_{i+1}; S_i \subseteq \mathcal{T}; t_i = (x_i, y_i) \in \mathcal{T} \)

And we return \(\vec{S}_{k}^*\) as the best ICL tuple, where for all \(i\):

\(\vec{S}_{i+1}^* = \vec{S}_{i}^* \oplus \mathop{\text{argmin}}\limits_{t_{i+1} \in{\mathcal{T}}} \mathcal{L}(\mathcal{T}, \vec{S}_{i}^* \oplus t_{i+1}, \mathcal{V}) \)

RL Formulation

Secondly, we use a reward model (a smaller model that we can train/finetune) to approximate the loss/score of the conditional output after picking the first \(i\) examples:

\(A_{\pi, \phi}( \widetilde{p}, \widetilde{x_v}, \widetilde{ \vec{S}{i}}, t_{i+1}) \approx \mathcal{L}(.) – \mathcal{L}(t_{i+1}) \)

\(\forall x_v \in \mathcal{V}, t_{i+1} \in \mathcal{T}; |\phi| << \theta|\)

the “~” hats on something signify conversion into FP vectors to feed into the model

Once the network \(Q_{\phi}\) is trained, we can use it directly to get a tuple of example for any query in the test set.

Thus, this becomes an RL problem where we have the following MDP:

  • The state s is the prompt, the query, and the ICL tuple given so far \(<p, x, \vec(S_i)>\), where S is initially the empty tuple.
  • The action space, \(\mathcal{A}\) is simply picking the next ICL example.
  • We end when we have chosen \(k\) example.

At time step \(t\), we can pick a random example from the training set with a probability \(\epsilon\), and at other times pick the best one given by the policy. We do LLM calls to get the reward/loss value of that step, and store the step information in a replay buffer.

We then sample batches from the buffer and update our network.

Another option is to solve the same MDP using Policy Gradients (PG) methods such as AC and PPO. Everything else is very similar here, except there’s a policy network \(\pi_{\phi}\) instead of a value network \(A_{\phi}\).

Diff. w/ a slightly related paper:

A Nov 2022 paper (Active Example Selection for In-Context Learning) tried RL for ICL optimizations, but this formulation is pretty different from theirs:

  1. They do not consider any prompt/instruction prior. (like most 2022 ICL papers)
  2. They do not perform query-dependent selection. So their examples selected for all train samples are the same.
  3. They train a 3-layer MLP as the model and do Q-learning. There’s scope of improvement in both these things.
  4. They train a model using GPT-2, and evaluate the same thing on GPT-2 and GPT-3 and share results, saying that their method didn’t work on larger models due to the emergent capabilities of larger models.

Implementation Details

WIP: To be decided and updated later.

  • We start with a number of datasets/tasks, and split each of them into train-val-test splits (like in usual ML), and keep the test-set totally untouched for evaluation.
  • For each task, we set the following stuffs:
    • S: is a model
  • Baselines: Random selection, uncertainty, similarity, the RL active selection paper, etc.

Thank you.

Share on:


Discover more from 7vik

Subscribe now to keep reading and get access to the full archive.

Continue reading