Compare a set of hyperparameters#

Run a grid search over all possible combinations of parameters

See compare_parameters.py on GitHub.#
from dotenv import load_dotenv

from genai.client import Client
from genai.credentials import Credentials
from genai.schema import (
    DecodingMethod,
    TextGenerationComparisonCreateRequestRequest,
    TextGenerationComparisonParameters,
    TextGenerationParameters,
)

# make sure you have a .env file under genai root with
# GENAI_KEY=<your-genai-key>
# GENAI_API=<genai-api-endpoint> (optional) DEFAULT_API = "https://bam-api.res.ibm.com"
load_dotenv()


def heading(text: str) -> str:
    """Helper function for centering text."""
    return "\n" + f" {text} ".center(80, "=") + "\n"


client = Client(credentials=Credentials.from_env())

prompt = "The capital of Madrid is Spain. The capital of Canada is"

print(heading("Run text generation with many parameter combinations"))
response = client.text.generation.compare(
    request=TextGenerationComparisonCreateRequestRequest(
        model_id="google/flan-t5-xl",
        parameters=TextGenerationParameters(min_new_tokens=1, max_new_tokens=10, decoding_method=DecodingMethod.SAMPLE),
        input=prompt,
    ),
    # Grid search through all possible combinations of the following parameters:
    compare_parameters=TextGenerationComparisonParameters(
        top_k=[10, 50],
        repetition_penalty=[1.0, 1.5],
        temperature=[0.7, 0.9, 1.5, 2.0],
    ),
)

print(f"Prompt: {prompt}\n")

for params_combination in response.results:
    print(f"Used params: {params_combination.parameters.model_dump()}")
    assert params_combination.result
    print(f"Generated text: {params_combination.result.results[0].generated_text}\n")