Source code for ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.db2_adapter

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

import copy
import logging
from typing import Any, TypeAlias, cast

from langchain_core.documents import Document

from ibm_watsonx_ai import APIClient
from ibm_watsonx_ai.foundation_models.embeddings import BaseEmbeddings
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.langchain_vector_store_adapter import (
    DEFAULT_CHUNK_SEQUENCE_NUMBER_FIELD,
    DEFAULT_DOCUMENT_NAME_FIELD,
    LangChainVectorStoreAdapter,
)
from ibm_watsonx_ai.utils.utils import is_lib_installed
from ibm_watsonx_ai.wml_client_error import (
    MissingExtension,
    VectorStoreSerializationError,
)

if not is_lib_installed(ext := "langchain-db2"):
    raise MissingExtension(ext, extra_info="rag")
from langchain_core.embeddings import Embeddings as LCEmbeddings
from langchain_db2 import DB2VS
from langchain_db2.db2vs import clear_table

logger = logging.getLogger(__name__)

# Type Alias
EmbeddingType: TypeAlias = BaseEmbeddings | LCEmbeddings


[docs] class DB2VectorStore(LangChainVectorStoreAdapter): """DB2VectorStore vector store client for a RAG pattern. :param api_client: api client is required if connecting by connection_id, defaults to None :type api_client: APIClient, optional :param connection_id: connection asset ID, defaults to None :type connection_id: str, optional :param vector_store: initialized langchain_db2 vector store, defaults to None :type vector_store: langchain_db2.DB2VS, optional :param embedding_function: dense embedding function, defaults to None :type embedding_function: BaseEmbeddings | LCEmbeddings, optional :param table_name: name of the DB2 table name, defaults to None :type table_name: str, optional :param document_name_field: mapping field for document name, defaults to `document_id` :type document_name_field: str, optional :param chunk_sequence_number_field: mapping field for chunk sequence number, defaults to `sequence_number` :type chunk_sequence_number_field: str, optional :param text_field: mapping field for text field :type text_field: str, optional :param kwargs: keyword arguments that will be directly passed to `langchain_db2.DB2VS` constructor :type kwargs: Any, optional **Example:** To connect, provide the connection asset ID. You can use custom embeddings to add and search documents. .. code-block:: python from ibm_watsonx_ai import APIClient, Credentials from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores import ( DB2VectorStore, ) from ibm_watsonx_ai.foundation_models.embeddings import Embeddings credentials = Credentials( api_key=IAM_API_KEY, url="https://us-south.ml.cloud.ibm.com" ) api_client = APIClient(credentials, project_id="<PROJECT_ID>") embedding = Embeddings( model_id=EmbeddingTypes.IBM_SLATE_30M_ENG, api_client=api_client ) vector_store = DB2VectorStore( api_client, connection_id="***", collection_name="my_test_collection", embedding_function=embedding, ) vector_store.add_documents( [ { "content": "document one content", "metadata": {"url": "ibm.com"}, }, { "content": "document two content", "metadata": {"url": "ibm.com"}, }, ] ) # ['4CDDAF00329B3DF9', 'B8AE97421A8857E7'] vector_store.search("one", k=1) # [Document(metadata={'url': 'ibm.com'}, page_content='document one content')] """ def __init__( self, api_client: APIClient | None = None, *, connection_id: str | None = None, vector_store: DB2VS | None = None, embedding_function: EmbeddingType | None = None, table_name: str | None = None, document_name_field: str = DEFAULT_DOCUMENT_NAME_FIELD, chunk_sequence_number_field: str = DEFAULT_CHUNK_SEQUENCE_NUMBER_FIELD, text_field: str | None = None, **kwargs: Any, ) -> None: self._connection_id = connection_id self._client = api_client self._document_name_field = document_name_field self._chunk_sequence_number_field = chunk_sequence_number_field self._text_field = text_field self._is_serializable = not bool(vector_store) self._embedding_function = embedding_function self._table_name = table_name self._additional_kwargs = kwargs # import is not at top-level of file due to circular import from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.vector_store_connector import ( # noqa: E501, PLC0415 VectorStoreConnector, VectorStoreDataSourceType, ) if vector_store is None: if self._client is not None and self._connection_id is not None: self._datasource_type, connection_properties = self._connect_by_type( cast(str, self._connection_id) ) else: self._datasource_type, connection_properties = ( VectorStoreDataSourceType.DB2, {}, ) logger.info("Initializing vector store of type: %s", self._datasource_type) self._properties = { **connection_properties, **self._additional_kwargs, "embedding_function": self._embedding_function, "table_name": self._table_name, } if self._text_field is not None: self._properties["text_field"] = self._text_field self._properties = VectorStoreConnector( self._properties )._get_db2_connection_params() vector_store = DB2VS(**self._properties) else: self._datasource_type = ( VectorStoreConnector.get_type_from_langchain_vector_store(vector_store) ) self._text_field = getattr(vector_store, "_text_field", None) super().__init__( vector_store=vector_store, document_name_field=self._document_name_field, chunk_sequence_number_field=self._chunk_sequence_number_field, )
[docs] def get_client(self) -> DB2VS: """Get langchain_db2.DB2VS instance.""" return super().get_client()
[docs] def clear(self) -> None: """ Clear table by removing all records. """ db2_client = self.get_client() clear_table(client=db2_client.client, table_name=db2_client.table_name)
[docs] def count(self) -> int: """ Count number of records in table. """ ids = self.get_client().get_pks() return len(ids) if ids else 0
[docs] def add_documents( self, content: list[str] | list[dict] | list[Document], **kwargs: Any ) -> list[str]: """ Embed documents and add to the vectorstore. :param content: Documents to add to the vectorstore. :type content: list[str] | list[dict] | list[langchain_core.documents.Document] :return: List of IDs of the added texts. :rtype: list[str] """ ids, docs = self._process_documents(content) texts = [doc.page_content for doc in docs] metadatas = [doc.metadata for doc in docs] if len(texts) == 0: return [] db2 = self.get_client() return db2.add_texts(texts=texts, metadatas=metadatas, ids=ids)
[docs] def to_dict(self) -> dict: """Serialize ``DB2VectorStore`` into a dict that allows reconstruction using the ``from_dict`` class method. :return: dict for the from_dict initialization :rtype: dict :raises VectorStoreSerializationError: when instance is not serializable """ if not self._is_serializable: raise VectorStoreSerializationError( "Serialization is not available when passing vector store instance in `DB2VectorStore` constructor." ) embedding = self._embedding_function if embedding is None: embedding_f = None else: watsonx = getattr(embedding, "watsonx_embed", None) to_dict_method = None if watsonx is not None: to_dict_method = getattr(watsonx, "to_dict", None) if to_dict_method is None: to_dict_method = getattr(embedding, "to_dict", None) if callable(to_dict_method): embedding_f = to_dict_method() else: raise VectorStoreSerializationError( f"Cannot serialize embedding-function of type {type(embedding).__name__}; " "expected `.watsonx_embed.to_dict()` or `.to_dict()`." ) data_dict = { "connection_id": self._connection_id, "embedding_function": embedding_f, "table_name": self._table_name, **self._additional_kwargs, "datasource_type": str(self._datasource_type), "document_name_field": self._document_name_field, "chunk_sequence_number_field": self._chunk_sequence_number_field, } if self._text_field is not None: data_dict["text_field"] = self._text_field return data_dict
[docs] @classmethod def from_dict( cls, api_client: APIClient | None = None, data: dict | None = None ) -> "DB2VectorStore": """Creates ``DB2VectorStore`` using only a primitive data type dict. :param api_client: initialised APIClient used in vector store constructor, defaults to None :type api_client: APIClient, optional :param data: dict in schema like the ``to_dict()`` method :type data: dict :return: reconstructed DB2VectorStore :rtype: DB2VectorStore """ d = copy.deepcopy(data) if isinstance(data, dict) else {} # Remove `datasource_type` if present d.pop("datasource_type", None) if "embeddings" in d: d.setdefault("embedding_function", d.pop("embeddings")) if "index_name" in d: d.setdefault("table_name", d.pop("index_name")) if "distance_metric" in d: d.setdefault("distance_strategy", d.pop("distance_metric")) d["embedding_function"] = BaseEmbeddings.from_dict( data=d.get("embedding_function", {}), api_client=api_client ) return cls(api_client, **d)
def _get_window_documents( self, doc_id: str, seq_nums_window: list[int] ) -> list[Document]: """ Receives a document ID and a list of chunks' sequence_numbers, and searches the vector store according to the metadata. :param doc_id: ID of document :type doc_id: str :param seq_nums_window: list of sequence numbers :type seq_nums_window: list[int] :return: list of documents from that document with these sequence_numbers :rtype: list[Document] """ table_name = self._langchain_vector_store.table_name # type: ignore[attr-defined] placeholders = ",".join("?" for _ in seq_nums_window) sql = f""" WITH extracted AS ( SELECT JSON_VALUE(metadata, '$.{self._chunk_sequence_number_field}' RETURNING INTEGER) AS seq_num, JSON_VALUE(metadata, '$.{self._document_name_field}') AS doc_id, {self._text_field} AS page_content FROM {table_name} ) SELECT seq_num, doc_id, page_content FROM extracted WHERE doc_id = ? AND seq_num IN ({placeholders}) ORDER BY seq_num; """ params = [doc_id] + seq_nums_window cursor = self._langchain_vector_store.client.cursor() # type: ignore[attr-defined] cursor.execute(sql, tuple(params)) rows = cursor.fetchall() window_documents = [ Document( page_content=row[2], metadata={ self._chunk_sequence_number_field: row[0], self._document_name_field: row[1], }, ) for row in rows ] return window_documents