Source code for ibm_watson_machine_learning.foundation_models.prompts.prompt_template

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2023-2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from __future__ import print_function, annotations

from typing import Optional, List, Dict, Union, TYPE_CHECKING

if TYPE_CHECKING:
    import langchain
from dataclasses import dataclass
from datetime import datetime

import pandas

import ibm_watson_machine_learning._wrappers.requests as requests
from ibm_watson_machine_learning import APIClient
from ibm_watson_machine_learning.wml_client_error import (WMLClientError, ValidationError,
                                                          InvalidValue, InvalidMultipleArguments, PromptVariablesError)
from ibm_watson_machine_learning.wml_resource import WMLResource
from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes, PromptTemplateFormats
from ibm_watson_machine_learning.foundation_models.utils.utils import TemplateFormatter


@dataclass
class PromptTemplateLock:
    """Storage for lock object.
    """
    locked: bool
    locked_by: Optional[str] = None


class PromptTemplate:
    """Storage for prompt template parameters.

    :param prompt_id: Id of prompt template, defaults to None.
    :type prompt_id: Optional[str], attribute setting not allowed

    :param created_at: Time the prompt was created (UTC), defaults to None.
    :type created_at: Optional[str], attribute setting not allowed 

    :param lock: Locked state of asset, defaults to None.
    :type lock: Optional[PromptTemplateLock], attribute setting not allowed

    :param is_template: True if prompt is a template, False otherwise; defaults to None.
    :type is_template: Optional[bool], attribute setting not allowed 

    :param name: Prompt template name, defaults to None.
    :type name: Optional[str], optional

    :param model_id: Foundation model id, defaults to None.
    :type model_id: Optional[ModelTypes], optional

    :param model_params: Model parameters, defaults to None.
    :type model_params: Optional[Dict], optional

    :param template_version: Semvar version for tracking in IBM AI Factsheets, defaults to None.
    :type template_version: Optional[str], optional

    :param task_ids: List of task ids, defaults to None.
    :type task_ids: Optional[List[str]], optional

    :param description: Prompt template asset description, defaults to None.
    :type description: Optional[str], optional

    :param input_text: Input text for prompt, defaults to None.
    :type input_text: Optional[str], optional

    :param input_variables: Input variables can be present in fields: `instruction`, 
                            `input_prefix`, `output_prefix`, `input_text`, `examples`
                            and are indentified by braces ('{' and '}'), defaults to None.
    :type temaplate_parameters: (List | Dict[str, Dict[str, str]] | None), optional

    :param instruction: Instruction for model, defaults to None.
    :type instruction: Optional[str], optional

    :param input_prefix: Prefix string placed before input text, defaults to None.
    :type input_prefix: Optional[str], optional

    :param output_prefix: Prefix before model response, defaults to None.
    :type output_prefix: Optional[str], optional

    :param exmaples: Examples may help the model to adjust the response; [[input1, output1], ...], defaults to None.
    :type exmaples: Optional[List[List[str]]], optional

    :param validate_template: If True, the Prompt Template is validated for the presence of input variables, defaults to True.
    :type validate_template: bool, optional


    **Examples**

    Example of invalid Prompt Template:

    .. code-block:: python

        prompt_template = PromptTemplate(input_text='What are the most famous monuments in ?',
                                         input_variables=['country'])

        Traceback (most recent call last):
            ...
        ValidationError: Invalid prompt template; check for mismatched or missing input variables. Missing input variable: {'country'}

    Example of valid Prompt Template:

    .. code-block:: python

        prompt_template = PromptTemplate(input_text='What are the most famous monuments in {country}?',
                                         input_variables=['country'])

    """

    def __init__(self,
                 name: Optional[str] = None,
                 model_id: Optional[ModelTypes] = None,
                 model_params: Optional[Dict] = None,
                 template_version: Optional[str] = None,
                 task_ids: Optional[List[str]] = None,
                 description: Optional[str] = None,
                 input_text: Optional[str] = None,
                 input_variables: (List | Dict[str, Dict[str, str]] | None) = None,
                 instruction: Optional[str] = None,
                 input_prefix: Optional[str] = None,
                 output_prefix: Optional[str] = None,
                 examples: Optional[List[List[str]]] = None,
                 validate_template: bool = True) -> None:
        self.name = name
        self._prompt_id = None
        self._created_at = None
        self._lock = None
        self._is_template = None
        self.model_id = model_id
        if isinstance(self.model_id, ModelTypes):
            self.model_id = self.model_id.value
        self.model_params = model_params.copy() if model_params is not None else model_params
        self.task_ids = task_ids.copy() if task_ids is not None else task_ids
        self.template_version = template_version
        self.description = description
        self.input_text = input_text
        self.input_variables = input_variables.copy() if input_variables is not None else input_variables
        self.instruction = instruction
        self.input_prefix = input_prefix
        self.output_prefix = output_prefix
        self.examples = examples.copy() if examples is not None else examples

        # template validation
        if validate_template:
            self._validation()

    def __repr__(self):
        args = [f"{key}={value!r}" for key, value in self.__dict__.items() 
                if not key.startswith('_') and value is not None]
        return f"{type(self).__name__}({ ', '.join(args)})"
    
    @property
    def prompt_id(self):
        return self._prompt_id

    @property
    def created_at(self):
        return str(datetime.utcfromtimestamp(self._created_at / 1000)).split(".")[0]

    @property
    def lock(self):
        return self._lock

    @property
    def is_template(self):
        return self._is_template

    def _validation(self):
        """Validate template structure.

        :raises ValidationError: raises when input_variables does not fit placeholders in input body.
        """
        input_variables = self.input_variables if self.input_variables is not None else []
        template_text = " ".join(filter(None, [self.instruction,
                                               self.input_prefix,
                                               self.output_prefix]))
        if self.examples:
            for example in self.examples:
                template_text += " ".join(example)
        try:
            def _validate(input_text):
                dummy_inputs = {input_variable: "wx" for input_variable in input_variables}
                TemplateFormatter().format("".join([template_text, input_text]), **dummy_inputs)

            if self.input_text:
                _validate(template_text + self.input_text)
            else:
                if template_text:
                    _validate(template_text)
        except KeyError as key:
            raise ValidationError(key)


