Probe Ensembles¶
Combine multiple probes into an ensemble for more robust predictions and uncertainty estimation. Ensembles are especially useful when you're unsure which classifier or layer gives the best signal.
When to use ensembles¶
- Robustness: Averaging across classifiers or layers smooths out individual probe weaknesses
- Uncertainty estimation: Bootstrap ensembles quantify how stable your predictions are
- Diverse perspectives: Combine probes at different layers to capture both shallow and deep representations
Basic ensemble¶
Create probes with different configurations and combine them:
from lmprobe import Probe, ProbeEnsemble
p1 = Probe(model="meta-llama/Llama-3.1-8B-Instruct", layers=-1, classifier="logistic_regression")
p2 = Probe(model="meta-llama/Llama-3.1-8B-Instruct", layers=-1, classifier="svm")
p3 = Probe(model="meta-llama/Llama-3.1-8B-Instruct", layers=16, classifier="logistic_regression")
ensemble = ProbeEnsemble([p1, p2, p3], voting="soft")
ensemble.fit(positive_prompts, negative_prompts)
predictions = ensemble.predict(test_prompts) # (n_samples,)
probabilities = ensemble.predict_proba(test_prompts) # (n_samples, n_classes)
accuracy = ensemble.score(test_prompts, test_labels)
Voting strategies¶
| Strategy | How it works | When to use |
|---|---|---|
"soft" (default) |
Average class probabilities across probes, then argmax | When all classifiers support predict_proba() |
"hard" |
Majority vote on predicted labels | When using classifiers like Ridge that lack predict_proba() |
# Soft voting (default) — requires predict_proba support
ensemble = ProbeEnsemble([p1, p2], voting="soft")
# Hard voting — works with any classifier
ensemble = ProbeEnsemble([p1, p2, p3], voting="hard")
Weighted voting¶
Give more weight to probes you trust more:
ensemble = ProbeEnsemble(
[p1, p2, p3],
weights=[2.0, 1.0, 1.0], # p1 gets double weight
voting="soft",
)
Weights are normalized to sum to 1 internally.
Factory construction¶
Create ensembles from config dicts sharing a common model. This avoids repeating the model name and shared parameters:
ensemble = ProbeEnsemble.from_configs(
model="meta-llama/Llama-3.1-8B-Instruct",
configs=[
{"layers": -1, "classifier": "logistic_regression"},
{"layers": -1, "classifier": "svm"},
{"layers": 16, "classifier": "ridge"},
],
voting="hard", # required when using Ridge
device="auto", # shared kwargs
)
Bootstrap stability analysis¶
Clone a single probe into N bootstrap resamples to measure how stable predictions are across different training subsets:
base_probe = Probe(
model="meta-llama/Llama-3.1-8B-Instruct",
layers=-1,
classifier="logistic_regression",
)
ensemble = ProbeEnsemble.bootstrap(base_probe, n_resamples=10, random_state=42)
ensemble.fit(positive_prompts, negative_prompts)
# Per-sample uncertainty: high std = ensemble members disagree
uncertainty = ensemble.prediction_std(test_prompts) # (n_samples,)
prediction_std() returns the standard deviation of the positive-class probability across ensemble members. High values indicate samples where the probe is unreliable — useful for flagging borderline cases or identifying data distribution shifts.
Interpreting uncertainty¶
prediction_std |
Interpretation |
|---|---|
| < 0.05 | High confidence — members agree |
| 0.05 – 0.15 | Moderate uncertainty |
| > 0.15 | Low confidence — consider manual review |
These thresholds are approximate and task-dependent. Calibrate on your own data.
Group-balanced bootstrap¶
When your training data has groups (e.g., different prompt sources or categories), ensure each bootstrap resample draws evenly from all groups:
ensemble = ProbeEnsemble.bootstrap(base_probe, n_resamples=10, random_state=42)
ensemble.fit(
positive_prompts, negative_prompts,
groups=group_labels, # group-balanced resampling
sample_weight=weights, # optional per-sample importance weights
)
Each ensemble member draws ceil(total / n_groups) samples from each group. Without groups, standard bootstrap resampling can under-represent minority groups.
Note
groups is only used in bootstrap mode. Passing it to a non-bootstrap ensemble emits a warning.
Save and load¶
ensemble.save("my_ensemble.pkl")
loaded = ProbeEnsemble.load("my_ensemble.pkl")
predictions = loaded.predict(test_prompts)
Performance notes¶
- Activation extraction happens once. The ensemble performs a single warmup pass covering all layers needed by any member probe. Individual
fit()andpredict()calls hit the cache. - Cost scales with number of probes, not model forward passes. Adding more ensemble members is cheap if they share cached activations.
- Bootstrap ensembles are particularly efficient since all members use the same model, layers, and pooling — only the training data subset differs.