# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2024.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from typing import Any, Type, TypeVar, get_origin, get_args
from tabulate import tabulate
from enum import Enum
from ibm_watsonx_ai.utils.utils import StrEnum
from dataclasses import dataclass, is_dataclass, fields
T = TypeVar("T", bound="BaseSchema")
[docs]
@dataclass
class BaseSchema:
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> "BaseSchema":
kwargs = {}
for field in fields(cls):
field_name = field.name
field_type = field.type
if field_name in data:
value = data[field_name]
origin = get_origin(field_type)
if origin is not None and issubclass(origin, BaseSchema):
if hasattr(origin, "from_dict"):
value = origin.from_dict(value)
kwargs[field_name] = value
return cls(**kwargs)
def to_dict(self) -> dict[str, Any]:
def unpack(
value: Enum | list[Any] | Any,
) -> int | dict[str, Any] | list[Any] | Any:
if isinstance(value, Enum):
return value.value
elif is_dataclass(value):
return {
k: unpack(v)
for k, v in value.__dict__.items()
if v is not None and not k.startswith("_")
}
elif isinstance(value, list):
return [unpack(v) for v in value]
else:
return value
return {
k: unpack(v)
for k, v in self.__dict__.items()
if v is not None and not k.startswith("_")
}
[docs]
@classmethod
def show(cls) -> None:
"""Displays a table with the parameter name, type, and example value."""
sample_params = cls.get_sample_params()
table_data = []
for field in fields(cls):
field_name = field.name
field_type = field.type
origin = get_origin(field_type) or field_type
args = get_args(field_type)
if args:
display_type = f"{', '.join(arg.__name__ if hasattr(arg, '__name__') else str(arg) for arg in args)}"
else:
display_type = (
origin.__name__ if hasattr(origin, "__name__") else str(origin)
)
example_value = sample_params.get(field_name, "N/A")
table_data.append([field_name, display_type, example_value])
print(
tabulate(
table_data,
headers=["PARAMETER", "TYPE", "EXAMPLE VALUE"],
tablefmt="grid",
)
)
[docs]
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Override this method in subclasses to provide example values for parameters."""
return {}
##############
# TEXT-GEN #
##############
[docs]
class TextGenDecodingMethod(StrEnum):
GREEDY = "greedy"
SAMPLE = "sample"
[docs]
@dataclass
class TextGenLengthPenalty:
decay_factor: float | None = None
start_index: int | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for TextGenLengthPenalty."""
return {
"decay_factor": 2.5,
"start_index": 5,
}
[docs]
@dataclass
class ReturnOptionProperties(BaseSchema):
input_text: bool | None = None
generated_tokens: bool | None = None
input_tokens: bool | None = None
token_logprobs: bool | None = None
token_ranks: bool | None = None
top_n_tokens: bool | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for ReturnOptionProperties."""
return {
"input_text": True,
"generated_tokens": True,
"input_tokens": True,
"token_logprobs": True,
"token_ranks": False,
"top_n_tokens": False,
}
[docs]
@dataclass
class TextGenParameters(BaseSchema):
decoding_method: str | TextGenDecodingMethod | None = None
length_penalty: dict | TextGenLengthPenalty | None = None
temperature: float | None = None
top_p: float | None = None
top_k: int | None = None
random_seed: int | None = None
repetition_penalty: float | None = None
min_new_tokens: int | None = None
max_new_tokens: int | None = None
stop_sequences: list[str] | None = None
time_limit: int | None = None
truncate_input_tokens: int | None = None
return_options: dict | ReturnOptionProperties | None = None
include_stop_sequence: bool | None = None
prompt_variables: dict | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for TextChatParameters."""
return {
"decoding_method": list(TextGenDecodingMethod)[1].value,
"length_penalty": TextGenLengthPenalty.get_sample_params(),
"temperature": 0.5,
"top_p": 0.2,
"top_k": 1,
"random_seed": 33,
"repetition_penalty": 2,
"min_new_tokens": 50,
"max_new_tokens": 1000,
"stop_sequences": 200,
"time_limit": 600000,
"truncate_input_tokens": 200,
"return_options": ReturnOptionProperties.get_sample_params(),
"include_stop_sequence": True,
"prompt_variables": {"doc_type": "emails", "entity_name": "Golden Retail"},
}
###############
# TEXT-CHAT #
###############
[docs]
class TextChatResponseFormatType(StrEnum):
JSON_OBJECT = "json_object"
[docs]
@dataclass
class TextChatResponseFormat(BaseSchema):
type: str | TextChatResponseFormatType | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for TextChatResponseFormat."""
return {"type": list(TextChatResponseFormatType)[0].value}
[docs]
@dataclass
class TextChatParameters(BaseSchema):
frequency_penalty: float | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
presence_penalty: float | None = None
response_format: dict | TextChatResponseFormat | None = None
temperature: float | None = None
max_tokens: int | None = None
time_limit: int | None = None
top_p: float | None = None
n: int | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for TextChatParameters."""
return {
"frequency_penalty": 0.5,
"logprobs": True,
"top_logprobs": 3,
"presence_penalty": 0.3,
"response_format": TextChatResponseFormat.get_sample_params(),
"temperature": 0.7,
"max_tokens": 100,
"time_limit": 600000,
"top_p": 0.9,
"n": 1,
}
############
# RERANK #
############
[docs]
@dataclass
class RerankReturnOptions(BaseSchema):
top_n: int | None = None
inputs: bool | None = None
query: bool | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for RerankReturnOptions."""
return {"top_n": 1, "inputs": False, "query": False}
[docs]
@dataclass
class RerankParameters(BaseSchema):
truncate_input_tokens: int | None = None
return_options: dict | RerankReturnOptions | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for RerankParameters."""
return {
"truncate_input_tokens": 100,
"return_options": RerankReturnOptions.get_sample_params(),
}
#################
# TIME SERIES #
#################
[docs]
@dataclass
class TSForecastParameters(BaseSchema):
"""
:param timestamp_column: A valid column in the data that should be treated as the timestamp. if using calendar dates (simple integer time offsets are also allowed), users should consider using a format such as ISO 8601 that includes a UTC offset (e.g., '2024-10-18T01:09:21.454746+00:00'). This will avoid potential issues such as duplicate dates appearing due to daylight savings change overs. There are many date formats in existence and inferring the correct one can be a challenge so please do consider adhering to ISO 8601.
:type timestamp_column: str
:param prediction_length: The prediction length for the forecast. The service will return this many periods beyond the last timestamp in the inference data payload. If specified, prediction_length must be an integer >=1 and no more than the model default prediction length. When omitted the model default prediction_length will be used.
:type prediction_length: int, optional
:param id_columns: Columns that define a unique key for time series. This is similar to a compound primary key in a database table.
:type id_columns: list[str], optional
:param freq: A freqency indicator for the given timestamp_column. See https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#period-aliases for a description of the allowed values. If not provided, we will attempt to infer it from the data. Possible values: 0 ≤ length ≤ 100, Value must match regular expression ^\d+(B|D|W|M|Q|Y|h|min|s|ms|us|ns)$|^\s*$
:type freq: str, optional
:param target_columns: An array of column headings which constitute the target variables. These are the data that will be forecasted.
:type target_columns: list[str], optional
"""
timestamp_column: str
prediction_length: int | None = None
id_columns: list[str] | None = None
freq: str | None = None
target_columns: list[str] | None = None
observable_columns: list[str] | None = None
control_columns: list[str] | None = None
conditional_columns: list[str] | None = None
static_categorical_columns: list[str] | None = None
@classmethod
def get_sample_params(cls) -> dict[str, Any]:
"""Provide example values for TSForecastParameters."""
return {
"prediction_length": 10,
"timestamp_column": "date",
"id_columns": ["id1"],
"freq": "D",
"target_columns": ["col1", "col2"],
}