Skip to content

Classifiers

Built-in classifier implementations and the factory functions used to resolve classifier specs.


Factory functions

lmprobe.classifiers.resolve_classifier

resolve_classifier(classifier: str | BaseEstimator, random_state: int | None = None, classifier_kwargs: dict | None = None) -> BaseEstimator

Resolve a classifier specification to an estimator instance.

Parameters:

Name Type Description Default
classifier str | BaseEstimator

Either a string name of a built-in classifier, or a custom sklearn-compatible estimator instance.

required
random_state int | None

Random seed. Only used for built-in classifiers (strings). Custom estimators must set their own random_state.

None
classifier_kwargs dict | None

Additional keyword arguments for built-in classifiers. Ignored when a custom estimator instance is provided.

None

Returns:

Type Description
BaseEstimator

The resolved classifier instance.

lmprobe.classifiers.build_classifier

build_classifier(name: str, random_state: int | None = None, classifier_kwargs: dict | None = None) -> BaseEstimator

Build a classifier by name with the given random_state.

Parameters:

Name Type Description Default
name str

Name of the built-in classifier. One of: - "logistic_regression": L2-regularized logistic regression (default) - "logistic_regression_cv": Logistic regression with CV-tuned regularization - "ridge": Ridge classifier (fast, no probabilities) - "ridge_regression": Ridge regression for regression tasks - "svm": Linear SVM with Platt scaling for probabilities - "sgd": SGD classifier (scalable to large datasets) - "mass_mean": Mass-Mean Probing (difference-in-means direction) - "lda": Linear Discriminant Analysis (covariance-corrected mass mean)

required
random_state int | None

Random seed for reproducibility. Propagated from LinearProbe.

None
classifier_kwargs dict | None

Additional keyword arguments passed to the sklearn classifier constructor. These override the defaults (e.g., {"C": 0.01, "solver": "liblinear"} for logistic regression).

None

Returns:

Type Description
BaseEstimator

An sklearn-compatible classifier instance.

Raises:

Type Description
ValueError

If the classifier name is not recognized.

lmprobe.classifiers.validate_classifier

validate_classifier(clf: BaseEstimator) -> None

Validate that a classifier has the required interface.

Parameters:

Name Type Description Default
clf BaseEstimator

The classifier to validate.

required

Raises:

Type Description
TypeError

If the classifier lacks fit() or predict() methods.

Warns:

Type Description
UserWarning

If the classifier lacks predict_proba() method.


Custom classifier implementations

lmprobe.classifiers.MassMeanClassifier

Mass-Mean Probing classifier using difference-in-means direction.

This classifier computes the probe direction as the difference between the mean of positive and negative class activations:

θ = μ_true - μ_false

This is extremely efficient (no optimization needed) and research suggests it identifies directions that are more causally implicated in model outputs than logistic regression, despite similar classification accuracy.

