$ \DeclareMathOperator*{\defeq}{\overset{\mathrm{def}}{=}} \DeclareMathOperator*{\eqdef}{\overset{\mathrm{def}}{=}} \DeclareMathOperator*{\argmax}{arg\max} \DeclareMathOperator*{\argmin}{arg\min} \DeclareMathOperator*{\esssup}{ess\sup} \DeclareMathOperator*{\diam}{diam} \newcommand{\norm}[1]{\left\| #1 \right\|} \newcommand{\Pr}[1]{\mathbb{P}\!\left(#1\right)} \newcommand{\H}[1]{\mathrm{H}\!\left[#1\right]} \newcommand{\I}[2]{\mathrm{I}\!\left(#1 ; #2\right)} \newcommand{\E}[2]{\mathbb{E}_{#1}\!\left[#2\right]} \newcommand{\Var}[1]{\mathrm{Var}\!\left[#1\right]} \newcommand{\Cor}[1]{\mathrm{Cor}\!\left[#1\right]} % \NewDocumentCommand{\Ind}{m}{\mathbbm{1}\{{#1}\}} % \NewDocumentCommand{\fnPr}{}{\mathbb{P}} % \RenewDocumentCommand{\Pr}{om}{\fnPr\IfValueT{#1}{_{#1}}\parentheses*{#2}} % \RenewDocumentCommand{\H}{mo}{\mathrm{H}\IfValueTF{#2}{\!\left[#1\ \middle|\ #2\right]}{\brackets*{#1}}} % \NewDocumentCommand{\Hsm}{mo}{\mathrm{H}\IfValueTF{#2}{[#1 \mid #2]}{\brackets{#1}}} % \NewDocumentCommand{\I}{mmo}{\mathrm{I}\IfValueTF{#3}{\!\left(#1;#2\ \middle|\ #3\right)}{\parentheses*{#1; #2}}} % \NewDocumentCommand{\Ism}{mmo}{\mathrm{I}\IfValueTF{#3}{(#1;#2 \mid #3)}{\parentheses{#1; #2}}} % \NewDocumentCommand{\fnS}{}{\mathrm{S}} % \RenewDocumentCommand{\S}{m}{\fnS\brackets*{#1}} % \NewDocumentCommand{\E}{somo}{\ensuremath{\mathbb{E}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\IfValueTF{#4}{\!\left[#3\ \middle|\ #4\right]}{\brackets*{#3}}}}} % \NewDocumentCommand{\Esm}{somo}{\ensuremath{\mathbb{E}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\IfValueTF{#4}{\!\left[#3\ \middle|\ #4\right]}{\brackets{#3}}}}} % \NewDocumentCommand{\Var}{somo}{\mathrm{Var}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\IfValueTF{#4}{\!\left[#3\ \middle|\ #4\right]}{\brackets*{#3}}}} % \NewDocumentCommand{\Varsm}{somo}{\mathrm{Var}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\IfValueTF{#4}{\left[#3\ \middle|\ #4\right]}{\brackets{#3}}}} % \NewDocumentCommand{\Cov}{som}{\mathrm{Cov}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\brackets*{#3}}} % \NewDocumentCommand{\Cor}{som}{\mathrm{Cor}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{#3}{\brackets*{#3}}} \newcommand{\grad}[1][]{\boldsymbol{\nabla}_{\!\!#1}\,} \DeclareMathOperator*{\tr}{tr} % \NewDocumentCommand{\grad}{e_}{\boldsymbol{\nabla}\IfValueT{#1}{_{\!\!#1}\,}} % \NewDocumentCommand{\jac}{}{\mD} % \NewDocumentCommand{\hes}{}{\mH} % \NewDocumentCommand{\BigO}{m}{\mathcal{O}\parentheses*{#1}} % \NewDocumentCommand{\BigOTil}{m}{\widetilde{\mathcal{O}}\parentheses*{#1}} % \NewDocumentCommand{\transpose}{m}{#1^\top} % \NewDocumentCommand{\inv}{m}{#1^{-1}} % \RenewDocumentCommand{\det}{m}{\left| #1 \right|} % \NewDocumentCommand{\detsm}{m}{|#1|} % \NewDocumentCommand{\tr}{m}{\mathrm{tr}\;#1} % \NewDocumentCommand{\diag}{som}{\mathrm{diag}\IfValueT{#2}{_{#2}}{} \IfBooleanTF{#1}{\braces{#3}}{\braces*{#3}}} % \NewDocumentCommand{\msqrt}{m}{#1^{\nicefrac{1}{2}}} % \NewDocumentCommand{\vecop}{m}{\mathrm{vec}\brackets{#1}} \newcommand{\N}[2]{\mathcal{N}\!\left(#1, #2\right)} % \NewDocumentCommand{\N}{somm}{\mathcal{N}\IfBooleanTF{#1}{\left(}{(}\IfValueT{#2}{#2;}{} #3, #4\IfBooleanTF{#1}{\right)}{)}} % \NewDocumentCommand{\GP}{omm}{\mathcal{GP}(\IfValueT{#1}{#1;}{} #2, #3)} \renewcommand{\vec}[1]{\boldsymbol{#1}} \newcommand{\mat}[1]{\boldsymbol{#1}} \newcommand{\rvec}[1]{\mathbf{#1}} \newcommand{\set}[1]{#1} \newcommand{\spa}[1]{\mathcal{#1}} \newcommand{\opt}[1]{#1^\star} \newcommand{\vzero}{\vec{0}} \newcommand{\vone}{\vec{1}} \newcommand{\va}{\vec{a}} \newcommand{\vap}{\vec{a'}} \newcommand{\vas}{\vec{\opt{a}}} \newcommand{\vb}{\vec{b}} \newcommand{\vc}{\vec{c}} \newcommand{\vd}{\vec{d}} \newcommand{\ve}{\vec{e}} \newcommand{\vf}{\vec{f}} \newcommand{\vfp}{\vec{f'}} \newcommand{\vfsub}[1]{\vec{f}_{\!\!#1}} \newcommand{\vfhat}{\vec{\widehat{f}}} \newcommand{\vg}{\vec{g}} \newcommand{\vh}{\vec{h}} \newcommand{\vk}{\vec{k}} \newcommand{\vl}{\vec{l}} \newcommand{\vm}{\vec{m}} \newcommand{\vq}{\vec{q}} \newcommand{\vr}{\vec{r}} \newcommand{\vu}{\vec{u}} \newcommand{\vup}{\vec{u'}} \newcommand{\vv}{\vec{v}} \newcommand{\vvp}{\vec{v'}} \newcommand{\vvs}{\vec{\opt{v}}} \newcommand{\vw}{\vec{w}} \newcommand{\vwhat}{\vec{\hat{w}}} \newcommand{\vx}{\vec{x}} \newcommand{\vs}{\vec{s}} \newcommand{\vT}{\vec{T}} \newcommand{\vxhat}{\widehat{\vec{x}}} \newcommand{\vxp}{\vec{x'}} \newcommand{\vxs}{\vec{\opt{x}}} \newcommand{\vy}{\vec{y}} \newcommand{\vysub}[1]{\vec{y}_{\!#1}} \newcommand{\vyp}{\vec{y'}} \newcommand{\vz}{\vec{z}} \newcommand{\vbeta}{\boldsymbol{\beta}} \newcommand{\vdelta}{\boldsymbol{\delta}} \newcommand{\vDelta}{\boldsymbol{\Delta}} \newcommand{\vepsilon}{\boldsymbol{\epsilon}} \newcommand{\veta}{\boldsymbol{\eta}} \newcommand{\vlambda}{\boldsymbol{\lambda}} \newcommand{\vmu}{\boldsymbol{\mu}} \newcommand{\vmusub}[1]{\boldsymbol{\mu}_{\!#1}} \newcommand{\vmuhat}{\boldsymbol{\widehat{\mu}}} \newcommand{\vmup}{\boldsymbol{\mu'}} \newcommand{\vnu}{\boldsymbol{\nu}} \newcommand{\vomega}{\boldsymbol{\omega}} \newcommand{\vphi}{\boldsymbol{\phi}} \newcommand{\vphip}{\boldsymbol{\phi'}} \newcommand{\vpi}{\boldsymbol{\pi}} \newcommand{\vvarphi}{\boldsymbol{\varphi}} \newcommand{\vrho}{\boldsymbol{\rho}} \newcommand{\vsigma}{\boldsymbol{\sigma}} \newcommand{\vtheta}{\boldsymbol{\theta}} \newcommand{\vthetap}{\boldsymbol{\theta'}} \newcommand{\vthetahat}{\boldsymbol{\widehat{\theta}}} \newcommand{\vxi}{\boldsymbol{\xi}} \newcommand{\mzero}{\mat{0}} \newcommand{\mA}{\mat{A}} \newcommand{\mB}{\mat{B}} \newcommand{\mBs}{\mat{\opt{B}}} \newcommand{\mC}{\mat{C}} \newcommand{\mD}{\mat{D}} \newcommand{\mF}{\mat{F}} \newcommand{\mG}{\mat{G}} \newcommand{\mH}{\mat{H}} \newcommand{\mI}{\mat{I}} \newcommand{\mJ}{\mat{J}} \newcommand{\mK}{\mat{K}} \newcommand{\mKsub}[1]{\mat{K}_{\!#1}} \newcommand{\mL}{\mat{L}} \newcommand{\mCalL}{\mat{\mathcal{L}}} \newcommand{\mM}{\mat{M}} \newcommand{\mO}{\mat{O}} \newcommand{\mP}{\mat{P}} \newcommand{\mPsub}[1]{\mat{P}_{\!#1}} \newcommand{\mPhi}{\mat{\Phi}} \newcommand{\mPsi}{\mat{\Psi}} \newcommand{\mQ}{\mat{Q}} \newcommand{\mR}{\mat{R}} \newcommand{\mS}{\mat{S}} \newcommand{\mT}{\mat{T}} \newcommand{\mU}{\mat{U}} \newcommand{\mV}{\mat{V}} \newcommand{\mW}{\mat{W}} \newcommand{\mX}{\mat{X}} \newcommand{\mLambda}{\mat{\Lambda}} \newcommand{\mSigma}{\mat{\Sigma}} \newcommand{\mSigmap}{\mat{\Sigma'}} \newcommand{\rG}{\rvec{G}} \newcommand{\rQ}{\rvec{Q}} \newcommand{\rU}{\rvec{U}} \newcommand{\rV}{\rvec{V}} \newcommand{\rX}{\rvec{X}} \newcommand{\rXp}{\rvec{X'}} \newcommand{\rY}{\rvec{Y}} \newcommand{\rZ}{\rvec{Z}} \newcommand{\sA}{\set{A}} \newcommand{\sB}{\set{B}} \newcommand{\sC}{\set{C}} \newcommand{\sD}{\set{D}} \newcommand{\sI}{\set{I}} \newcommand{\sM}{\set{M}} \newcommand{\sS}{\set{S}} \newcommand{\sU}{\set{U}} \newcommand{\sX}{\set{X}} \newcommand{\sY}{\set{Y}} \newcommand{\sZ}{\set{Z}} \newcommand{\spA}{\spa{A}} \newcommand{\spB}{\spa{B}} \newcommand{\spC}{\spa{C}} \newcommand{\vspC}{\boldsymbol{\spa{C}}} \newcommand{\spD}{\spa{D}} \newcommand{\spE}{\spa{E}} \newcommand{\spF}{\spa{F}} \newcommand{\spG}{\spa{G}} \newcommand{\spH}{\spa{H}} \newcommand{\spI}{\spa{I}} \newcommand{\spL}{\spa{L}} \newcommand{\spM}{\spa{M}} \newcommand{\spN}{\spa{N}} \newcommand{\spO}{\spa{O}} \newcommand{\spP}{\spa{P}} \newcommand{\spPhat}{\widehat{\spP}} \newcommand{\spQ}{\spa{Q}} \newcommand{\vspQ}{\boldsymbol{\spa{Q}}} \newcommand{\spR}{\spa{R}} \newcommand{\spS}{\spa{S}} \newcommand{\spT}{\spa{T}} \newcommand{\spU}{\spa{U}} \newcommand{\spW}{\spa{W}} \newcommand{\spX}{\spa{X}} \newcommand{\spY}{\spa{Y}} \newcommand{\vspY}{\boldsymbol{\spa{Y}}} \newcommand{\spZ}{\spa{Z}} \newcommand{\fs}{\opt{f}} \newcommand{\ps}{\opt{p}} \newcommand{\qs}{\opt{q}} \newcommand{\xs}{\opt{x}} \newcommand{\ys}{\opt{y}} \newcommand{\Bs}{\opt{B}} \newcommand{\Qs}{\opt{Q}} \newcommand{\sSs}{\opt{\sS}} \newcommand{\hQs}{\opt{\hat{Q}}} \newcommand{\Vs}{\opt{V}} \newcommand{\pis}{\opt{\pi}} \newcommand{\spPA}{\ensuremath{\spa{P}_{\!\!\spA}}} \newcommand{\spPS}{\ensuremath{\spa{P}_{\!\spS}}} $
Edit on GitHub

activeft

Active Fine-Tuning (activeft) is a Python package for intelligent active data selection.

Why Active Data Selection?

As opposed to random data selection, active data selection chooses data adaptively utilizing the current model. In other words,

active data selection pays attention to the most useful data

which allows for faster learning and adaptation. There are mainly two reasons for why some data may be particularly useful:

  1. Relevance: The data is closely related to a particular task, such as answering a specific prompt.
  2. Diversity: The data contains non-redundant information that is not yet captured by the model.

A dataset that is both relevant and diverse is informative for the model. This is related to memory recall, where the brain recalls informative and relevant memories (think "data") to make sense of the current sensory input. Focusing recall on useful data enables efficient learning from few examples.

activeft provides a simple interface for active data selection, which can be used as a drop-in replacement for random data selection or nearest neighbor retrieval.

Getting Started

You can install activeft from PyPI via pip:

pip install activeft

We briefly discuss how to use activeft for standard fine-tuning and test-time fine-tuning.

Example: Fine-tuning

Given a PyTorch model which may (but does not have to be!) pre-trained, we can use activeft to efficiently fine-tune the model. This model may be generative (e.g., a language model) or discriminative (e.g., a classifier), and can use any architecture.

We only need the following things:

  • A dataset of inputs dataset (such that dataset[i] returns a vector of length $d$) from which we want to select batches for fine-tuning. If one has a supervised dataset returning input-label pairs, then activeft.data.InputDataset(dataset) can be used to obtain a dataset over the input space.
  • A tensor of prediction targets target ($m \times d$) which specifies the task we want to fine-tune the model for. Here, $m$ can be quite small, e.g., equal to the number of classes in a classification task. If there is no specific task for training, then active data selection can still be useful as we will see later.
  • The model can be any PyTorch nn.Module with an embed(x) method that computes (latent) embeddings for the given inputs x, e.g., the representation of x from the penultimate layer. See activeft.model.ModelWithEmbedding for more details. Alternatively, the model can have a kernel(x1,x2) method that computes a kernel for given inputs x1 and x2 (see activeft.model.ModelWithKernel).

For active data selection to be effective, it is important that the model's embeddings are somewhat representative of the data. In particular, embeddings should capture the relationship between the data and the task.

With this in place, we can initialize the "active" data loader

from activeft import ActiveDataLoader

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)

