QSage Tutorial#

This tutorial demonstrates how to use Quantum Sage (QSage), a meta-learning tool that predicts which quantum or classical ML model will perform best on your dataset.

What is QSage?#

QSage is a surrogate model trained on extensive benchmarking data that:

  • Predicts model performance without running expensive experiments

  • Recommends best models based on dataset characteristics

  • Saves computational resources by avoiding trial-and-error

  • Supports both quantum and classical ML algorithms

⏱️ Note on Training Time: Training QSage sub-sages can take a significant amount of time depending on your hardware and dataset size. If you want to skip training, you can download a pre-compiled QSage model and jump directly to the prediction steps. See Section 3b: Load Pre-compiled Model (Recommended) below.

1. Setup and Imports#

[ ]:
import pandas as pd
import os
import math
import pickle

# Import QSage
from apps.sage.sage import QuantumSage

print("✓ Imports successful")

2. What QSage Needs — Input Data#

QSage requires a compiled benchmarking results table that combines:

  • Dataset complexity features — intrinsic properties of each dataset (e.g., number of features, number of samples, intrinsic dimension, fractal dimension, Fisher discriminant ratio, etc.)

  • Model performance metrics — accuracy, F1-score, AUC for each model/embedding combination

  • Metadata — dataset name, embedding type, model name

How to generate this data with QProfiler#

The easiest way to produce this input is to run the QProfiler Tutorial first. QProfiler benchmarks multiple ML models across your datasets and outputs a ModelResults.csv file.

Once you have run QProfiler, load its output below. You can also compile results from multiple QProfiler runs across different datasets to build a richer training set for QSage.

[ ]:
# Load QProfiler output (ModelResults.csv)
# Run the QProfiler tutorial first to generate this file:
# tutorial/QProfiler/example_qprofiler.ipynb

file_input = '../QProfiler/ModelResults.csv'  # Path to QProfiler output

results_df = pd.read_csv(file_input)
results_df['embeddings'] = results_df['embeddings'].fillna('none')
results_df = results_df.reset_index(drop=True)
results_df[results_df == math.inf] = 0
results_df = results_df.drop_duplicates()

print(f"Loaded {len(results_df)} benchmark results")
print(f"Datasets: {results_df['Dataset'].nunique()}")
print(f"Models: {results_df['model'].nunique()}")
print(f"Columns: {list(results_df.columns)}")
results_df.head()

2b. Prepare QProfiler Output for QSage#

QSage expects a few additional metadata columns that are not directly output by QProfiler. The cell below adds them automatically:

  • ``datatype`` — the file/dataset name (derived from Dataset)

  • ``model_embed_datatype`` — a combined identifier string in the format model_embedding_datatype

  • ``iteration`` — trial/run index (set to 1 if not present)

[ ]:
# Add columns required by QSage that are not directly in QProfiler output

# 'datatype': use the Dataset column as-is
results_df['datatype'] = results_df['Dataset']

# 'model_embed_datatype': combined identifier used internally by QSage
results_df['model_embed_datatype'] = (
    results_df['model'] + '_' +
    results_df['embeddings'] + '_' +
    results_df['datatype']
)

# 'iteration': use existing column if present, otherwise default to 1
if 'iteration' not in results_df.columns:
    results_df['iteration'] = 1

print("✓ QProfiler output prepared for QSage")
print(f"Added columns: datatype, model_embed_datatype, iteration")
results_df[['Dataset', 'embeddings', 'model', 'datatype', 'model_embed_datatype', 'iteration']].head()

3. Initialize QSage#

Important: The current QSage API uses data_input parameter (not data, features, metrics, sage_type).

[ ]:
# Select a held-out dataset for testing
held_out_dataset = results_df['Dataset'].unique()[0]  # Use first dataset as held-out example
print(f"Held-out dataset: {held_out_dataset}")

# Split data into training and held-out
train_df = results_df[~results_df['Dataset'].str.contains(held_out_dataset, regex=False)]
held_out_df = results_df[results_df['Dataset'].str.contains(held_out_dataset, regex=False)]

print(f"Training data: {len(train_df)} results")
print(f"Held-out data: {len(held_out_df)} results")

# Initialize QSage with correct API
sage = QuantumSage(data_input=train_df)

print(f"\n✓ QSage initialized")
print(f"Available models: {sage._available_models}")
print(f"Available metrics: {sage._available_metrics}")

