Active Fine-Tuning (activeft) is a Python package for intelligent active data selection.

Why Active Data Selection?

As opposed to random data selection, active data selection chooses data adaptively utilizing the current model. In other words,

active data selection pays attention to the most useful data

which allows for faster learning and adaptation. There are mainly two reasons for why some data may be particularly useful:

  1. Relevance: The data is closely related to a particular task, such as answering a specific prompt.
  2. Diversity: The data contains non-redundant information that is not yet captured by the model.

A dataset that is both relevant and diverse is informative for the model. This is related to memory recall, where the brain recalls informative and relevant memories (think "data") to make sense of the current sensory input. Focusing recall on useful data enables efficient learning from few examples.

activeft provides a simple interface for active data selection, which can be used as a drop-in replacement for random data selection or nearest neighbor retrieval.

Getting Started

You can install activeft from PyPI via pip:

pip install activeft

We briefly discuss how to use activeft for standard fine-tuning and test-time fine-tuning.

Example: Fine-tuning

Given a PyTorch model which may (but does not have to be!) pre-trained, we can use activeft to efficiently fine-tune the model. This model may be generative (e.g., a language model) or discriminative (e.g., a classifier), and can use any architecture.

We only need the following things:

  • A dataset of inputs dataset (such that dataset[i] returns a vector of length $d$) from which we want to select batches for fine-tuning. If one has a supervised dataset returning input-label pairs, then activeft.data.InputDataset(dataset) can be used to obtain a dataset over the input space.
  • A tensor of prediction targets target ($m \times d$) which specifies the task we want to fine-tune the model for. Here, $m$ can be quite small, e.g., equal to the number of classes in a classification task. If there is no specific task for training, then active data selection can still be useful as we will see later.
  • The model can be any PyTorch nn.Module with an embed(x) method that computes (latent) embeddings for the given inputs x, e.g., the representation of x from the penultimate layer. See activeft.model.ModelWithEmbedding for more details. Alternatively, the model can have a kernel(x1,x2) method that computes a kernel for given inputs x1 and x2 (see activeft.model.ModelWithKernel).

For active data selection to be effective, it is important that the model's embeddings are somewhat representative of the data. In particular, embeddings should capture the relationship between the data and the task.

With this in place, we can initialize the "active" data loader

from activeft import ActiveDataLoader

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)

To obtain the next batch from data, we can then simply call

batch = data[data_loader.next(model)]

Note that the active data selection of the next batch is utilizing the current model to select the most relevant data with respect to the given target.

Combining the data selection with a model update step, we can implement a simple training loop as follows:

while not converged:
    batch = dataset[data_loader.next(model)]

Notice the feedback loop(!): the batch selection improves as the model learns and the model learns faster as the batch selection improves.

This is it! Training with active data selection is as simple as that.

"Undirected" Data Selection

If there is no specific task for training then all data is equally relevant, yet, we can still use active data selection to select the most informative data. To do this, simply initialize

data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64)

Example: Test-Time Fine-Tuning

The above example described active data selection in the context of training a model with multiple batches. This usually happens at "train-time" or during "post-training".

The following example demonstrates how to use activeft at "test-time" to obtain a model that is as good as possible on a specific test instance. For example, with a language model, this would fine-tune the model for a few gradient steps on data selected specifically for a given prompt. We refer to the following paper for more details: Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs.

We can also use the intelligent retrieval of informative and relevant data outside a training loop — for example, for in-context learning and retrieval-augmented generation.

The setup is analogous to the previous section: we have a pre-trained model, a dataset data to query from, and targets (e.g., a prompt) for which we want to retrieve relevant data. We can use activeft to query the most useful data and then add it to the model's context:

from activeft import ActiveDataLoader

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=10)
data = dataset[data_loader.next(model)]

Again: very simple!

Scaling to Large Datasets

By default activeft maintains a matrix of size of the dataset in memory. This is not feasible for very large datasets. Some acquisition functions (such as activeft.acquisition_functions.LazyVTL) allow for efficient computation of the acquisition function without storing the entire dataset in memory. An alternative approach is to pre-select a subset of the data using nearest neighbor retrieval (using Faiss), before initializing the ActiveDataLoader. The following is an example of this approach in the context of test-time fine-tuning:

import torch
import faiss
from activeft.sift import Retriever

# Before Test-Time
embeddings = torch.randn(1000, 768)
index = faiss.IndexFlatIP(embeddings.size(1))
retriever = Retriever(index)

# At Test-Time, given query
query_embeddings = torch.randn(1, 768)
indices = retriever.search(query_embeddings, N=10, K=1_000)
data = embeddings[indices]
model.step(data)  # Use data to fine-tune base model, then forward pass query

activeft.sift.Retriever first pre-selects K nearest neighbors and then uses activeft to select the N most informative data for the given query from this subset.


If you use the code in a publication, please cite our papers:

# Large-Scale Learning at Test-Time with SIFT
        title        = {Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs},
        author       = {H{\"u}botter, Jonas and Bongni, Sascha and Hakimi, Ido and Krause, Andreas},
        year         = 2024,
        journal      = {arXiv Preprint}

# Theory and Fundamental Algorithms for Transductive Active Learning
        title        = {Transductive Active Learning: Theory and Applications},
        author       = {H{\"u}botter, Jonas and Sukhija, Bhavya and Treven, Lenart and As, Yarden and Krause, Andreas},
        year         = 2024,
        booktitle    = {Advances in Neural Information Processing Systems}

class ActiveDataLoader(typing.Generic[~M]):
 17class ActiveDataLoader(Generic[M]):
 18    r"""
 19    `ActiveDataLoader` can be used as a drop-in replacement for random data selection or nearest neighbor retrieval:
 21    ```python
 22    data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
 23    batch = dataset[data_loader.next(model)]
 24    ```
 26    where
 27    - `model` is a PyTorch `nn.Module`,
 28    - `dataset` is a dataset of inputs (where `dataset[i]` returns a vector of length $d$), and
 29    - `target` is a tensor of prediction targets (shape $m \times d$) or `None`.
 31    If `dataset` already includes pre-computed embeddings, `model` can be omitted:
 33    ```python
 34    data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
 35    batch = dataset[data_loader.next()]
 36    ```
 38    The target can also be updated sequentially:
 40    ```python
 41    data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64, force_targeted=True)
 42    for target in targets:
 43        batch = dataset[data_loader.with_target(target).next(model)]
 44    ```
 45    """
 47    dataset: Dataset
 48    r"""Inputs (shape $n \times d$) to be selected from."""
 50    batch_size: int
 51    r"""Size of the batch to be selected."""
 53    acquisition_function: AcquisitionFunction[M]
 54    r"""Acquisition function to be used for data selection."""
 56    device: torch.device | None = None
 57    r"""Device used for computation of the acquisition function."""
 59    def __init__(
 60        self,
 61        dataset: Dataset,
 62        batch_size: int,
 63        acquisition_function: AcquisitionFunction[M],
 64        device: torch.device | None = None,
 65    ):
 66        """
 67        Explicitly constructs an active data loader with a custom acquisition function.
 68        `activeft` supports a wide range of acquisition functions which are summarized in `activeft.acquisition_functions`.
 69        """
 71        assert len(dataset) > 0, "Data must be non-empty"
 72        assert batch_size > 0, "Batch size must be positive"
 74        self.dataset = dataset
 75        self.batch_size = batch_size
 76        self.acquisition_function = acquisition_function
 77        self.device = device
 79    @classmethod
 80    def initialize(
 81        cls,
 82        dataset: Dataset,
 83        target: torch.Tensor | None,
 84        batch_size: int,
 85        device: torch.device | None = None,
 86        subsampled_target_frac: float = 1,
 87        max_target_size: int | None = None,
 88        mini_batch_size: int = DEFAULT_MINI_BATCH_SIZE,
 89        embedding_batch_size: int = DEFAULT_EMBEDDING_BATCH_SIZE,
 90        num_workers: int = DEFAULT_NUM_WORKERS,
 91        subsample_acquisition: bool = DEFAULT_SUBSAMPLE,
 92        force_targeted: bool = False,
 93    ):
 94        r"""
 95        Initializes an active data loader.
 97        :param dataset: Inputs (shape $n \times d$) to be selected from.
 98        :param target: Tensor of prediction targets (shape $m \times d$) or `None`.
 99        :param batch_size: Size of the batch to be selected.
100        :param device: Device used for computation of the acquisition function.
101        :param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`.
102        :param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`.
103        :param mini_batch_size: Size of mini batches used for computing the acquisition function.
104        :param embedding_batch_size: Batch size used for computing the embeddings.
105        :param num_workers: Number of workers used for data loading.
106        :param subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function.
107        :param force_targeted: Whether to force targeted data selection. If `True`, `target` must be provided subsequently using `with_target`.
108        """
110        if target is not None or force_targeted:
111            acquisition_function = VTL(
112                target=target if target is not None else torch.tensor([]),
113                subsampled_target_frac=subsampled_target_frac,
114                max_target_size=max_target_size,
115                mini_batch_size=mini_batch_size,
116                embedding_batch_size=embedding_batch_size,
117                num_workers=num_workers,
118                subsample=subsample_acquisition,
119            )
120        else:
121            acquisition_function = UndirectedVTL(
122                mini_batch_size=mini_batch_size,
123                embedding_batch_size=embedding_batch_size,
124                num_workers=num_workers,
125                subsample=subsample_acquisition,
126            )
127        return cls(
128            dataset=dataset,
129            batch_size=batch_size,
130            acquisition_function=acquisition_function,  # type: ignore
131            device=device,
132        )
134    def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
135        r"""
136        Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`.
138        .. warning::
140            The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`.
142        :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded.
143        :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`.
144        """
146        return self.acquisition_function.select(
147            batch_size=self.batch_size,
148            model=model,  # type: ignore
149            dataset=self.dataset,
150            device=self.device,
151        )
153    def with_target(self, target: torch.Tensor) -> ActiveDataLoader[M]:
154        r"""
155        Returns the active data loader with a new target.
157        :param target: Tensor of prediction targets (shape $m \times d$).
158        :return: Updated active data loader.
159        """
161        assert isinstance(
162            self.acquisition_function, Targeted
163        ), "Acquisition function must be targeted"
164        self.acquisition_function.set_target(target)
165        return self

