Source code for genai.extensions.huggingface.agent
import logging
try:
from transformers import Agent
except ImportError:
raise ImportError( # noqa: B904
"Could not import HuggingFace transformers: Please install ibm-generative-ai[huggingface] extension."
)
from typing import Optional
from genai._utils.general import to_model_instance
from genai.client import Client
from genai.schema import TextGenerationParameters
logger = logging.getLogger(__name__)
[docs]
class IBMGenAIAgent(Agent):
[docs]
def __init__(
self,
client: Client,
model: Optional[str] = None,
parameters: Optional[TextGenerationParameters] = None,
chat_prompt_template: Optional[str] = None,
run_prompt_template: Optional[str] = None,
additional_tools: Optional[list[str]] = None,
):
super().__init__(
chat_prompt_template=chat_prompt_template,
run_prompt_template=run_prompt_template,
additional_tools=additional_tools,
)
self.client = client
self.model = model
self.parameters = parameters
[docs]
def generate_one(self, prompt: str, stop: Optional[list[str]] = None):
return self._generate([prompt], stop)[0]
[docs]
def generate_many(self, prompts: list[str], stop: Optional[list[str]] = None):
return self._generate(prompts, stop)
def _generate(self, prompts: list[str], stop: Optional[list[str]] = None) -> list[str]:
final_results: list[str] = []
if len(prompts) == 0:
return final_results
params = to_model_instance(self.parameters, TextGenerationParameters)
params.stop_sequences = stop or params.stop_sequences
for response in self.client.text.generation.create(model_id=self.model, inputs=prompts, parameters=params):
for result in response.results:
generated_text = result.generated_text or ""
logger.info("Output of GENAI call: {}".format(generated_text))
final_results.append(generated_text)
return final_results