For a covariance-corrected version (equivalent to Fisher's Linear Discriminant), use sklearn's LinearDiscriminantAnalysis instead.

Attributes:

Name Type Description
coef_ ndarray

The difference-in-means direction, shape (n_features,).

intercept_ float

Decision threshold (midpoint between class means projected onto coef_).

classes_ ndarray

Class labels [0, 1].

mean_positive_ ndarray

Mean of positive class samples.

mean_negative_ ndarray

Mean of negative class samples.

References

Marks & Tegmark, "The Geometry of Truth" (2023)

fit

fit(X: ndarray, y: ndarray) -> MassMeanClassifier

Fit the Mass-Mean classifier.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required
y ndarray

Binary labels, shape (n_samples,).

required

Returns:

Type Description
self

decision_function

decision_function(X: ndarray) -> np.ndarray

Compute decision scores.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Decision scores, shape (n_samples,). Positive values indicate class 1, negative values indicate class 0.

predict

predict(X: ndarray) -> np.ndarray

Predict class labels.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Predicted labels, shape (n_samples,).

predict_proba

predict_proba(X: ndarray) -> np.ndarray

Predict class probabilities using Platt-scaled decision scores.

Uses a logistic regression fitted on the 1D decision scores during fit() to produce calibrated probabilities (Platt scaling). This yields better-ranked probabilities (higher AUROC) than raw sigmoid.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Class probabilities, shape (n_samples, 2).

score

score(X: ndarray, y: ndarray) -> float

Compute accuracy.

Parameters:

Name Type Description Default
X ndarray

Feature matrix.

required
y ndarray

True labels.

required

Returns:

Type Description
float

Accuracy.

get_params

get_params(deep: bool = True) -> dict

Get parameters for this estimator (sklearn compatibility).

Parameters:

Name Type Description Default
deep bool

Ignored (no nested estimators).

True

Returns:

Type Description
dict

Empty dict (no hyperparameters).

set_params

set_params(**params: Any) -> MassMeanClassifier

Set parameters for this estimator (sklearn compatibility).

Parameters:

Name Type Description Default
**params Any

Ignored (no hyperparameters).

{}

Returns:

Type Description
self

lmprobe.classifiers.EnsembleClassifier

Ensemble classifier that averages predictions across regularization strengths.

Trains multiple logistic regression models at different C values and averages their predicted probabilities for more robust predictions.

Parameters:

Name Type Description Default
C_values list[float] | None

Regularization strengths. Defaults to [0.01, 0.1, 0.5, 1.0, 5.0].

None
solver str

Solver for LogisticRegression. Default "lbfgs".

'lbfgs'
max_iter int

Maximum iterations per model. Default 1000.

1000
random_state int | None

Random seed for reproducibility.

None

Attributes:

Name Type Description
classes_ ndarray

Class labels [0, 1].

estimators_ list[LogisticRegression]

Fitted logistic regression models.

coef_ ndarray

Averaged coefficients across all models.

intercept_ ndarray

Averaged intercepts across all models.

fit

fit(X: ndarray, y: ndarray) -> EnsembleClassifier

Fit one LogisticRegression per C value.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required
y ndarray

Labels, shape (n_samples,).

required

Returns:

Type Description
self

predict_proba

predict_proba(X: ndarray) -> np.ndarray

Average predicted probabilities across all models.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Averaged class probabilities, shape (n_samples, n_classes).

predict

predict(X: ndarray) -> np.ndarray

Predict by thresholding averaged probabilities at 0.5.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Predicted labels, shape (n_samples,).

score

score(X: ndarray, y: ndarray) -> float

Compute accuracy.

Parameters:

Name Type Description Default
X ndarray

Feature matrix.

required
y ndarray

True labels.

required

Returns:

Type Description
float

Accuracy.

get_params

get_params(deep: bool = True) -> dict

Get parameters for this estimator (sklearn compatibility).

set_params

set_params(**params: Any) -> EnsembleClassifier

Set parameters for this estimator (sklearn compatibility).

lmprobe.classifiers.GroupLassoClassifier

Wrapper around skglm Group Lasso for automatic layer selection.

This classifier treats each layer's hidden dimensions as a group and applies L2,1 regularization (Group Lasso) to encourage entire groups (layers) to become zero, effectively performing layer selection.

Parameters:

Name Type Description Default
hidden_dim int

Hidden dimension size per layer.

required
n_layers int

Number of layers being probed.

required
alpha float

Regularization strength. Higher values induce more sparsity.

0.01
random_state int | None

Random seed for reproducibility.

None

Attributes:

Name Type Description
coef_ ndarray

Fitted coefficients, shape (n_features,) = (hidden_dim * n_layers,).

intercept_ ndarray

Fitted intercept.

classes_ ndarray

Class labels.

selected_groups_ list[int]

Indices of groups (layers) with non-zero norms after fitting.

group_norms_ ndarray

L2 norm of coefficients for each group, shape (n_layers,).

fit

fit(X: ndarray, y: ndarray) -> GroupLassoClassifier

Fit the Group Lasso classifier.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, hidden_dim * n_layers).

required
y ndarray

Labels, shape (n_samples,).

required

Returns:

Type Description
self

predict

predict(X: ndarray) -> np.ndarray

Predict class labels.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Predicted labels, shape (n_samples,).

predict_proba

predict_proba(X: ndarray) -> np.ndarray

Predict class probabilities.

Note: skglm's GeneralizedLinearEstimator does not have native predict_proba. We compute probabilities from the linear scores using the sigmoid function.

Parameters:

Name Type Description Default
X ndarray

Feature matrix, shape (n_samples, n_features).

required

Returns:

Type Description
ndarray

Class probabilities, shape (n_samples, 2).

score

score(X: ndarray, y: ndarray) -> float

Compute accuracy.

Parameters:

Name Type Description Default
X ndarray

Feature matrix.

required
y ndarray

True labels.

required

Returns:

Type Description
float

Accuracy.