[docs] class PromptTemplateManager(WMLResource): """Instantiate the prompt template manager. :param credentials: Credentials to watsonx.ai instance. :type credentials: dict :param project_id: ID of project :type project_id: str :param space_id: ID of project :type space_id: str :param verify: user can pass as verify one of following: - the path to a CA_BUNDLE file - the path of directory with certificates of trusted CAs - `True` - default path to truststore will be taken - `False` - no verification will be made :type verify: bool or str, optional .. note:: One of these parameters is required: ['project_id ', 'space_id'] **Example** .. code-block:: python from ibm_watson_machine_learning.foundation_models.prompts import PromptTemplate, PromptTemplateManager from ibm_watson_machine_learning.foundation_models.utils.enums import ModelTypes prompt_mgr = PromptTemplateManager( credentials={ "apikey": "***", "url": "https://us-south.ml.cloud.ibm.com" }, project_id="*****" ) prompt_template = PromptTemplate(name="My prompt", model_id=ModelTypes.GRANITE_13B_CHAT_V2, input_prefix="Human:", output_prefix="Assistant:", input_text="What is {object} and how does it work?", input_variables=['object'], examples=[['What is the Stock Market?', 'A stock market is a place where investors buy and sell shares of publicly traded companies.']]) stored_prompt_template = prompt_mgr.store_prompt(prompt_template) print(stored_prompt_template.prompt_id) # id of prompt template asset """ def __init__(self, credentials: Optional[dict] = None, *, project_id: Optional[str] = None, space_id: Optional[str] = None, verify=None, api_client: APIClient = None) -> None: self.project_id = project_id self.space_id = space_id if credentials: self._client = APIClient(credentials, verify=verify) elif api_client: self._client = api_client else: raise InvalidMultipleArguments(params_names_list=["credentials", "api_client"], reason="None of the arguments were provided.") if self.space_id is not None and self.project_id is not None: raise InvalidMultipleArguments(params_names_list=["project_id", "space_id"], reason="Both arguments were provided.") self.params = {} if self.space_id: self._client.set.default_space(space_id) self.params = {'space_id': self.space_id} elif self.project_id: self._client.set.default_project(project_id) self.params = {'project_id': self.project_id} elif api_client: if (project_id := self._client.default_project_id): self.params = {'project_id': project_id} elif (space_id := self._client.default_space_id): self.params = {'space_id': space_id} else: pass elif not api_client: raise InvalidMultipleArguments(params_names_list=["space_id", "project_id"], reason="None of the arguments were provided.") if not self._client.CLOUD_PLATFORM_SPACES and self._client.CPD_version < 4.8: raise WMLClientError(error_msg="Operation is unsupported for this release.") WMLResource.__init__(self, __name__, self._client) def _create_request_body(self, prompt_template: PromptTemplate) -> Dict: """Method is used to create request body from PromptTemplate object. :param prompt_template: Object of type PromptTemplate based on which the request body will be created. :type prompt_template: PromptTemplate :return: Request body :rtype: Dict """ json_data: Dict = {'prompt': dict()} if prompt_template.description is not None: json_data.update({'description': prompt_template.description}) if prompt_template.input_variables is not None: PromptTemplateManager._validate_type(prompt_template.input_variables, u'input_variables', [dict, list], False) if isinstance(prompt_template.input_variables, list): json_data.update({'prompt_variables': {key: {} for key in prompt_template.input_variables}}) else: json_data.update({'prompt_variables': prompt_template.input_variables}) if prompt_template.task_ids is not None: PromptTemplateManager._validate_type(prompt_template.task_ids, u'task_ids', list, False) json_data.update({'task_ids': prompt_template.task_ids}) if prompt_template.template_version is not None: json_data.update({"model_version": {"number": prompt_template.template_version}}) if prompt_template.input_text: PromptTemplateManager._validate_type(prompt_template.input_text, u'input_text', str, False) json_data['prompt'].update({'input': [[prompt_template.input_text, '']]}) PromptTemplateManager._validate_type(prompt_template.model_id, u'model_id', str, True) if prompt_template.model_id is not None: json_data['prompt'].update({'model_id': prompt_template.model_id}) if prompt_template.model_params is not None: PromptTemplateManager._validate_type(prompt_template.model_params, u'model_parameters', dict, False) json_data['prompt'].update({'model_parameters': prompt_template.model_params}) data: Dict = dict() if prompt_template.instruction is not None: data.update({'instruction': prompt_template.instruction}) if prompt_template.input_prefix is not None: data.update({'input_prefix': prompt_template.input_prefix}) if prompt_template.output_prefix is not None: data.update({'output_prefix': prompt_template.output_prefix}) if prompt_template.examples is not None: PromptTemplateManager._validate_type(prompt_template.examples, u'examples', list, False) data.update({'examples': prompt_template.examples}) json_data['prompt'].update({'data': data}) return json_data def _from_json_to_prompt(self, response: Dict) -> PromptTemplate: """Convert json response to PromptTemplate object. :param response: Response body after request operation. :type response: Dict :return: PromptTemplate object with given details. :rtype: PromptTemplate """ prompt_field: Dict = response.get('prompt', dict()) data_field: Dict = prompt_field.get('data', dict()) prompt_template = PromptTemplate(name=response.get('name'), description=response.get('description'), model_id=prompt_field.get('model_id'), model_params=prompt_field.get('model_parameters'), task_ids=response.get("task_ids"), template_version=response.get("model_version", dict()).get("number"), input_variables = response.get('prompt_variables'), input_text = prompt_field.get('input', [[None, None]])[0][0], instruction = data_field.get('instruction'), input_prefix = data_field.get('input_prefix'), output_prefix = data_field.get('output_prefix'), examples = data_field.get('examples'), validate_template=False ) prompt_template._prompt_id = response.get('id') prompt_template._created_at = response.get('created_at') if "lock_type" in response.get('lock', dict()): del response['lock']["lock_type"] prompt_template._lock = PromptTemplateLock(**response.get('lock', {"locked": None, "locked_by": None})) prompt_template._is_template = response.get('is_template') return prompt_template def _get_details(self, limit: Optional[int] = None) -> List: """Method retrives details of all prompt templates. If limit is set to None then all prompt templates are fetched. :param limit: limit number of fetched records, defaults to None. :type limit: Optional[int] :return: List of prompts metadata :rtype: List """ headers = self._client._get_headers() url = self._client.service_instance._href_definitions.get_prompts_all_href() json_data = {"query": u"asset.asset_type:wx_prompt", u"sort": "-asset.created_at<string>"} if limit is not None: if limit < 1: raise WMLClientError('Limit cannot be lower than 1.') elif limit > 200: raise WMLClientError('Limit cannot be larger than 200.') json_data.update({'limit': limit}) else: json_data.update({'limit': 200}) prompts_list = [] bookmark = True while bookmark is not None: response = requests.post(url=url, json=json_data, headers=headers, params=self.params) details_json = self._handle_response(200, "Get next details", response) bookmark = details_json.get('next', {'href': None}).get('bookmark', None) prompts_list.extend(details_json.get('results', [])) if limit is not None: break json_data.update({'bookmark': bookmark}) return prompts_list def _change_lock(self, prompt_id: str, locked: bool, force: bool = False) -> Dict: """Change prompt template lock state. :param prompt_id: Id of prompt template. :type prompt_id: str :param locked: New lock state. :type locked: bool :param force: force lock state overwrite, defaults to False. :type force: bool, optional :return: Response content after lock state change. :rtype: Dict """ headers = self._client._get_headers() params = (self.params | {"prompt_id": prompt_id, "force": force}) json_data = {"locked": locked} url = self._client.service_instance._href_definitions.get_prompts_href() + f"/{prompt_id}/lock" response = requests.put(url=url, json=json_data, headers=headers, params=params) return self._handle_response(200, u'change_lock', response)
[docs] def load_prompt(self, prompt_id: str, astype: PromptTemplateFormats = PromptTemplateFormats.PROMPTTEMPLATE, *, prompt_variables: Optional[Dict[str, str]] = None): """Retrive a prompt template asset. :param prompt_id: Id of prompt template which is processed. :type prompt_id: str :param astype: Type of return object. :type astype: PromptTemplateFormats :param prompt_variables: Dictionary of input variables and values with which input variables will be replaced. :type prompt_variables: Dict[str, str] :return: Prompt template asset. :rtype: PromptTemplate | str | langchain.prompts.PromptTemplate **Example** .. code-block:: python loaded_prompt_template = prompt_mgr.load_prompt(prompt_id) loaded_prompt_template_lc = prompt_mgr.load_prompt(prompt_id, PromptTemplateFormats.LANGCHAIN) loaded_prompt_template_string = prompt_mgr.load_prompt(prompt_id, PromptTemplateFormats.STRING) """ headers = self._client._get_headers() params = (self.params | {"prompt_id": prompt_id}) url = self._client.service_instance._href_definitions.get_prompts_href() + f"/{prompt_id}" if isinstance(astype, PromptTemplateFormats): astype = astype.value if astype == 'prompt': response = requests.get(url=url, headers=headers, params=params) return self._from_json_to_prompt(self._handle_response(200, u'_load_json_prompt', response)) elif astype in ('langchain', 'string'): response = requests.post(url=url + u"/input", headers=headers, params=params) response_input = self._handle_response(200, u'load_prompt', response).get("input") if astype == 'string': try: return response_input if prompt_variables is None else response_input.format(**prompt_variables) except KeyError as key: raise PromptVariablesError(key) else: from langchain.prompts import PromptTemplate as LcPromptTemplate return LcPromptTemplate.from_template(response_input) else: raise InvalidValue(u'astype')
[docs] def list(self, *, limit=None) -> pandas.core.frame.DataFrame: """List all available prompt templates in the DataFrame format. :param limit: limit number of fetched records, defaults to None. :type limit: Optional[int] :return: Dataframe of fundamental properties of availabale prompts. :rtype: pandas.core.frame.DataFram **Example** .. code-block:: python prompt_mgr.list(limit=5) # list of 5 recent created prompt template assets .. hint:: Additionally you can sort available prompt templates by "LAST MODIFIED" field. .. code-block:: python df_prompts = prompt_mgr.list() df_prompts.sort_values("LAST MODIFIED", ascending=False) """ details = ['metadata.asset_id', 'metadata.name', 'metadata.created_at', 'metadata.usage.last_updated_at'] prompts_details = self._get_details(limit=limit) data_normalize = pandas.json_normalize(prompts_details) prompts_data = data_normalize.reindex(columns=details) df_details = pandas.DataFrame(prompts_data, columns=details) df_details.rename(columns={'metadata.asset_id': 'ID', 'metadata.name': 'NAME', 'metadata.created_at': 'CREATED', 'metadata.usage.last_updated_at': 'LAST MODIFIED', }, inplace=True) return df_details
[docs] def store_prompt(self, prompt_template: Union[PromptTemplate, langchain.prompts.PromptTemplate]) -> PromptTemplate: """Store a new prompt template. :param prompt_template: PromptTemplate to be stored. :type prompt_template: (PromptTemplate | langchain.prompts.PromptTemplate) :return: PromptTemplate object initialized with values provided in the server response object. :rtype: PromptTemplate """ if isinstance(prompt_template, PromptTemplate): pass else: from langchain.prompts import PromptTemplate as LcPromptTemplate if isinstance(prompt_template, LcPromptTemplate): prompt_template = PromptTemplate(name="My prompt", model_id=ModelTypes.FLAN_UL2, input_text=prompt_template.template, input_variables=prompt_template.input_variables) else: raise WMLClientError(error_msg="Unsupported type for `prompt_template`") headers = self._client._get_headers() PromptTemplateManager._validate_type(prompt_template.name, u'prompt_template.name', str, True) json_data: Dict = { "name": prompt_template.name, "lock": {"locked": True}, "prompt": dict() } json_data.update(self._create_request_body(prompt_template)) url = self._client.service_instance._href_definitions.get_prompts_href() response = requests.post(url=url, json=json_data, headers=headers, params=self.params) response = self._handle_response(201, u'store_prompt', response) return self._from_json_to_prompt(response)
[docs] def delete_prompt(self, prompt_id: str, *, force: bool = False) -> str: """Remove prompt template from project or space. :param prompt_id: Id of prompt template that will be delete. :type prompt_id: str :param force: If True then prompt template is unlocked and then delete, defaults to False. :type force: bool :return: Status 'SUCESS' if the prompt template is successfully deleted. :rtype: str **Example** .. code-block:: python prompt_mgr.delete_prompt(prompt_id) # delete if asset is unclocked """ if force: self.unlock(prompt_id) headers = self._client._get_headers() params = (self.params | {"prompt_id": prompt_id}) url = self._client.service_instance._href_definitions.get_prompts_href() + f"/{prompt_id}" response = requests.delete(url=url, headers=headers, params=params) return self._handle_response(204, u'delete_prompt', response)
[docs] def update_prompt(self, prompt_id: str, prompt_template: PromptTemplate) -> Dict: """Update prompt template data. :param prompt_id: Id of the updated prompt template. :type prompt_id: str :param prompt: PromptTemplate with new data. :type prompt: PromptTemplate :return: metadata of updated deployment :rtype: dict **Example** .. code-block:: python updataed_prompt_template = PromptTemplate(name="New name") prompt_mgr.update_prompt(prompt_id, prompt_template) # {'name': 'New name'} in metadata """ headers = self._client._get_headers() params = (self.params | {"prompt_id": prompt_id}) new_body: Dict = dict() current_prompt_template = self.load_prompt(prompt_id) for attribute in prompt_template.__dict__: if getattr(prompt_template, attribute) is not None and not attribute.startswith("_"): setattr(current_prompt_template, attribute, getattr(prompt_template, attribute)) if current_prompt_template.name is not None: new_body.update({'name': current_prompt_template.name}) new_body.update(self._create_request_body(current_prompt_template)) url = self._client.service_instance._href_definitions.get_prompts_href() + f"/{prompt_id}" response = requests.patch(url=url, json=new_body, headers=headers, params=params) return self._handle_response(200, u'update_prompt', response)
[docs] def get_lock(self, prompt_id: str) -> Dict: """Get the current locked state of a prompt template. :param prompt_id: Id of prompt template :type prompt_id: str :return: Information about locked state of prompt template asset. :rtype: Dict **Example** .. code-block:: python print(prompt_mgr.get_lock(prompt_id)) """ headers = self._client._get_headers() params = (self.params | {"prompt_id": prompt_id}) url = self._client.service_instance._href_definitions.get_prompts_href() + f"/{prompt_id}/lock" response = requests.get(url=url, headers=headers, params=params) return self._handle_response(200, u'get_lock', response)
[docs] def lock(self, prompt_id: str, force: bool = False) -> Dict: """Lock the prompt template if it is unlocked and user has permission to do that. :param promp_id: Id of prompt template. :type promp_id: str :param force: If True, method forcefully overwrite a lock. :type force: bool :return: Status 'SUCCESS' or response content after an attempt to lock prompt template. :rtype: (str | Dict) **Example** .. code-block:: python prompt_mgr.lock(prompt_id) """ return self._change_lock(prompt_id=prompt_id, locked=True, force=force)
[docs] def unlock(self, prompt_id: str) -> Dict: """Unlock the prompt template if it is locked and user has permission to do that. :param promp_id: Id of prompt template. :type promp_id: str :return: Response content after an attempt to unlock prompt template. :rtype: Dict **Example** .. code-block:: python prompt_mgr.unlock(prompt_id) """ # server returns status code 400 after trying to unlock unlocked prompt lock_state = self.get_lock(prompt_id) if lock_state['locked']: return self._change_lock(prompt_id=prompt_id, locked=False, force=False) else: return lock_state