3b. Load Pre-compiled Model (Recommended)#

⏱️ Training QSage takes considerable time. To skip training and use a ready-made model, download the pre-compiled QSage model and load it directly:

Download the pre-compiled model: QSage Pre-compiled Model

After downloading, place the .pkl file in your working directory and run the cell below. If you use the pre-compiled model, you can skip Sections 4 and 5 and proceed directly to Section 6 (Make Predictions).

[ ]:
# Load pre-compiled QSage model (skip training)
# Download the model from:
# https://ibm.box.com/s/4vv39mpplq8juhffno114bqkjls38sgm

file_precompiled = 'qsage_model.pkl'  # Update path/filename if needed

with open(file_precompiled, 'rb') as f:
    sage = pickle.load(f)

print("✓ Pre-compiled QSage model loaded successfully")
print(f"Available models: {sage._available_models}")
print(f"Available metrics: {sage._available_metrics}")
print("\n→ You can now skip to Section 6 (Make Predictions)")

4. Train QSage Sub-Sages (Optional — Skip if using pre-compiled model)#

Train surrogate models for each ML model and metric combination.

⏱️ Warning: This step can take a significant amount of time (potentially hours depending on your hardware). If you downloaded the pre-compiled model in Section 3b, skip this section and proceed to Section 6.

[ ]:
# Train sub-sages with Random Forest
# NOTE: This step is optional if you loaded the pre-compiled model above.
print("Training QSage sub-sages (this may take a long time)...")
sage.train_sub_sages(
    test_size=0.2,
    sage_type='random_forest',  # or 'mlp'
    n_iter=50,  # Number of hyperparameter search iterations
    cv=5  # Cross-validation folds
)
print("✓ Training complete!")

5. Visualize Training Results (Optional — Skip if using pre-compiled model)#

[ ]:
# Plot training results
sage.plot_results(figsize=(8, 5))

6. Make Predictions#

Use QSage to predict which model will perform best on a new dataset.

The input to predict() is the dataset complexity features from your QProfiler output — the columns that describe the intrinsic properties of your dataset (e.g., intrinsic dimension, number of samples, Fisher discriminant ratio, etc.).

Note: If you loaded the pre-compiled model (Section 3b) and skipped Section 3, the cell below will load the QProfiler data and select a dataset to predict on automatically.

[ ]:
# If held_out_df is not defined (pre-compiled model path), load QProfiler data now
if 'held_out_df' not in dir() or held_out_df is None:
    _qp_df = pd.read_csv('../QProfiler/ModelResults.csv')
    _qp_df['embeddings'] = _qp_df['embeddings'].fillna('none')
    _qp_df['datatype'] = _qp_df['Dataset']
    _qp_df['model_embed_datatype'] = _qp_df['model'] + '_' + _qp_df['embeddings'] + '_' + _qp_df['datatype']
    if 'iteration' not in _qp_df.columns:
        _qp_df['iteration'] = 1
    held_out_dataset = _qp_df['Dataset'].unique()[0]
    held_out_df = _qp_df[_qp_df['Dataset'].str.contains(held_out_dataset, regex=False)]
    print(f"Loaded QProfiler data. Using '{held_out_dataset}' as the target dataset.")

# Extract dataset complexity features for the held-out dataset
# Use only the feature columns that QSage was trained on
available_features = [c for c in sage._columns_data_features if c in held_out_df.columns]
held_out_features = held_out_df[available_features].drop_duplicates()

print(f"Dataset: {held_out_df['Dataset'].unique()}")
print(f"\nDataset complexity features:")
print(held_out_features.T)

# Make predictions for accuracy metric
predictions = sage.predict(held_out_features, metric='accuracy')

print(f"\n✓ Generated predictions for {len(predictions)} models")
print("\nRanked Model Recommendations (by predicted accuracy × confidence):")
print(predictions.to_string(index=False))

7. Visualize QSage Predictions vs QProfiler Actual Results#

The plot below shows two views side by side:

  • Left: QSage predicted accuracy for each model it knows about, ranked from best to worst

  • Right: Actual accuracy measured by QProfiler on your local datasets

If the model names overlap (e.g., when you trained QSage on your own QProfiler data), a direct scatter comparison is also shown.

[ ]:
import matplotlib.pyplot as plt
import seaborn as sns

# Actual results from QProfiler
actual_results = held_out_df.groupby('model')['accuracy'].mean().reset_index()
actual_results.columns = ['model', 'actual_accuracy']
actual_results = actual_results.sort_values('actual_accuracy', ascending=False)

