Adding Custom Metrics¶
Add new evaluation metrics beyond accuracy and confusion matrix.
Overview¶
Metrics are in ml_src/core/metrics/.
Adding New Metric¶
Step 1: Define Metric Function¶
Edit ml_src/core/metrics/:
def calculate_f1_per_class(true_labels, pred_labels, class_names):
"""Calculate F1 score for each class."""
from sklearn.metrics import f1_score
f1_scores = f1_score(
true_labels,
pred_labels,
average=None,
labels=range(len(class_names))
)
# Create dict
f1_dict = {
class_name: f1
for class_name, f1 in zip(class_names, f1_scores)
}
return f1_dict
def save_f1_scores(true_labels, pred_labels, class_names, path):
"""Save F1 scores to file."""
f1_dict = calculate_f1_per_class(true_labels, pred_labels, class_names)
with open(path, 'w') as f:
f.write("F1 Scores per Class\n")
f.write("=" * 40 + "\n\n")
for class_name, f1 in f1_dict.items():
f.write(f"{class_name:15s}: {f1:.4f}\n")
f.write(f"\nMacro Average: {sum(f1_dict.values()) / len(f1_dict):.4f}\n")
Step 2: Call from trainer.py¶
Edit ml_src/core/trainers/ at end of train_model():
# After existing metrics
save_confusion_matrix(...)
save_classification_report(...)
# NEW: Add F1 scores
save_f1_scores(
all_true_labels,
all_pred_labels,
class_names,
run_dir / 'logs' / f'f1_scores_{split}.txt'
)
Example: ROC Curve¶
def save_roc_curve(true_labels, pred_probs, class_names, path):
"""Save ROC curve for binary or multi-class."""
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
num_classes = len(class_names)
plt.figure(figsize=(10, 8))
if num_classes == 2:
# Binary classification
fpr, tpr, _ = roc_curve(true_labels, pred_probs[:, 1])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.2f})')
else:
# Multi-class: one curve per class
for i, class_name in enumerate(class_names):
binary_labels = (true_labels == i).astype(int)
fpr, tpr, _ = roc_curve(binary_labels, pred_probs[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.grid(True)
plt.savefig(path, dpi=300, bbox_inches='tight')
plt.close()
Useful Metrics¶
- Precision, Recall, F1 (per-class)
- ROC curves and AUC
- Precision-Recall curves
- Top-K accuracy
- Calibration curves
- Per-sample confidence scores
Best Practices¶
- Save metrics to files
- Visualize when possible
- Include in summary
- Compare across runs
- Document interpretation