from __future__ import annotations
from pathlib import Path
from typing import Any, Mapping, Optional, Union
from langchain_core.messages import AIMessageChunk, BaseMessageChunk, merge_content
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from pydantic import BaseModel
from genai._utils.general import merge_objects
from genai.extensions._common.utils import extract_token_usage
from genai.schema import TextGenerationParameters
[docs]
def update_token_usage(*, target: dict[str, Any], source: Optional[dict[str, Any]]):
if not source:
return
for key, value in source.items():
if key in target:
target[key] += value
else:
target[key] = value
[docs]
def update_token_usage_stream(*, target: dict[str, Any], source: Optional[dict]):
if not source:
return
def get_value(key: str, override=False) -> int:
current = target.get(key, 0) or 0
new = source.get(key, 0) or 0
if new != 0 and (current == 0 or override):
return new
else:
return current
completion_tokens = get_value("completion_tokens", override=True)
prompt_tokens = get_value("prompt_tokens")
target.update(
{
"prompt_tokens": prompt_tokens,
"input_token_count": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": completion_tokens + prompt_tokens,
"generated_token_count": completion_tokens,
}
)
[docs]
def create_llm_output(*, model: str, token_usages: Optional[list[Optional[dict]]] = None, **kwargs) -> dict[str, Any]:
final_token_usage = extract_token_usage({})
for source in token_usages or []:
update_token_usage(target=final_token_usage, source=source)
return {"model_name": model, "token_usage": final_token_usage, **kwargs}
[docs]
def load_config(file: Union[str, Path]) -> dict:
def parse_config() -> dict:
file_path = Path(file) if isinstance(file, str) else file
if file_path.suffix == ".json":
with open(file_path) as f:
import json
return json.load(f)
elif file_path.suffix == ".yaml":
with open(file_path, "r") as f:
import yaml
return yaml.safe_load(f)
else:
raise ValueError("File type must be json or yaml")
config = parse_config()
config["parameters"] = TextGenerationParameters(**config.get("parameters", {}))
return config
[docs]
def dump_optional_model(model: Optional[BaseModel]) -> Optional[Mapping[str, Any]]:
return model.model_dump(exclude_none=True) if model else None
[docs]
class CustomChatGenerationChunk(ChatGenerationChunk):
def __add__(self, other: ChatGenerationChunk) -> CustomChatGenerationChunk:
""" "Replaces LangChain's 'merge_dicts' with our simplified 'merge_objects' utility"""
if isinstance(other, ChatGenerationChunk):
generation_info = merge_objects(
self.generation_info or {},
other.generation_info or {},
)
return CustomChatGenerationChunk(
message=self.message + other.message,
generation_info=generation_info or None,
)
else:
raise TypeError(f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'")
[docs]
class CustomGenerationChunk(GenerationChunk):
def __add__(self, other: GenerationChunk) -> CustomGenerationChunk:
"""Replaces LangChain's 'merge_dicts' with our simplified 'merge_objects' utility"""
if isinstance(other, GenerationChunk):
generation_info = merge_objects(
self.generation_info or {},
other.generation_info or {},
)
return CustomGenerationChunk(
text=self.text + other.text,
generation_info=generation_info or None,
)
else:
raise TypeError(f"unsupported operand type(s) for +: '{type(self)}' and '{type(other)}'")
[docs]
class CustomAIMessageChunk(AIMessageChunk):
"""Replaces LangChain's 'merge_dicts' with our simplified 'merge_objects' utility"""
def __add__(self, other: Any) -> BaseMessageChunk:
if isinstance(other, AIMessageChunk):
if self.example != other.example:
raise ValueError("Cannot concatenate AIMessageChunks with different example values.")
return self.__class__(
example=self.example,
content=merge_content(self.content, other.content),
additional_kwargs=merge_objects(self.additional_kwargs, other.additional_kwargs),
response_metadata=merge_objects(self.response_metadata, other.response_metadata),
)
return super().__add__(other)