# Check for overlapping model names
comparison = predictions.merge(actual_results, on='model')

if len(comparison) > 0:
    # Direct comparison scatter plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
else:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# --- Plot 1: QSage predictions (ranked bar chart) ---
ax = axes[0]
pred_sorted = predictions.sort_values('accuracy', ascending=True)
colors = ['#2196F3' if v >= pred_sorted['accuracy'].median() else '#90CAF9'
          for v in pred_sorted['accuracy']]
ax.barh(pred_sorted['model'], pred_sorted['accuracy'], color=colors)
ax.set_xlabel('Predicted Accuracy')
ax.set_title('QSage: Predicted Model Accuracy')
ax.set_xlim(0, 1)
ax.axvline(pred_sorted['accuracy'].median(), color='gray', linestyle='--', alpha=0.5, label='median')
ax.legend()

# --- Plot 2: QProfiler actual results (ranked bar chart) ---
ax = axes[1]
actual_sorted = actual_results.sort_values('actual_accuracy', ascending=True)
colors2 = ['#4CAF50' if v >= actual_sorted['actual_accuracy'].median() else '#A5D6A7'
           for v in actual_sorted['actual_accuracy']]
ax.barh(actual_sorted['model'], actual_sorted['actual_accuracy'], color=colors2)
ax.set_xlabel('Actual Accuracy')
ax.set_title('QProfiler: Actual Model Accuracy')
ax.set_xlim(0, 1)
ax.axvline(actual_sorted['actual_accuracy'].median(), color='gray', linestyle='--', alpha=0.5, label='median')
ax.legend()

# --- Plot 3 (if overlap): Direct scatter comparison ---
if len(comparison) > 0:
    ax = axes[2]
    ax.scatter(comparison['accuracy'], comparison['actual_accuracy'], alpha=0.7, s=100, color='#9C27B0')
    for _, row in comparison.iterrows():
        ax.annotate(row['model'], (row['accuracy'], row['actual_accuracy']),
                    textcoords='offset points', xytext=(5, 5), fontsize=8)
    ax.plot([0, 1], [0, 1], 'r--', label='Perfect prediction')
    ax.set_xlabel('Predicted Accuracy (QSage)')
    ax.set_ylabel('Actual Accuracy (QProfiler)')
    ax.set_title('QSage vs QProfiler: Direct Comparison')
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.legend()
    ax.grid(True, alpha=0.3)
    mae = abs(comparison['accuracy'] - comparison['actual_accuracy']).mean()
    ax.set_title(f'QSage vs QProfiler (MAE={mae:.3f})')

plt.tight_layout()
plt.show()

print("\nQSage top recommendations:")
print(predictions[['model', 'accuracy', 'r2']].head(5).to_string(index=False))
print("\nQProfiler top performers:")
print(actual_results.head(5).to_string(index=False))

8. Save Trained QSage Model#

[ ]:
# Save trained model
file_sage = 'my_qsage_model.pkl'
with open(file_sage, 'wb') as f:
    pickle.dump(sage, f)
print(f"✓ Model saved to {file_sage}")

Summary#

In this tutorial, you learned how to:

  1. ✅ Understand what input data QSage requires and how to generate it with QProfiler

  2. ✅ Prepare QProfiler output for QSage (add required metadata columns)

  3. ✅ Initialize QSage with the correct API (data_input parameter)

  4. ✅ Load a pre-compiled QSage model to skip training

  5. ✅ Train QSage sub-sages using train_sub_sages() method (optional)

  6. ✅ Make predictions for new datasets using predict() method

  7. ✅ Visualize QSage predictions alongside QProfiler actual results

  8. ✅ Save and load trained QSage models

Key API Points#

  • Initialization: QuantumSage(data_input=df) - only takes data_input parameter

  • Training: sage.train_sub_sages(sage_type='random_forest') - specify model type here (takes time)

  • Prediction: sage.predict(features, metric='accuracy') - predict for specific metric

  • Visualization: sage.plot_results() - plot training performance

Next Steps#

  • Run QProfiler on more datasets to build a richer training set for QSage

  • Combine results from multiple QProfiler runs across different datasets for better QSage predictions

  • Experiment with different sage types (Random Forest vs MLP)

  • Try different metrics (accuracy, f1_score, auc)

See Also#