Skip to content

Ensemble

Multi-probe ensemble support with voting-based aggregation and bootstrap stability analysis.


lmprobe.ensemble.ProbeEnsemble

Ensemble of multiple Probe instances with voting-based aggregation.

Wraps multiple :class:Probe instances and provides a unified fit/predict/predict_proba/score interface. Supports soft voting (averaged probabilities) and hard voting (majority predictions).

Parameters:

Name Type Description Default
probes list[Probe]

Pre-configured Probe instances.

required
weights list[float] | None

Per-probe weights, normalized to sum=1. If None, uniform weights.

None
voting str

Aggregation strategy: "soft" (average probabilities) or "hard" (majority vote). Default "soft".

.. note:: Soft voting requires all member probes to support predict_proba() (e.g. logistic regression, SVM with probability=True). Classifiers like Ridge that lack predict_proba must use voting="hard".

'soft'

Examples:

>>> from lmprobe import Probe
>>> from lmprobe.ensemble import ProbeEnsemble
>>> p1 = Probe(model="stas/tiny-random-llama-2", layers=-1,
...            classifier="logistic_regression", device="cpu")
>>> p2 = Probe(model="stas/tiny-random-llama-2", layers=-1,
...            classifier="random_forest", device="cpu")
>>> ensemble = ProbeEnsemble([p1, p2])
>>> ensemble.fit(pos_prompts, neg_prompts)
>>> preds = ensemble.predict(test_prompts)

from_configs classmethod

from_configs(model: str, configs: list[dict], weights: list[float] | None = None, voting: str = 'soft', **shared_kwargs: Any) -> ProbeEnsemble

Create an ensemble from per-probe config dicts.

Each dict in configs is merged with shared_kwargs to construct a :class:Probe.

Parameters:

Name Type Description Default
model str

HuggingFace model ID (shared across all probes).

required
configs list[dict]

Per-probe overrides (e.g. layers, classifier, pca_components).

required
weights list[float] | None

Per-probe weights.

None
voting str

Aggregation strategy.

'soft'
**shared_kwargs Any

Shared keyword arguments passed to every Probe constructor (e.g. device, remote, pooling, random_state).

{}

Returns:

Type Description
ProbeEnsemble

bootstrap classmethod

bootstrap(base_probe: Probe, n_resamples: int = 10, random_state: int | None = None, weights: list[float] | None = None, voting: str = 'soft') -> ProbeEnsemble

Create a bootstrap ensemble by cloning a probe.

During fit(), each member trains on a different bootstrap resample of the training data (activations are extracted once and served from cache).

Parameters:

Name Type Description Default
base_probe Probe

Template probe to clone.

required
n_resamples int

Number of bootstrap resamples (ensemble members).

10
random_state int | None

Random seed for reproducible bootstrap sampling.

None
weights list[float] | None

Per-probe weights.

None
voting str

Aggregation strategy.

'soft'

Returns:

Type Description
ProbeEnsemble

fit

fit(positive_prompts: list[str], negative_prompts: list[str] | ndarray | list[int] | None = None, remote: bool | None = None, sample_weight: ndarray | list[float] | None = None, groups: ndarray | list | None = None) -> ProbeEnsemble

Fit all probes in the ensemble.

Performs a single warmup extraction pass covering all layers, then fits each member probe. In bootstrap mode, each member trains on a different bootstrap resample.

Parameters:

Name Type Description Default
positive_prompts list[str]

In contrastive mode: positive-class prompts. In standard mode: all prompts.

required
negative_prompts list[str] | ndarray | list[int] | None

In contrastive mode: negative-class prompts. In standard mode: labels.

None
remote bool | None

Override remote setting for all probes.

None
sample_weight ndarray | list[float] | None

Per-sample weights passed to each probe's classifier fit(). Length must match the total number of training samples. In bootstrap mode, weights are resampled along with the data.

None
groups ndarray | list | None

Group labels for group-balanced bootstrap resampling. Only used in bootstrap mode — each ensemble member draws an equal number of samples from each group. Length must match the total number of training samples. Ignored in non-bootstrap mode.

None

Returns:

Type Description
ProbeEnsemble

Self, for method chaining.

predict_proba

predict_proba(prompts: list[str], remote: bool | None = None) -> np.ndarray

Predict weighted-average class probabilities.

Parameters:

Name Type Description Default
prompts list[str]

Text prompts to classify.

required
remote bool | None

Override remote setting for all probes.

None

Returns:

Type Description
ndarray

Averaged class probabilities, shape (n_samples, n_classes).

predict

predict(prompts: list[str], remote: bool | None = None) -> np.ndarray

Predict class labels.

Soft voting: argmax of weighted-average probabilities. Hard voting: mode of per-probe predictions.

Parameters:

Name Type Description Default
prompts list[str]

Text prompts to classify.

required
remote bool | None

Override remote setting for all probes.

None

Returns:

Type Description
ndarray

Predicted labels, shape (n_samples,).

score

score(prompts: list[str], labels: list[int] | ndarray, remote: bool | None = None) -> float

Compute accuracy on test data.

Parameters:

Name Type Description Default
prompts list[str]

Test prompts.

required
labels list[int] | ndarray

True labels.

required
remote bool | None

Override remote setting.

None

Returns:

Type Description
float

Accuracy.

prediction_std

prediction_std(prompts: list[str], remote: bool | None = None) -> np.ndarray

Per-sample standard deviation of positive-class probability.

Useful for bootstrap stability analysis — high std indicates the ensemble members disagree.

Parameters:

Name Type Description Default
prompts list[str]

Text prompts.

required
remote bool | None

Override remote setting.

None

Returns:

Type Description
ndarray

Standard deviation of positive-class probability per sample, shape (n_samples,).

save

save(path: str) -> None

Save the fitted ensemble to disk.

Parameters:

Name Type Description Default
path str

Path to save the ensemble.

required

load classmethod

load(path: str) -> ProbeEnsemble

Load a fitted ensemble from disk.

Parameters:

Name Type Description Default
path str

Path to the saved ensemble.

required

Returns:

Type Description
ProbeEnsemble