To obtain the next batch from data, we can then simply call

batch = data[data_loader.next(model)]

Note that the active data selection of the next batch is utilizing the current model to select the most relevant data with respect to the given target.

Combining the data selection with a model update step, we can implement a simple training loop as follows:

while not converged:
    batch = dataset[data_loader.next(model)]
    model.step(batch)

Notice the feedback loop(!): the batch selection improves as the model learns and the model learns faster as the batch selection improves.

This is it! Training with active data selection is as simple as that.

"Undirected" Data Selection

If there is no specific task for training then all data is equally relevant, yet, we can still use active data selection to select the most informative data. To do this, simply initialize

data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64)

Example: Test-Time Fine-Tuning

The above example described active data selection in the context of training a model with multiple batches. This usually happens at "train-time" or during "post-training".

The following example demonstrates how to use activeft at "test-time" to obtain a model that is as good as possible on a specific test instance. For example, with a language model, this would fine-tune the model for a few gradient steps on data selected specifically for a given prompt. We refer to the following paper for more details: Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs.

We can also use the intelligent retrieval of informative and relevant data outside a training loop — for example, for in-context learning and retrieval-augmented generation.

The setup is analogous to the previous section: we have a pre-trained model, a dataset data to query from, and targets (e.g., a prompt) for which we want to retrieve relevant data. We can use activeft to query the most useful data and then add it to the model's context:

