activeft
Active Fine-Tuning (activeft
) is a Python package for intelligent active data selection.
Why Active Data Selection?
As opposed to random data selection, active data selection chooses data adaptively utilizing the current model. In other words,
active data selection pays attention to the most useful data
which allows for faster learning and adaptation. There are mainly two reasons for why some data may be particularly useful:- Relevance: The data is closely related to a particular task, such as answering a specific prompt.
- Diversity: The data contains non-redundant information that is not yet captured by the model.
A dataset that is both relevant and diverse is informative for the model. This is related to memory recall, where the brain recalls informative and relevant memories (think "data") to make sense of the current sensory input. Focusing recall on useful data enables efficient learning from few examples.
activeft
provides a simple interface for active data selection, which can be used as a drop-in replacement for random data selection or nearest neighbor retrieval.
Getting Started
You can install activeft
from PyPI via pip:
pip install activeft
We briefly discuss how to use activeft
for standard fine-tuning and test-time fine-tuning.
Example: Fine-tuning
Given a PyTorch model which may (but does not have to be!) pre-trained, we can use activeft
to efficiently fine-tune the model.
This model may be generative (e.g., a language model) or discriminative (e.g., a classifier), and can use any architecture.
We only need the following things:
- A dataset of inputs
dataset
(such thatdataset[i]
returns a vector of length $d$) from which we want to select batches for fine-tuning. If one has a supervised dataset returning input-label pairs, thenactiveft.data.InputDataset(dataset)
can be used to obtain a dataset over the input space. - A tensor of prediction targets
target
($m \times d$) which specifies the task we want to fine-tune the model for. Here, $m$ can be quite small, e.g., equal to the number of classes in a classification task. If there is no specific task for training, then active data selection can still be useful as we will see later. - The
model
can be any PyTorchnn.Module
with anembed(x)
method that computes (latent) embeddings for the given inputsx
, e.g., the representation ofx
from the penultimate layer. Seeactiveft.model.ModelWithEmbedding
for more details. Alternatively, the model can have akernel(x1,x2)
method that computes a kernel for given inputsx1
andx2
(seeactiveft.model.ModelWithKernel
).
For active data selection to be effective, it is important that the model's embeddings are somewhat representative of the data. In particular, embeddings should capture the relationship between the data and the task.
With this in place, we can initialize the "active" data loader
from activeft import ActiveDataLoader
data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
To obtain the next batch from data
, we can then simply call
batch = data[data_loader.next(model)]
Note that the active data selection of the next batch is utilizing the current model
to select the most relevant data with respect to the given target
.
Combining the data selection with a model update step, we can implement a simple training loop as follows:
while not converged:
batch = dataset[data_loader.next(model)]
model.step(batch)
Notice the feedback loop(!): the batch selection improves as the model learns and the model learns faster as the batch selection improves.
This is it! Training with active data selection is as simple as that.
"Undirected" Data Selection
If there is no specific task for training then all data is equally relevant, yet, we can still use active data selection to select the most informative data. To do this, simply initialize
data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64)
Example: Test-Time Fine-Tuning
The above example described active data selection in the context of training a model with multiple batches. This usually happens at "train-time" or during "post-training".
The following example demonstrates how to use activeft
at "test-time" to obtain a model that is as good as possible on a specific test instance.
For example, with a language model, this would fine-tune the model for a few gradient steps on data selected specifically for a given prompt.
We refer to the following paper for more details: Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs.
We can also use the intelligent retrieval of informative and relevant data outside a training loop — for example, for in-context learning and retrieval-augmented generation.
The setup is analogous to the previous section: we have a pre-trained model
, a dataset data
to query from, and target
s (e.g., a prompt) for which we want to retrieve relevant data.
We can use activeft
to query the most useful data and then add it to the model's context:
from activeft import ActiveDataLoader
data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=10)
data = dataset[data_loader.next(model)]
model.step(data)
Again: very simple!
Scaling to Large Datasets
By default activeft
maintains a matrix of size of the dataset in memory. This is not feasible for very large datasets.
Some acquisition functions (such as activeft.acquisition_functions.LazyVTL
) allow for efficient computation of the acquisition function without storing the entire dataset in memory.
An alternative approach is to pre-select a subset of the data using nearest neighbor retrieval (using Faiss), before initializing the ActiveDataLoader
.
The following is an example of this approach in the context of test-time fine-tuning:
import torch
import faiss
from activeft.sift import Retriever
# Before Test-Time
embeddings = torch.randn(1000, 768)
index = faiss.IndexFlatIP(embeddings.size(1))
index.add(embeddings)
retriever = Retriever(index)
# At Test-Time, given query
query_embeddings = torch.randn(1, 768)
indices = retriever.search(query_embeddings, N=10, K=1_000)
data = embeddings[indices]
model.step(data) # Use data to fine-tune base model, then forward pass query
activeft.sift.Retriever
first pre-selects K
nearest neighbors and then uses activeft
to select the N
most informative data for the given query from this subset.
Citation
If you use the code in a publication, please cite our papers:
# Large-Scale Learning at Test-Time with SIFT
@article{hubotter2024efficiently,
title = {Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs},
author = {H{\"u}botter, Jonas and Bongni, Sascha and Hakimi, Ido and Krause, Andreas},
year = 2024,
journal = {arXiv Preprint}
}
# Theory and Fundamental Algorithms for Transductive Active Learning
@inproceedings{hubotter2024transductive,
title = {Transductive Active Learning: Theory and Applications},
author = {H{\"u}botter, Jonas and Sukhija, Bhavya and Treven, Lenart and As, Yarden and Krause, Andreas},
year = 2024,
booktitle = {Advances in Neural Information Processing Systems}
}
1r""" 2*Active Fine-Tuning* (`activeft`) is a Python package for intelligent active data selection. 3 4## Why Active Data Selection? 5 6As opposed to random data selection, active data selection chooses data adaptively utilizing the current model. 7In other words, <p style="text-align: center;">active data selection pays *attention* to the most useful data</p> which allows for faster learning and adaptation. 8There are mainly two reasons for why some data may be particularly useful: 9 101. **Relevance**: The data is closely related to a particular task, such as answering a specific prompt. 112. **Diversity**: The data contains non-redundant information that is not yet captured by the model. 12 13A dataset that is both relevant and diverse is *informative* for the model. 14This is related to memory recall, where the brain recalls informative and relevant memories (think "data") to make sense of the current sensory input. 15Focusing recall on useful data enables efficient learning from few examples. 16 17`activeft` provides a simple interface for active data selection, which can be used as a drop-in replacement for random data selection or nearest neighbor retrieval. 18 19## Getting Started 20 21You can install `activeft` from [PyPI](https://pypi.org/project/activeft/) via pip: 22 23```bash 24pip install activeft 25``` 26 27We briefly discuss how to use `activeft` for standard [fine-tuning](#example-fine-tuning) and [test-time fine-tuning](#example-test-time-fine-tuning). 28 29### Example: Fine-tuning 30 31Given a [PyTorch](https://pytorch.org) model which may (but does not have to be!) pre-trained, we can use `activeft` to efficiently fine-tune the model. 32This model may be generative (e.g., a language model) or discriminative (e.g., a classifier), and can use any architecture. 33 34We only need the following things: 35- A dataset of inputs `dataset` (such that `dataset[i]` returns a vector of length $d$) from which we want to select batches for fine-tuning. If one has a supervised dataset returning input-label pairs, then `activeft.data.InputDataset(dataset)` can be used to obtain a dataset over the input space. 36- A tensor of prediction targets `target` ($m \times d$) which specifies the task we want to fine-tune the model for. 37Here, $m$ can be quite small, e.g., equal to the number of classes in a classification task. 38If there is no *specific* task for training, then active data selection can still be useful as we will see [later](#undirected-data-selection). 39- The `model` can be any PyTorch `nn.Module` with an `embed(x)` method that computes (latent) embeddings for the given inputs `x`, e.g., the representation of `x` from the penultimate layer. 40See `activeft.model.ModelWithEmbedding` for more details. Alternatively, the model can have a `kernel(x1,x2)` method that computes a kernel for given inputs `x1` and `x2` (see `activeft.model.ModelWithKernel`). 41 42.. note:: 43 44 For active data selection to be effective, it is important that the model's embeddings are somewhat representative of the data. 45 In particular, embeddings should capture the relationship between the data and the task. 46 47With this in place, we can initialize the "active" data loader 48 49```python 50from activeft import ActiveDataLoader 51 52data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64) 53``` 54 55To obtain the next batch from `data`, we can then simply call 56 57```python 58batch = data[data_loader.next(model)] 59``` 60 61Note that the active data selection of the next batch is utilizing the current `model` to select the most relevant data with respect to the given `target`. 62 63Combining the data selection with a model update step, we can implement a simple training loop as follows: 64 65```python 66while not converged: 67 batch = dataset[data_loader.next(model)] 68 model.step(batch) 69``` 70 71Notice the feedback loop(!): the batch selection improves as the model learns and the model learns faster as the batch selection improves. 72 73This is it! 74Training with active data selection is as simple as that. 75 76#### "Undirected" Data Selection 77 78If there is no specific task for training then all data is equally relevant, yet, we can still use active data selection to select the most informative data. 79To do this, simply initialize 80 81```python 82data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64) 83``` 84 85### Example: Test-Time Fine-Tuning 86 87The above example described active data selection in the context of training a model with multiple batches. This usually happens at "train-time" or during "post-training". 88 89The following example demonstrates how to use `activeft` at "test-time" to obtain a model that is as good as possible on a specific test instance. 90For example, with a language model, this would fine-tune the model for a few gradient steps on data selected specifically for a given prompt. 91We refer to the following paper for more details: [Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs](https://arxiv.org/abs/2410.08020). 92 93We can also use the intelligent retrieval of informative and relevant data outside a training loop — for example, for in-context learning and retrieval-augmented generation. 94 95The setup is analogous to the previous section: we have a pre-trained `model`, a dataset `data` to query from, and `target`s (e.g., a prompt) for which we want to retrieve relevant data. 96We can use `activeft` to query the most useful data and then add it to the model's context: 97 98```python 99from activeft import ActiveDataLoader 100 101data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=10) 102data = dataset[data_loader.next(model)] 103model.step(data) 104``` 105 106Again: very simple! 107 108### Scaling to Large Datasets 109 110By default `activeft` maintains a matrix of size of the dataset in memory. This is not feasible for very large datasets. 111Some acquisition functions (such as `activeft.acquisition_functions.LazyVTL`) allow for efficient computation of the acquisition function without storing the entire dataset in memory. 112An alternative approach is to pre-select a subset of the data using nearest neighbor retrieval (using [Faiss](https://github.com/facebookresearch/faiss)), before initializing the `ActiveDataLoader`. 113The following is an example of this approach in the context of [test-time fine-tuning](#example-test-time-fine-tuning): 114 115```python 116import torch 117import faiss 118from activeft.sift import Retriever 119 120# Before Test-Time 121embeddings = torch.randn(1000, 768) 122index = faiss.IndexFlatIP(embeddings.size(1)) 123index.add(embeddings) 124retriever = Retriever(index) 125 126# At Test-Time, given query 127query_embeddings = torch.randn(1, 768) 128indices = retriever.search(query_embeddings, N=10, K=1_000) 129data = embeddings[indices] 130model.step(data) # Use data to fine-tune base model, then forward pass query 131``` 132 133`activeft.sift.Retriever` first pre-selects `K` nearest neighbors and then uses `activeft` to select the `N` most informative data for the given query from this subset. 134 135## Citation 136 137If you use the code in a publication, please cite our papers: 138 139```bibtex 140# Large-Scale Learning at Test-Time with SIFT 141@article{hubotter2024efficiently, 142 title = {Efficiently Learning at Test-Time: Active Fine-Tuning of LLMs}, 143 author = {H{\"u}botter, Jonas and Bongni, Sascha and Hakimi, Ido and Krause, Andreas}, 144 year = 2024, 145 journal = {arXiv Preprint} 146} 147 148# Theory and Fundamental Algorithms for Transductive Active Learning 149@inproceedings{hubotter2024transductive, 150 title = {Transductive Active Learning: Theory and Applications}, 151 author = {H{\"u}botter, Jonas and Sukhija, Bhavya and Treven, Lenart and As, Yarden and Krause, Andreas}, 152 year = 2024, 153 booktitle = {Advances in Neural Information Processing Systems} 154} 155``` 156 157--- 158""" 159 160from activeft.active_data_loader import ActiveDataLoader 161from activeft import acquisition_functions, data, embeddings, model, sift 162 163__all__ = [ 164 "ActiveDataLoader", 165 "acquisition_functions", 166 "data", 167 "embeddings", 168 "model", 169 "sift", 170] 171__version__ = "0.1.1" 172__author__ = "Jonas Hübotter" 173__credits__ = "ETH Zurich, Switzerland"
17class ActiveDataLoader(Generic[M]): 18 r""" 19 `ActiveDataLoader` can be used as a drop-in replacement for random data selection or nearest neighbor retrieval: 20 21 ```python 22 data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64) 23 batch = dataset[data_loader.next(model)] 24 ``` 25 26 where 27 - `model` is a PyTorch `nn.Module`, 28 - `dataset` is a dataset of inputs (where `dataset[i]` returns a vector of length $d$), and 29 - `target` is a tensor of prediction targets (shape $m \times d$) or `None`. 30 31 If `dataset` already includes pre-computed embeddings, `model` can be omitted: 32 33 ```python 34 data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64) 35 batch = dataset[data_loader.next()] 36 ``` 37 38 The target can also be updated sequentially: 39 40 ```python 41 data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64, force_targeted=True) 42 for target in targets: 43 batch = dataset[data_loader.with_target(target).next(model)] 44 ``` 45 """ 46 47 dataset: Dataset 48 r"""Inputs (shape $n \times d$) to be selected from.""" 49 50 batch_size: int 51 r"""Size of the batch to be selected.""" 52 53 acquisition_function: AcquisitionFunction[M] 54 r"""Acquisition function to be used for data selection.""" 55 56 device: torch.device | None = None 57 r"""Device used for computation of the acquisition function.""" 58 59 def __init__( 60 self, 61 dataset: Dataset, 62 batch_size: int, 63 acquisition_function: AcquisitionFunction[M], 64 device: torch.device | None = None, 65 ): 66 """ 67 Explicitly constructs an active data loader with a custom acquisition function. 68 `activeft` supports a wide range of acquisition functions which are summarized in `activeft.acquisition_functions`. 69 """ 70 71 assert len(dataset) > 0, "Data must be non-empty" 72 assert batch_size > 0, "Batch size must be positive" 73 74 self.dataset = dataset 75 self.batch_size = batch_size 76 self.acquisition_function = acquisition_function 77 self.device = device 78 79 @classmethod 80 def initialize( 81 cls, 82 dataset: Dataset, 83 target: torch.Tensor | None, 84 batch_size: int, 85 device: torch.device | None = None, 86 subsampled_target_frac: float = 1, 87 max_target_size: int | None = None, 88 mini_batch_size: int = DEFAULT_MINI_BATCH_SIZE, 89 embedding_batch_size: int = DEFAULT_EMBEDDING_BATCH_SIZE, 90 num_workers: int = DEFAULT_NUM_WORKERS, 91 subsample_acquisition: bool = DEFAULT_SUBSAMPLE, 92 force_targeted: bool = False, 93 ): 94 r""" 95 Initializes an active data loader. 96 97 :param dataset: Inputs (shape $n \times d$) to be selected from. 98 :param target: Tensor of prediction targets (shape $m \times d$) or `None`. 99 :param batch_size: Size of the batch to be selected. 100 :param device: Device used for computation of the acquisition function. 101 :param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`. 102 :param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`. 103 :param mini_batch_size: Size of mini batches used for computing the acquisition function. 104 :param embedding_batch_size: Batch size used for computing the embeddings. 105 :param num_workers: Number of workers used for data loading. 106 :param subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function. 107 :param force_targeted: Whether to force targeted data selection. If `True`, `target` must be provided subsequently using `with_target`. 108 """ 109 110 if target is not None or force_targeted: 111 acquisition_function = VTL( 112 target=target if target is not None else torch.tensor([]), 113 subsampled_target_frac=subsampled_target_frac, 114 max_target_size=max_target_size, 115 mini_batch_size=mini_batch_size, 116 embedding_batch_size=embedding_batch_size, 117 num_workers=num_workers, 118 subsample=subsample_acquisition, 119 ) 120 else: 121 acquisition_function = UndirectedVTL( 122 mini_batch_size=mini_batch_size, 123 embedding_batch_size=embedding_batch_size, 124 num_workers=num_workers, 125 subsample=subsample_acquisition, 126 ) 127 return cls( 128 dataset=dataset, 129 batch_size=batch_size, 130 acquisition_function=acquisition_function, # type: ignore 131 device=device, 132 ) 133 134 def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]: 135 r""" 136 Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`. 137 138 .. warning:: 139 140 The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`. 141 142 :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded. 143 :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`. 144 """ 145 146 return self.acquisition_function.select( 147 batch_size=self.batch_size, 148 model=model, # type: ignore 149 dataset=self.dataset, 150 device=self.device, 151 ) 152 153 def with_target(self, target: torch.Tensor) -> ActiveDataLoader[M]: 154 r""" 155 Returns the active data loader with a new target. 156 157 :param target: Tensor of prediction targets (shape $m \times d$). 158 :return: Updated active data loader. 159 """ 160 161 assert isinstance( 162 self.acquisition_function, Targeted 163 ), "Acquisition function must be targeted" 164 self.acquisition_function.set_target(target) 165 return self
ActiveDataLoader
can be used as a drop-in replacement for random data selection or nearest neighbor retrieval:
data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
batch = dataset[data_loader.next(model)]
where
model
is a PyTorchnn.Module
,dataset
is a dataset of inputs (wheredataset[i]
returns a vector of length $d$), andtarget
is a tensor of prediction targets (shape $m \times d$) orNone
.
If dataset
already includes pre-computed embeddings, model
can be omitted:
data_loader = ActiveDataLoader.initialize(dataset, target, batch_size=64)
batch = dataset[data_loader.next()]
The target can also be updated sequentially:
data_loader = ActiveDataLoader.initialize(dataset, target=None, batch_size=64, force_targeted=True)
for target in targets:
batch = dataset[data_loader.with_target(target).next(model)]
59 def __init__( 60 self, 61 dataset: Dataset, 62 batch_size: int, 63 acquisition_function: AcquisitionFunction[M], 64 device: torch.device | None = None, 65 ): 66 """ 67 Explicitly constructs an active data loader with a custom acquisition function. 68 `activeft` supports a wide range of acquisition functions which are summarized in `activeft.acquisition_functions`. 69 """ 70 71 assert len(dataset) > 0, "Data must be non-empty" 72 assert batch_size > 0, "Batch size must be positive" 73 74 self.dataset = dataset 75 self.batch_size = batch_size 76 self.acquisition_function = acquisition_function 77 self.device = device
Explicitly constructs an active data loader with a custom acquisition function.
activeft
supports a wide range of acquisition functions which are summarized in activeft.acquisition_functions
.
Acquisition function to be used for data selection.
79 @classmethod 80 def initialize( 81 cls, 82 dataset: Dataset, 83 target: torch.Tensor | None, 84 batch_size: int, 85 device: torch.device | None = None, 86 subsampled_target_frac: float = 1, 87 max_target_size: int | None = None, 88 mini_batch_size: int = DEFAULT_MINI_BATCH_SIZE, 89 embedding_batch_size: int = DEFAULT_EMBEDDING_BATCH_SIZE, 90 num_workers: int = DEFAULT_NUM_WORKERS, 91 subsample_acquisition: bool = DEFAULT_SUBSAMPLE, 92 force_targeted: bool = False, 93 ): 94 r""" 95 Initializes an active data loader. 96 97 :param dataset: Inputs (shape $n \times d$) to be selected from. 98 :param target: Tensor of prediction targets (shape $m \times d$) or `None`. 99 :param batch_size: Size of the batch to be selected. 100 :param device: Device used for computation of the acquisition function. 101 :param subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if `target` is `None`. 102 :param max_target_size: Maximum size of the target to be subsampled in each iteration. Default is `None` in which case the target may be arbitrarily large. Ignored if `target` is `None`. 103 :param mini_batch_size: Size of mini batches used for computing the acquisition function. 104 :param embedding_batch_size: Batch size used for computing the embeddings. 105 :param num_workers: Number of workers used for data loading. 106 :param subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function. 107 :param force_targeted: Whether to force targeted data selection. If `True`, `target` must be provided subsequently using `with_target`. 108 """ 109 110 if target is not None or force_targeted: 111 acquisition_function = VTL( 112 target=target if target is not None else torch.tensor([]), 113 subsampled_target_frac=subsampled_target_frac, 114 max_target_size=max_target_size, 115 mini_batch_size=mini_batch_size, 116 embedding_batch_size=embedding_batch_size, 117 num_workers=num_workers, 118 subsample=subsample_acquisition, 119 ) 120 else: 121 acquisition_function = UndirectedVTL( 122 mini_batch_size=mini_batch_size, 123 embedding_batch_size=embedding_batch_size, 124 num_workers=num_workers, 125 subsample=subsample_acquisition, 126 ) 127 return cls( 128 dataset=dataset, 129 batch_size=batch_size, 130 acquisition_function=acquisition_function, # type: ignore 131 device=device, 132 )
Initializes an active data loader.
Parameters
- dataset: Inputs (shape $n \times d$) to be selected from.
- target: Tensor of prediction targets (shape $m \times d$) or
None
. - batch_size: Size of the batch to be selected.
- device: Device used for computation of the acquisition function.
- subsampled_target_frac: Fraction of the target to be subsampled in each iteration. Must be in $(0,1]$. Default is $1$. Ignored if
target
isNone
. - max_target_size: Maximum size of the target to be subsampled in each iteration. Default is
None
in which case the target may be arbitrarily large. Ignored iftarget
isNone
. - mini_batch_size: Size of mini batches used for computing the acquisition function.
- embedding_batch_size: Batch size used for computing the embeddings.
- num_workers: Number of workers used for data loading.
- subsample_acquisition: Whether to subsample the data to a single mini batch before computing the acquisition function.
- force_targeted: Whether to force targeted data selection. If
True
,target
must be provided subsequently usingwith_target
.
134 def next(self, model: M | None = None) -> Tuple[torch.Tensor, torch.Tensor]: 135 r""" 136 Selects the next batch of data provided a `model` which is a PyTorch `nn.Module`. 137 138 .. warning:: 139 140 The computational complexity of `next` scales cubically with the size of the target. If the target is large, it is recommended to set `max_target_size` to value other than `None`. 141 142 :param model: Model to be used for data selection. For embedding-based acquisition functions, `model` can be `None` in which case the data is treated as if it was already embedded. 143 :return: Indices of the selected data and corresponding value of the acquisition function in the format `(indices, values)`. 144 """ 145 146 return self.acquisition_function.select( 147 batch_size=self.batch_size, 148 model=model, # type: ignore 149 dataset=self.dataset, 150 device=self.device, 151 )
Selects the next batch of data provided a model
which is a PyTorch nn.Module
.
The computational complexity of next
scales cubically with the size of the target. If the target is large, it is recommended to set max_target_size
to value other than None
.
Parameters
- model: Model to be used for data selection. For embedding-based acquisition functions,
model
can beNone
in which case the data is treated as if it was already embedded.
Returns
Indices of the selected data and corresponding value of the acquisition function in the format
(indices, values)
.
153 def with_target(self, target: torch.Tensor) -> ActiveDataLoader[M]: 154 r""" 155 Returns the active data loader with a new target. 156 157 :param target: Tensor of prediction targets (shape $m \times d$). 158 :return: Updated active data loader. 159 """ 160 161 assert isinstance( 162 self.acquisition_function, Targeted 163 ), "Acquisition function must be targeted" 164 self.acquisition_function.set_target(target) 165 return self
Returns the active data loader with a new target.
Parameters
- target: Tensor of prediction targets (shape $m \times d$).
Returns
Updated active data loader.