Machine Learning

PSI: detect feature drift in production

The Population Stability Index compares a variable's distribution between training and production. Usual thresholds: < 0.10 stable, 0.10-0.25 watch, > 0.25 major drift.

Prerequisites

numpy, pandas

Python
import numpy as np

def psi(reference, courant, bins=10):
    """Population Stability Index entre deux échantillons 1D."""
    edges = np.quantile(reference, np.linspace(0, 1, bins + 1))
    edges[0], edges[-1] = -np.inf, np.inf      # couvre tout le support
    ref_pct = np.histogram(reference, edges)[0] / len(reference)
    cur_pct = np.histogram(courant, edges)[0] / len(courant)
    ref_pct = np.clip(ref_pct, 1e-6, None)     # évite log(0)
    cur_pct = np.clip(cur_pct, 1e-6, None)
    return float(np.sum((cur_pct - ref_pct) * np.log(cur_pct / ref_pct)))

rapport = {c: psi(X_train[c].dropna(), X_prod[c].dropna())
           for c in X_train.select_dtypes("number").columns}

for col, v in sorted(rapport.items(), key=lambda t: -t[1])[:10]:
    statut = "DRIFT" if v > 0.25 else ("watch" if v > 0.10 else "ok")
    print(f"{col:<25} PSI={v:.3f}  [{statut}]")

Result

montant                   PSI=0.412  [DRIFT]
delai_livraison           PSI=0.218  [watch]
anciennete_jours          PSI=0.131  [watch]
nb_commandes_90j          PSI=0.087  [ok]
age                       PSI=0.041  [ok]
remise_moyenne            PSI=0.029  [ok]
frequence_connexion       PSI=0.018  [ok]
solde_moyen               PSI=0.011  [ok]
nb_incidents              PSI=0.008  [ok]
score_fidelite            PSI=0.004  [ok]
PSIDriftMonitoringProduction

Related snippets

Back to the Data Lab