from activeft import ActiveDataLoader

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=10)
data = dataset[data_loader.next(model)]
model.step(data)

Again: very simple!

Scaling to Large Datasets

By default activeft maintains a matrix of size of the dataset in memory. This is not feasible for very large datasets. Some acquisition functions (such as activeft.acquisition_functions.LazyVTL) allow for efficient computation of the acquisition function without storing the entire dataset in memory. An alternative approach is to pre-select a subset of the data using nearest neighbor retrieval (using Faiss), before initializing the ActiveDataLoader. The following is an example of this approach in the context of test-time fine-tuning:

import torch
import faiss
from activeft.sift import Retriever

# Before Test-Time
embeddings = torch.randn(1000, 768)
index = faiss.IndexFlatIP(embeddings.size(1))
index.add(embeddings)
retriever = Retriever(index)

# At Test-Time, given query
query_embeddings = torch.randn(1, 768)
indices = retriever.search(query_embeddings, N=10, K=1_000)
data = embeddings[indices]
model.step(data)  # Use data to fine-tune base model, then forward pass query

activeft.sift.Retriever first pre-selects K nearest neighbors and then uses activeft to select the N most informative data for the given query from this subset.

Citation

If you use the code in a publication, please cite our papers:

# Large-Scale Learning at Test-Time with SIFT
@article{hubotter2024efficiently,
        title        = {Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs},
        author       = {H{\"u}botter, Jonas and Bongni, Sascha and Hakimi, Ido and Krause, Andreas},
        year         = 2024,
        journal      = {arXiv Preprint}
}

# Theory and Fundamental Algorithms for Transductive Active Learning
@inproceedings{hubotter2024transductive,
        title        = {Transductive Active Learning: Theory and Applications},
        author       = {H{\"u}botter, Jonas and Sukhija, Bhavya and Treven, Lenart and As, Yarden and Krause, Andreas},
        year         = 2024,
        booktitle    = {Advances in Neural Information Processing Systems}
}

  1r"""
  2*Active Fine-Tuning* (`activeft`) is a Python package for intelligent active data selection.
  3
  4## Why Active Data Selection?
  5
  6As opposed to random data selection, active data selection chooses data adaptively utilizing the current model.
  7In other words, <p style="text-align: center;">active data selection pays *attention* to the most useful data</p> which allows for faster learning and adaptation.
  8There are mainly two reasons for why some data may be particularly useful:
  9
 101. **Relevance**: The data is closely related to a particular task, such as answering a specific prompt.
 112. **Diversity**: The data contains non-redundant information that is not yet captured by the model.
 12
 13A dataset that is both relevant and diverse is *informative* for the model.
 14This is related to memory recall, where the brain recalls informative and relevant memories (think "data") to make sense of the current sensory input.
 15Focusing recall on useful data enables efficient learning from few examples.
 16
 17`activeft` provides a simple interface for active data selection, which can be used as a drop-in replacement for random data selection or nearest neighbor retrieval.
 18
 19## Getting Started
 20
 21You can install `activeft` from [PyPI](https://pypi.org/project/activeft/) via pip:
 22
 23```bash
 24pip install activeft
 25```
 26
 27We briefly discuss how to use `activeft` for standard [fine-tuning](#example-fine-tuning) and [test-time fine-tuning](#example-test-time-fine-tuning).
 28
 29### Example: Fine-tuning
 30
 31Given a [PyTorch](https://pytorch.org) model which may (but does not have to be!) pre-trained, we can use `activeft` to efficiently fine-tune the model.
 32This model may be generative (e.g., a language model) or discriminative (e.g., a classifier), and can use any architecture.
 33
 34We only need the following things:
 35- A dataset of inputs `dataset` (such that `dataset[i]` returns a vector of length $d$) from which we want to select batches for fine-tuning. If one has a supervised dataset returning input-label pairs, then `activeft.data.InputDataset(dataset)` can be used to obtain a dataset over the input space.
 36- A tensor of prediction targets `target` ($m \times d$) which specifies the task we want to fine-tune the model for.
 37Here, $m$ can be quite small, e.g., equal to the number of classes in a classification task.
 38If there is no *specific* task for training, then active data selection can still be useful as we will see [later](#undirected-data-selection).
 39- The `model` can be any PyTorch `nn.Module` with an `embed(x)` method that computes (latent) embeddings for the given inputs `x`, e.g., the representation of `x` from the penultimate layer.
 40See `activeft.model.ModelWithEmbedding` for more details. Alternatively, the model can have a `kernel(x1,x2)` method that computes a kernel for given inputs `x1` and `x2` (see `activeft.model.ModelWithKernel`).
 41
 42.. note::
 43
 44   For active data selection to be effective, it is important that the model's embeddings are somewhat representative of the data.
 45   In particular, embeddings should capture the relationship between the data and the task.
 46
 47With this in place, we can initialize the "active" data loader
 48
 49```python
 50from activeft import ActiveDataLoader
 51
 52data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
 53```
 54
 55To obtain the next batch from `data`, we can then simply call
 56
 57```python
 58batch = data[data_loader.next(model)]
 59```
 60
 61Note that the active data selection of the next batch is utilizing the current `model` to select the most relevant data with respect to the given `target`.
 62
 63Combining the data selection with a model update step, we can implement a simple training loop as follows:
 64
 65```python
 66while not converged:
 67    batch = dataset[data_loader.next(model)]
 68    model.step(batch)
 69```
 70
 71Notice the feedback loop(!): the batch selection improves as the model learns and the model learns faster as the batch selection improves.
 72
 73This is it!
 74Training with active data selection is as simple as that.
 75
 76#### "Undirected" Data Selection
 77
 78If there is no specific task for training then all data is equally relevant, yet, we can still use active data selection to select the most informative data.
 79To do this, simply initialize
 80
 81```python
 82data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64)
 83```
 84
 85### Example: Test-Time Fine-Tuning
 86
 87The above example described active data selection in the context of training a model with multiple batches. This usually happens at "train-time" or during "post-training".
 88
 89The following example demonstrates how to use `activeft` at "test-time" to obtain a model that is as good as possible on a specific test instance.
 90For example, with a language model, this would fine-tune the model for a few gradient steps on data selected specifically for a given prompt.
 91We refer to the following paper for more details: [Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs](https://arxiv.org/abs/2410.08020).
 92
 93We can also use the intelligent retrieval of informative and relevant data outside a training loop — for example, for in-context learning and retrieval-augmented generation.
 94
 95The setup is analogous to the previous section: we have a pre-trained `model`, a dataset `data` to query from, and `target`s (e.g., a prompt) for which we want to retrieve relevant data.
 96We can use `activeft` to query the most useful data and then add it to the model's context:
 97
 98```python
 99from activeft import ActiveDataLoader
100
101data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=10)
102data = dataset[data_loader.next(model)]
103model.step(data)
104```
105
106Again: very simple!
107
108### Scaling to Large Datasets
109
110By default `activeft` maintains a matrix of size of the dataset in memory. This is not feasible for very large datasets.
111Some acquisition functions (such as `activeft.acquisition_functions.LazyVTL`) allow for efficient computation of the acquisition function without storing the entire dataset in memory.
112An alternative approach is to pre-select a subset of the data using nearest neighbor retrieval (using [Faiss](https://github.com/facebookresearch/faiss)), before initializing the `ActiveDataLoader`.
113The following is an example of this approach in the context of [test-time fine-tuning](#example-test-time-fine-tuning):
114
115```python
116import torch
117import faiss
118from activeft.sift import Retriever
119
120# Before Test-Time
121embeddings = torch.randn(1000, 768)
122index = faiss.IndexFlatIP(embeddings.size(1))
123index.add(embeddings)
124retriever = Retriever(index)
125
126# At Test-Time, given query
127query_embeddings = torch.randn(1, 768)
128indices = retriever.search(query_embeddings, N=10, K=1_000)
129data = embeddings[indices]
130model.step(data)  # Use data to fine-tune base model, then forward pass query
131```
132
133`activeft.sift.Retriever` first pre-selects `K` nearest neighbors and then uses `activeft` to select the `N` most informative data for the given query from this subset.
134
135## Citation
136
137If you use the code in a publication, please cite our papers:
138
139```bibtex
140# Large-Scale Learning at Test-Time with SIFT
141@article{hubotter2024efficiently,
142	title        = {Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs},
143	author       = {H{\"u}botter, Jonas and Bongni, Sascha and Hakimi, Ido and Krause, Andreas},
144	year         = 2024,
145	journal      = {arXiv Preprint}
146}
147
148# Theory and Fundamental Algorithms for Transductive Active Learning
149@inproceedings{hubotter2024transductive,
150	title        = {Transductive Active Learning: Theory and Applications},
151	author       = {H{\"u}botter, Jonas and Sukhija, Bhavya and Treven, Lenart and As, Yarden and Krause, Andreas},
152	year         = 2024,
153	booktitle    = {Advances in Neural Information Processing Systems}
154}
155```
156
157---
158"""
159
160from activeft.active_data_loader import ActiveDataLoader
161from activeft import acquisition_functions, data, embeddings, model, sift
162
163__all__ = [
164    "ActiveDataLoader",
165    "acquisition_functions",
166    "data",
167    "embeddings",
168    "model",
169    "sift",
170]
171__version__ = "0.1.1"
172__author__ = "Jonas Hübotter"
173__credits__ = "ETH Zurich, Switzerland"
class ActiveDataLoader(typing.Generic[~M]):
 17class ActiveDataLoader(Generic[M]):
 18    r"""
 19    `ActiveDataLoader` can be used as a drop-in replacement for random data selection or nearest neighbor retrieval:
 20
 21    ```python
 22    data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
 23    batch = dataset[data_loader.next(model)]
 24    ```
 25
 26    where
 27    - `model` is a PyTorch `nn.Module`,
 28    - `dataset` is a dataset of inputs (where `dataset[i]` returns a vector of length $d$), and
 29    - `target` is a tensor of prediction targets (shape $m \times d$) or `None`.
 30
 31    If `dataset` already includes pre-computed embeddings, `model` can be omitted:
 32
 33    ```python
 34    data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
 35    batch = dataset[data_loader.next()]
 36    ```
 37
 38    The target can also be updated sequentially:
 39
 40    ```python
 41    data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64, force_targeted=True)
 42    for target in targets:
 43        batch = dataset[data_loader.with_target(target).next(model)]
 44    ```
 45    """
 46
 47    dataset: Dataset
 48    r"""Inputs (shape $n \times d$) to be selected from."""
 49
 50    batch_size: int
 51    r"""Size of the batch to be selected."""
 52
 53    acquisition_function: AcquisitionFunction[M]
 54    r"""Acquisition function to be used for data selection."""
 55
 56    device: torch.device | None = None
 57    r"""Device used for computation of the acquisition function."""
 58
 59    def __init__(
 60        self,
 61        dataset: Dataset,
 62        batch_size: int,
 63        acquisition_function: AcquisitionFunction[M],
 64        device: torch.device | None = None,
 65    ):
 66        """
 67        Explicitly constructs an active data loader with a custom acquisition function.
 68        `activeft` supports a wide range of acquisition functions which are summarized in `activeft.acquisition_functions`.
 69        """
 70
 71        assert len(dataset) > 0, "Data must be non-empty"
 72        assert batch_size > 0, "Batch size must be positive"
 73
 74        self.dataset = dataset
 75        self.batch_size = batch_size
 76        self.acquisition_function = acquisition_function
 77        self.device = device
 78
 79    @classmethod
 80    def initialize(
 81        cls,
 82        dataset: Dataset,
 83        target: torch.Tensor | None,
 84        batch_size: int,
 85        device: torch.device | None = None,
 86        subsampled_target_frac: float = 1,
 87        max_target_size: int | None = None,
 88        mini_batch_size: int = DEFAULT_MINI_BATCH_SIZE,
 89        embedding_batch_size: int = DEFAULT_EMBEDDING_BATCH_SIZE,
 90        num_workers: int = DEFAULT_NUM_WORKERS,
 91        subsample_acquisition: bool = DEFAULT_SUBSAMPLE,
 92        force_targeted: bool = False,
 93    ):
 94        r"""
 95        Initializes an active data loader.
 96
 97        :param dataset: Inputs (shape $n \times d$) to be selected from.
 98        :param target: Tensor of prediction targets (shape $m \times d$) or `None`.
 99        :param batch_size: Size of the batch to be selected.
100        :param device: Device used for computation of the acquisition function.
101        :param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`.
102        :param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`.
103        :param mini_batch_size: Size of mini batches used for computing the acquisition function.
104        :param embedding_batch_size: Batch size used for computing the embeddings.
105        :param num_workers: Number of workers used for data loading.
106        :param subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function.
107        :param force_targeted: Whether to force targeted data selection. If `True`, `target` must be provided subsequently using `with_target`.
108        """
109
110        if target is not None or force_targeted:
111            acquisition_function = VTL(
112                target=target if target is not None else torch.tensor([]),
113                subsampled_target_frac=subsampled_target_frac,
114                max_target_size=max_target_size,
115                mini_batch_size=mini_batch_size,
116                embedding_batch_size=embedding_batch_size,
117                num_workers=num_workers,
118                subsample=subsample_acquisition,
119            )
120        else:
121            acquisition_function = UndirectedVTL(
122                mini_batch_size=mini_batch_size,
123                embedding_batch_size=embedding_batch_size,
124                num_workers=num_workers,
125                subsample=subsample_acquisition,
126            )
127        return cls(
128            dataset=dataset,
129            batch_size=batch_size,
130            acquisition_function=acquisition_function,  # type: ignore
131            device=device,
132        )
133
134    def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
135        r"""
136        Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`.
137
138        .. warning::
139
140            The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`.
141
142        :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded.
143        :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`.
144        """
145
146        return self.acquisition_function.select(
147            batch_size=self.batch_size,
148            model=model,  # type: ignore
149            dataset=self.dataset,
150            device=self.device,
151        )
152
153    def with_target(self, target: torch.Tensor) -> ActiveDataLoader[M]:
154        r"""
155        Returns the active data loader with a new target.
156
157        :param target: Tensor of prediction targets (shape $m \times d$).
158        :return: Updated active data loader.
159        """
160
161        assert isinstance(
162            self.acquisition_function, Targeted
163        ), "Acquisition function must be targeted"
164        self.acquisition_function.set_target(target)
165        return self

ActiveDataLoader can be used as a drop-in replacement for random data selection or nearest neighbor retrieval:

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
batch = dataset[data_loader.next(model)]

where

  • model is a PyTorch nn.Module,
  • dataset is a dataset of inputs (where dataset[i] returns a vector of length $d$), and
  • target is a tensor of prediction targets (shape $m \times d$) or None.

If dataset already includes pre-computed embeddings, model can be omitted:

data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
batch = dataset[data_loader.next()]

The target can also be updated sequentially:

data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64, force_targeted=True)
for target in targets:
    batch = dataset[data_loader.with_target(target).next(model)]
ActiveDataLoader( dataset: activeft.data.Dataset, batch_size: int, acquisition_function: activeft.acquisition_functions.AcquisitionFunction[~M], device: torch.device | None = None)
59    def __init__(
60        self,
61        dataset: Dataset,
62        batch_size: int,
63        acquisition_function: AcquisitionFunction[M],
64        device: torch.device | None = None,
65    ):
66        """
67        Explicitly constructs an active data loader with a custom acquisition function.
68        `activeft` supports a wide range of acquisition functions which are summarized in `activeft.acquisition_functions`.
69        """
70
71        assert len(dataset) > 0, "Data must be non-empty"
72        assert batch_size > 0, "Batch size must be positive"
73
74        self.dataset = dataset
75        self.batch_size = batch_size
76        self.acquisition_function = acquisition_function
77        self.device = device

Explicitly constructs an active data loader with a custom acquisition function. activeft supports a wide range of acquisition functions which are summarized in activeft.acquisition_functions.

Inputs (shape $n \times d$) to be selected from.

batch_size: int

Size of the batch to be selected.

Acquisition function to be used for data selection.

device: torch.device | None = None

Device used for computation of the acquisition function.

@classmethod
def initialize( cls, dataset: activeft.data.Dataset, target: torch.Tensor | None, batch_size: int, device: torch.device | None = None, subsampled_target_frac: float = 1, max_target_size: int | None = None, mini_batch_size: int = 1000, embedding_batch_size: int = 100, num_workers: int = 0, subsample_acquisition: bool = False, force_targeted: bool = False):
 79    @classmethod
 80    def initialize(
 81        cls,
 82        dataset: Dataset,
 83        target: torch.Tensor | None,
 84        batch_size: int,
 85        device: torch.device | None = None,
 86        subsampled_target_frac: float = 1,
 87        max_target_size: int | None = None,
 88        mini_batch_size: int = DEFAULT_MINI_BATCH_SIZE,
 89        embedding_batch_size: int = DEFAULT_EMBEDDING_BATCH_SIZE,
 90        num_workers: int = DEFAULT_NUM_WORKERS,
 91        subsample_acquisition: bool = DEFAULT_SUBSAMPLE,
 92        force_targeted: bool = False,
 93    ):
 94        r"""
 95        Initializes an active data loader.
 96
 97        :param dataset: Inputs (shape $n \times d$) to be selected from.
 98        :param target: Tensor of prediction targets (shape $m \times d$) or `None`.
 99        :param batch_size: Size of the batch to be selected.
100        :param device: Device used for computation of the acquisition function.
101        :param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`.
102        :param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`.
103        :param mini_batch_size: Size of mini batches used for computing the acquisition function.
104        :param embedding_batch_size: Batch size used for computing the embeddings.
105        :param num_workers: Number of workers used for data loading.
106        :param subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function.
107        :param force_targeted: Whether to force targeted data selection. If `True`, `target` must be provided subsequently using `with_target`.
108        """
109
110        if target is not None or force_targeted:
111            acquisition_function = VTL(
112                target=target if target is not None else torch.tensor([]),
113                subsampled_target_frac=subsampled_target_frac,
114                max_target_size=max_target_size,
115                mini_batch_size=mini_batch_size,
116                embedding_batch_size=embedding_batch_size,
117                num_workers=num_workers,
118                subsample=subsample_acquisition,
119            )
120        else:
121            acquisition_function = UndirectedVTL(
122                mini_batch_size=mini_batch_size,
123                embedding_batch_size=embedding_batch_size,
124                num_workers=num_workers,
125                subsample=subsample_acquisition,
126            )
127        return cls(
128            dataset=dataset,
129            batch_size=batch_size,
130            acquisition_function=acquisition_function,  # type: ignore
131            device=device,
132        )

Initializes an active data loader.

Parameters
  • dataset: Inputs (shape $n \times d$) to be selected from.
  • target: Tensor of prediction targets (shape $m \times d$) or None.
  • batch_size: Size of the batch to be selected.
  • device: Device used for computation of the acquisition function.
  • subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if target is None.
  • max_target_size: Maximum size of the target to be subsampled in each iteration. Default is None in which case the target may be arbitrarily large. Ignored if target is None.
  • mini_batch_size: Size of mini batches used for computing the acquisition function.
  • embedding_batch_size: Batch size used for computing the embeddings.
  • num_workers: Number of workers used for data loading.
  • subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function.
  • force_targeted: Whether to force targeted data selection. If True, target must be provided subsequently using with_target.
def next(self, model: Optional[~M] = None) -> Tuple[torch.Tensor, torch.Tensor]:
134    def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]:
135        r"""
136        Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`.
137
138        .. warning::
139
140            The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`.
141
142        :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded.
143        :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`.
144        """
145
146        return self.acquisition_function.select(
147            batch_size=self.batch_size,
148            model=model,  # type: ignore
149            dataset=self.dataset,
150            device=self.device,
151        )

Selects the next batch of data provided a model which is a PyTorch nn.Module.

The computational complexity of next scales cubically with the size of the target. If the target is large, it is recommended to set max_target_size to value other than None.

Parameters
  • model: Model to be used for data selection. For embedding-based acquisition functions, model can be None in which case the data is treated as if it was already embedded.
Returns

Indices of the selected data and corresponding value of the acquisition function in the format (indices, values).

def with_target( self, target: torch.Tensor) -> ActiveDataLoader[~M]:
153    def with_target(self, target: torch.Tensor) -> ActiveDataLoader[M]:
154        r"""
155        Returns the active data loader with a new target.
156
157        :param target: Tensor of prediction targets (shape $m \times d$).
158        :return: Updated active data loader.
159        """
160
161        assert isinstance(
162            self.acquisition_function, Targeted
163        ), "Acquisition function must be targeted"
164        self.acquisition_function.set_target(target)
165        return self

Returns the active data loader with a new target.

Parameters
  • target: Tensor of prediction targets (shape $m \times d$).
Returns

Updated active data loader.