Source code for ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.vector_store_connector

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024-2025.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from __future__ import annotations

import copy
import logging
from enum import Enum
from typing import Any, Optional
from warnings import warn

from langchain_core.vectorstores import VectorStore as LangChainVectorStore

from ibm_watsonx_ai.foundation_models.extensions.rag.utils.utils import (
    save_ssl_certificate_as_file,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.base_vector_store import (
    BaseVectorStore,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.langchain_vector_store_adapter import (
    LangChainVectorStoreAdapter,
)
from ibm_watsonx_ai.utils.utils import is_lib_installed
from ibm_watsonx_ai.wml_client_error import MissingExtension

logger = logging.getLogger(__name__)


[docs] class VectorStoreDataSourceType(str, Enum): ELASTICSEARCH = "elasticsearch" CHROMA = "chroma" MILVUS = "milvus" MILVUS_WXD = "milvuswxd" # IBM watsonx.data Milvus DB2 = "db2" UNDEFINED = "undefined" def __str__(self) -> str: return self.value
def _convert_str_to_vs_datasource_type_enum( datasource_type_str: str, ) -> VectorStoreDataSourceType: try: return VectorStoreDataSourceType(datasource_type_str) except ValueError: return VectorStoreDataSourceType.UNDEFINED
[docs] class VectorStoreConnector: """Creates a proper vector store client using the provided properties. Properties are arguments to the LangChain vector stores of a desired type. Also parses properties extracted from connection assets into one that would fit for initialization. Custom or connection asset properties that are parsed are: * `index_name` * `distance_metric` * `username` * `password` * `ssl_certificate` * `embeddings` :param properties: dictionary with all the required key values to establish the connection :type properties: dict """ def __init__(self, properties: dict | None = None) -> None: def deepcopy_if_possible(obj: Any) -> Any: try: return copy.deepcopy(obj) except Exception: return obj self.properties: dict = ( {key: deepcopy_if_possible(value) for key, value in properties.items()} if isinstance(properties, dict) else {} )
[docs] @staticmethod def get_type_from_langchain_vector_store( langchain_vector_store: Any, ) -> VectorStoreDataSourceType: """Returns ``DataSourceType`` for concrete LangChain ``VectorStore`` class. :param langchain_vector_store: vector store object from LangChain :type langchain_vector_store: Any :return: DataSourceType name :rtype: VectorStoreDataSourceType """ vs_type = langchain_vector_store.__class__.__name__ match vs_type: case "ElasticsearchStore": return VectorStoreDataSourceType.ELASTICSEARCH case "Chroma": return VectorStoreDataSourceType.CHROMA case "Milvus": return VectorStoreDataSourceType.MILVUS case "DB2VS": return VectorStoreDataSourceType.DB2 case _: return VectorStoreDataSourceType.UNDEFINED
[docs] def get_from_type(self, type: VectorStoreDataSourceType) -> BaseVectorStore: """Gets a vector store based on the provided type (matching from DataSource names from SDK API). :param type: DataSource type string from SDK API :type type: VectorStoreDataSourceType :raises TypeError: unsupported type :return: proper BaseVectorStore type constructed from properties :rtype: BaseVectorStore """ match type: case VectorStoreDataSourceType.ELASTICSEARCH: return self.get_elasticsearch() case VectorStoreDataSourceType.CHROMA: return self.get_chroma() case ( VectorStoreDataSourceType.MILVUS | VectorStoreDataSourceType.MILVUS_WXD ): return self.get_milvus() case VectorStoreDataSourceType.DB2: return self.get_db2() case _: raise TypeError("Data source type not supported.")
[docs] def get_langchain_adapter( # type: ignore[return] self, langchain_vector_store: Any ) -> LangChainVectorStoreAdapter | None: """Creates an adapter for a concrete vector store from LangChain. :param langchain_vector_store: object that is a subclass of the LangChain vector store :type langchain_vector_store: Any :raises ImportError: LangChain required :return: proper adapter for the vector store :rtype: LangChainVectorStoreAdapter """ if isinstance(langchain_vector_store, LangChainVectorStore): return LangChainVectorStoreAdapter(vector_store=langchain_vector_store)
[docs] def get_chroma(self) -> LangChainVectorStoreAdapter: """Creates an in-memory vector store for Chroma. :raises ImportError: langchain required :return: vector store adapter for LangChain's Chroma :rtype: LangChainVectorStoreAdapter """ if not is_lib_installed(ext := "langchain-chroma"): raise MissingExtension(ext, extra_info="rag") from langchain_chroma import Chroma from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.chroma_adapter import ( ChromaVectorStore, ) parsed_params = self.properties parsed_params.pop("datasource_type", None) parsed_params.pop("text_field", None) # Parse collection name # 'collection_name' kwargs for Chroma has priority over generic 'index_name' collection_name = parsed_params.pop("index_name", None) if collection_name: parsed_params["collection_name"] = parsed_params.get( "collection_name", collection_name ) # Parse distance metric - set it in collection_metadata # Distance metric for Chroma is determined by collection metadata # See: Chroma._select_relevance_score_fn() distance_metric = parsed_params.pop("distance_metric", None) if distance_metric == "euclidean": collection_metadata = {"hnsw:space": "l2"} elif distance_metric == "cosine": collection_metadata = {"hnsw:space": "cosine"} else: collection_metadata = None parsed_params["collection_metadata"] = parsed_params.get( "collection_metadata", collection_metadata ) # Set embedding from params parsed_params["embedding_function"] = parsed_params.pop("embeddings", None) if parsed_params["embedding_function"] is None: raise ValueError("Embedding function is required for Chroma.") _doc_keys = ("document_name_field", "chunk_sequence_number_field") document_key_mapping_args = { k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params } return ChromaVectorStore( vector_store=Chroma(**parsed_params), **document_key_mapping_args )
[docs] def get_milvus(self) -> LangChainVectorStoreAdapter: """Creates a Milvus vector store. :raises ImportError: langchain required :return: vector store adapter for LangChain's Milvus :rtype: LangChainVectorStoreAdapter """ from langchain_milvus import Milvus from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.milvus_adapter import ( MilvusVectorStore, ) from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.milvus_utils import ( DEFAULT_INDEX_PARAM, ) parsed_params = self._get_milvus_connection_params() parsed_params.pop("datasource_type", None) # Connection 'index_name' is 'collection_name' in Milvus if index_name := parsed_params.pop("index_name", None): parsed_params["collection_name"] = index_name elif not parsed_params.get("collection_name"): raise ValueError("Provide 'index_name' or 'collection_name'.") # Parse distance metric # Distance metric is set in `index_params`. # Here we replace the default `index_params` with different metric type. # See: `Milvus._create_index()`. distance_metric = parsed_params.pop("distance_metric", None) if distance_metric == "cosine": index_params = DEFAULT_INDEX_PARAM else: index_params = None parsed_params["index_params"] = parsed_params.get("index_params", index_params) if "embedding_function" in parsed_params: if "embeddings" in parsed_params: raise ValueError( "Either `embeddings` or `embedding_function` must be specified, but not both." ) else: parsed_params["embedding_function"] = parsed_params.pop("embeddings", None) _doc_keys = ("document_name_field", "chunk_sequence_number_field") document_key_mapping_args = { k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params } return MilvusVectorStore( vector_store=Milvus(**parsed_params), **document_key_mapping_args )
def _get_milvus_connection_params(self) -> dict: parsed_params = self.properties # Prepare connection_args (if not present) if "connection_args" not in parsed_params: parsed_params["connection_args"] = {} # Set secure=True also when user set it in Connection UI if "ssl" in parsed_params: is_ssl = parsed_params.pop("ssl") parsed_params["secure"] = True if is_ssl == "true" else False # Get SSL certificate saved to file if "ssl_certificate" in parsed_params: parsed_params["server_pem_path"] = save_ssl_certificate_as_file( parsed_params.pop("ssl_certificate") ) # Connection 'username' is 'user' in Milvus if "username" in parsed_params: parsed_params["user"] = parsed_params.pop("username") # Connection 'database' is 'db_name' in Milvus if "database" in parsed_params: parsed_params["db_name"] = parsed_params.pop("database") # Move each param that was in parsed_params to connection_args if we expect it here for param in [ "uri", "host", "port", "user", "password", "db_name", "secure", "client_key_path", "client_pem_path", "ca_pem_path", "server_pem_path", "server_name", ]: if param in parsed_params.keys(): parsed_params["connection_args"][param] = parsed_params.pop(param) def build_address(connection_args: dict) -> str: """Build an address string from host and port.""" host = connection_args.get("host", "localhost") port = connection_args.get("port", 19530) return f"{host}:{port}" parsed_params["connection_args"]["address"] = build_address( parsed_params["connection_args"] ) return parsed_params
[docs] def get_elasticsearch(self) -> LangChainVectorStoreAdapter: """Creates an Elasticsearch vector store. :raises ImportError: langchain required :return: vector store adapter for LangChain's Elasticsearch :rtype: LangChainVectorStoreAdapter """ if not is_lib_installed(ext := "langchain-elasticsearch"): raise MissingExtension(ext, extra_info="rag") from langchain_elasticsearch import ( DenseVectorScriptScoreStrategy, DistanceMetric, ElasticsearchStore, RetrievalStrategy, SparseVectorStrategy, ) from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.es_adapter import ( ElasticsearchVectorStore, ) # Parse ES connection data - select proper connection type parsed_params = self._get_elasticsearch_connection_params() parsed_params = self.properties parsed_params.pop("datasource_type", None) # Parse distance metric # Match with ES DistanceMetric. # Default is cosine. match parsed_params.pop("distance_metric", None): case "euclidean": distance_metric = DistanceMetric.EUCLIDEAN_DISTANCE case "cosine": distance_metric = DistanceMetric.COSINE case _: distance_metric = DistanceMetric.COSINE parsed_params["distance_strategy"] = parsed_params.pop( "distance_strategy", distance_metric ) # Determine retrieval strategy type from parameters if "strategy" not in parsed_params or not isinstance( parsed_params["strategy"], RetrievalStrategy ): if "model_id" in parsed_params: parsed_params["strategy"] = SparseVectorStrategy( model_id=parsed_params.pop("model_id") ) else: retrieval_strategy_warning = ( "The default retrieval strategy will soon be updated from `DenseVectorScriptScoreStrategy` to `DenseVectorStrategy` " "for consistency with the `ElasticsearchVectorStore` class. " "Using the latter is recommended to avoid potential issues in the future." ) warn(retrieval_strategy_warning) parsed_params["strategy"] = DenseVectorScriptScoreStrategy( distance=distance_metric, ) # Set embedding from params if parsed_params.get("embedding") is None: parsed_params["embedding"] = parsed_params.pop("embeddings", None) _doc_keys = ("document_name_field", "chunk_sequence_number_field") document_key_mapping_args = { k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params } return ElasticsearchVectorStore( vector_store=ElasticsearchStore(**parsed_params), **document_key_mapping_args, )
def _get_elasticsearch_connection_params(self) -> dict: parsed_params = self.properties # Always use empty es_params if not provided if "es_params" not in parsed_params: parsed_params["es_params"] = {} # Drop unnecessary stuff from connection asset if they are present parsed_params.pop("auth_method", None) parsed_params.pop("use_anonymous_access", None) # Parse ES connection data - select proper connection type # Connecting by 'url': username/password or api_key if "url" in parsed_params: # Get URL of ES instance parsed_params["es_url"] = parsed_params.pop("url") # Detect credentials given in connection asset if "username" in parsed_params and "password" in parsed_params: # Connect by username and password extracted from connection parsed_params["es_user"] = parsed_params.pop("username") parsed_params["es_password"] = parsed_params.pop("password") parsed_params.pop("api_key", None) elif "api_key" in parsed_params: # Connect by api key parsed_params["es_api_key"] = parsed_params.pop("api_key") parsed_params.pop("username", None) parsed_params.pop("password", None) else: raise ValueError( """To connect to given hostname ['url'] provide either ['username', 'password'] or ['api_key']. Make sure those fields are present in connection details or parameters given upon VectorStore initialization. """ ) elif "es_url" in parsed_params: if "es_user" in parsed_params and "es_password" in parsed_params: pass elif "es_api_key" in parsed_params: pass else: raise ValueError( """To connect to given hostname ['es_url'] provide either ['es_user', 'es_password'] or ['es_api_key']. Make sure those fields are present in parameters given upon VectorStore initialization. """ ) # Connecting by '(es_)cloud_id' to Elasticsearch cloud elif "cloud_id" in parsed_params and "api_key" in parsed_params: parsed_params["es_cloud_id"] = parsed_params.pop("cloud_id", None) parsed_params["es_api_key"] = parsed_params.pop("api_key", None) elif "es_cloud_id" in parsed_params and "es_api_key" in parsed_params: pass else: raise ValueError( """Connection data was not sufficent. Either provide: - ['url', 'username', 'password'], - ['url', 'api_key'], - ['cloud_id', 'api_key'] or - ['es_url', 'es_user', 'es_password'], - ['es_url', 'es_api_key'], - ['es_cloud_id', 'es_api_key'], in your connection asset or in params for VectorStore.""" ) if not parsed_params.get("index_name"): raise ValueError("Provide 'index_name'.") # Parse SSL certificate ssl_certificate_content = parsed_params.pop("ssl_certificate", None) if ssl_certificate_content: parsed_params["es_params"]["ca_certs"] = save_ssl_certificate_as_file( ssl_certificate_content ) return parsed_params
[docs] def get_db2(self) -> LangChainVectorStoreAdapter: """Creates a DV2 vector store. :raises ImportError: langchain-db2 required :return: vector store adapter for LangChain's DB2 :rtype: LangChainVectorStoreAdapter """ if not is_lib_installed(ext := "langchain-db2"): raise MissingExtension(ext, extra_info="rag") from langchain_db2 import DB2VS from langchain_db2.db2vs import DistanceStrategy from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.db2_adapter import ( DB2VectorStore, ) parsed_params = self._get_db2_connection_params() parsed_params.pop("datasource_type", None) # Connection 'index_name' is 'table_name' in DB2 if index_name := parsed_params.pop("index_name", None): parsed_params["table_name"] = index_name elif not parsed_params.get("table_name"): raise ValueError("Provide 'index_name' or 'table_name'.") # Parse distance_metric # Match with DB2 DistanceStrategy. distance_metric: Optional[DistanceStrategy] = None match parsed_params.pop("distance_metric", None): case "euclidean": distance_metric = DistanceStrategy.EUCLIDEAN_DISTANCE case "max_inner_product": distance_metric = DistanceStrategy.MAX_INNER_PRODUCT case "dot": distance_metric = DistanceStrategy.DOT_PRODUCT case "jaccard": distance_metric = DistanceStrategy.JACCARD case "cosine": distance_metric = DistanceStrategy.COSINE case _: pass distance_strategy = parsed_params.get("distance_strategy", None) if distance_strategy or distance_metric: parsed_params["distance_strategy"] = distance_strategy or distance_metric # Set embedding_function if "embedding_function" not in parsed_params: parsed_params["embedding_function"] = parsed_params.pop("embeddings", None) elif "embeddings" in parsed_params: raise ValueError( "Either `embeddings` or `embedding_function` must be specified, but not both." ) _doc_keys = ("document_name_field", "chunk_sequence_number_field") document_key_mapping_args = { k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params } return DB2VectorStore( vector_store=DB2VS(**parsed_params), **document_key_mapping_args, )
def _get_db2_connection_params(self) -> dict: parsed_params = self.properties # Prepare connection_args (if not present) if "connection_args" not in parsed_params: parsed_params["connection_args"] = {} # Connection 'ssl' is 'security' in DB2 ssl_flag = parsed_params.pop("ssl", None) if ssl_flag is True or ( isinstance(ssl_flag, str) and ssl_flag.strip().lower() == "true" ): parsed_params["security"] = "ssl" # Move each param that was in parsed_params to connection_args if we expect it here for param in [ "database", "host", "port", "username", "password", "security", ]: if param in parsed_params.keys(): parsed_params["connection_args"][param] = parsed_params.pop(param) return parsed_params