# -----------------------------------------------------------------------------------------
# (C) Copyright IBM Corp. 2024-2025.
# https://opensource.org/licenses/BSD-3-Clause
# -----------------------------------------------------------------------------------------
from __future__ import annotations
import copy
import logging
from enum import Enum
from typing import Any, Optional
from warnings import warn
from langchain_core.vectorstores import VectorStore as LangChainVectorStore
from ibm_watsonx_ai.foundation_models.extensions.rag.utils.utils import (
save_ssl_certificate_as_file,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.base_vector_store import (
BaseVectorStore,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.langchain_vector_store_adapter import (
LangChainVectorStoreAdapter,
)
from ibm_watsonx_ai.utils.utils import is_lib_installed
from ibm_watsonx_ai.wml_client_error import MissingExtension
logger = logging.getLogger(__name__)
[docs]
class VectorStoreDataSourceType(str, Enum):
ELASTICSEARCH = "elasticsearch"
CHROMA = "chroma"
MILVUS = "milvus"
MILVUS_WXD = "milvuswxd" # IBM watsonx.data Milvus
DB2 = "db2"
UNDEFINED = "undefined"
def __str__(self) -> str:
return self.value
def _convert_str_to_vs_datasource_type_enum(
datasource_type_str: str,
) -> VectorStoreDataSourceType:
try:
return VectorStoreDataSourceType(datasource_type_str)
except ValueError:
return VectorStoreDataSourceType.UNDEFINED
[docs]
class VectorStoreConnector:
"""Creates a proper vector store client using the provided properties.
Properties are arguments to the LangChain vector stores of a desired type.
Also parses properties extracted from connection assets into one that would fit for initialization.
Custom or connection asset properties that are parsed are:
* `index_name`
* `distance_metric`
* `username`
* `password`
* `ssl_certificate`
* `embeddings`
:param properties: dictionary with all the required key values to establish the connection
:type properties: dict
"""
def __init__(self, properties: dict | None = None) -> None:
def deepcopy_if_possible(obj: Any) -> Any:
try:
return copy.deepcopy(obj)
except Exception:
return obj
self.properties: dict = (
{key: deepcopy_if_possible(value) for key, value in properties.items()}
if isinstance(properties, dict)
else {}
)
[docs]
@staticmethod
def get_type_from_langchain_vector_store(
langchain_vector_store: Any,
) -> VectorStoreDataSourceType:
"""Returns ``DataSourceType`` for concrete LangChain ``VectorStore`` class.
:param langchain_vector_store: vector store object from LangChain
:type langchain_vector_store: Any
:return: DataSourceType name
:rtype: VectorStoreDataSourceType
"""
vs_type = langchain_vector_store.__class__.__name__
match vs_type:
case "ElasticsearchStore":
return VectorStoreDataSourceType.ELASTICSEARCH
case "Chroma":
return VectorStoreDataSourceType.CHROMA
case "Milvus":
return VectorStoreDataSourceType.MILVUS
case "DB2VS":
return VectorStoreDataSourceType.DB2
case _:
return VectorStoreDataSourceType.UNDEFINED
[docs]
def get_from_type(self, type: VectorStoreDataSourceType) -> BaseVectorStore:
"""Gets a vector store based on the provided type (matching from DataSource names from SDK API).
:param type: DataSource type string from SDK API
:type type: VectorStoreDataSourceType
:raises TypeError: unsupported type
:return: proper BaseVectorStore type constructed from properties
:rtype: BaseVectorStore
"""
match type:
case VectorStoreDataSourceType.ELASTICSEARCH:
return self.get_elasticsearch()
case VectorStoreDataSourceType.CHROMA:
return self.get_chroma()
case (
VectorStoreDataSourceType.MILVUS | VectorStoreDataSourceType.MILVUS_WXD
):
return self.get_milvus()
case VectorStoreDataSourceType.DB2:
return self.get_db2()
case _:
raise TypeError("Data source type not supported.")
[docs]
def get_langchain_adapter( # type: ignore[return]
self, langchain_vector_store: Any
) -> LangChainVectorStoreAdapter | None:
"""Creates an adapter for a concrete vector store from LangChain.
:param langchain_vector_store: object that is a subclass of the LangChain vector store
:type langchain_vector_store: Any
:raises ImportError: LangChain required
:return: proper adapter for the vector store
:rtype: LangChainVectorStoreAdapter
"""
if isinstance(langchain_vector_store, LangChainVectorStore):
return LangChainVectorStoreAdapter(vector_store=langchain_vector_store)
[docs]
def get_chroma(self) -> LangChainVectorStoreAdapter:
"""Creates an in-memory vector store for Chroma.
:raises ImportError: langchain required
:return: vector store adapter for LangChain's Chroma
:rtype: LangChainVectorStoreAdapter
"""
if not is_lib_installed(ext := "langchain-chroma"):
raise MissingExtension(ext, extra_info="rag")
from langchain_chroma import Chroma
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.chroma_adapter import (
ChromaVectorStore,
)
parsed_params = self.properties
parsed_params.pop("datasource_type", None)
parsed_params.pop("text_field", None)
# Parse collection name
# 'collection_name' kwargs for Chroma has priority over generic 'index_name'
collection_name = parsed_params.pop("index_name", None)
if collection_name:
parsed_params["collection_name"] = parsed_params.get(
"collection_name", collection_name
)
# Parse distance metric - set it in collection_metadata
# Distance metric for Chroma is determined by collection metadata
# See: Chroma._select_relevance_score_fn()
distance_metric = parsed_params.pop("distance_metric", None)
if distance_metric == "euclidean":
collection_metadata = {"hnsw:space": "l2"}
elif distance_metric == "cosine":
collection_metadata = {"hnsw:space": "cosine"}
else:
collection_metadata = None
parsed_params["collection_metadata"] = parsed_params.get(
"collection_metadata", collection_metadata
)
# Set embedding from params
parsed_params["embedding_function"] = parsed_params.pop("embeddings", None)
if parsed_params["embedding_function"] is None:
raise ValueError("Embedding function is required for Chroma.")
_doc_keys = ("document_name_field", "chunk_sequence_number_field")
document_key_mapping_args = {
k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params
}
return ChromaVectorStore(
vector_store=Chroma(**parsed_params), **document_key_mapping_args
)
[docs]
def get_milvus(self) -> LangChainVectorStoreAdapter:
"""Creates a Milvus vector store.
:raises ImportError: langchain required
:return: vector store adapter for LangChain's Milvus
:rtype: LangChainVectorStoreAdapter
"""
from langchain_milvus import Milvus
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.milvus_adapter import (
MilvusVectorStore,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.milvus_utils import (
DEFAULT_INDEX_PARAM,
)
parsed_params = self._get_milvus_connection_params()
parsed_params.pop("datasource_type", None)
# Connection 'index_name' is 'collection_name' in Milvus
if index_name := parsed_params.pop("index_name", None):
parsed_params["collection_name"] = index_name
elif not parsed_params.get("collection_name"):
raise ValueError("Provide 'index_name' or 'collection_name'.")
# Parse distance metric
# Distance metric is set in `index_params`.
# Here we replace the default `index_params` with different metric type.
# See: `Milvus._create_index()`.
distance_metric = parsed_params.pop("distance_metric", None)
if distance_metric == "cosine":
index_params = DEFAULT_INDEX_PARAM
else:
index_params = None
parsed_params["index_params"] = parsed_params.get("index_params", index_params)
if "embedding_function" in parsed_params:
if "embeddings" in parsed_params:
raise ValueError(
"Either `embeddings` or `embedding_function` must be specified, but not both."
)
else:
parsed_params["embedding_function"] = parsed_params.pop("embeddings", None)
_doc_keys = ("document_name_field", "chunk_sequence_number_field")
document_key_mapping_args = {
k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params
}
return MilvusVectorStore(
vector_store=Milvus(**parsed_params), **document_key_mapping_args
)
def _get_milvus_connection_params(self) -> dict:
parsed_params = self.properties
# Prepare connection_args (if not present)
if "connection_args" not in parsed_params:
parsed_params["connection_args"] = {}
# Set secure=True also when user set it in Connection UI
if "ssl" in parsed_params:
is_ssl = parsed_params.pop("ssl")
parsed_params["secure"] = True if is_ssl == "true" else False
# Get SSL certificate saved to file
if "ssl_certificate" in parsed_params:
parsed_params["server_pem_path"] = save_ssl_certificate_as_file(
parsed_params.pop("ssl_certificate")
)
# Connection 'username' is 'user' in Milvus
if "username" in parsed_params:
parsed_params["user"] = parsed_params.pop("username")
# Connection 'database' is 'db_name' in Milvus
if "database" in parsed_params:
parsed_params["db_name"] = parsed_params.pop("database")
# Move each param that was in parsed_params to connection_args if we expect it here
for param in [
"uri",
"host",
"port",
"user",
"password",
"db_name",
"secure",
"client_key_path",
"client_pem_path",
"ca_pem_path",
"server_pem_path",
"server_name",
]:
if param in parsed_params.keys():
parsed_params["connection_args"][param] = parsed_params.pop(param)
def build_address(connection_args: dict) -> str:
"""Build an address string from host and port."""
host = connection_args.get("host", "localhost")
port = connection_args.get("port", 19530)
return f"{host}:{port}"
parsed_params["connection_args"]["address"] = build_address(
parsed_params["connection_args"]
)
return parsed_params
[docs]
def get_elasticsearch(self) -> LangChainVectorStoreAdapter:
"""Creates an Elasticsearch vector store.
:raises ImportError: langchain required
:return: vector store adapter for LangChain's Elasticsearch
:rtype: LangChainVectorStoreAdapter
"""
if not is_lib_installed(ext := "langchain-elasticsearch"):
raise MissingExtension(ext, extra_info="rag")
from langchain_elasticsearch import (
DenseVectorScriptScoreStrategy,
DistanceMetric,
ElasticsearchStore,
RetrievalStrategy,
SparseVectorStrategy,
)
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.es_adapter import (
ElasticsearchVectorStore,
)
# Parse ES connection data - select proper connection type
parsed_params = self._get_elasticsearch_connection_params()
parsed_params = self.properties
parsed_params.pop("datasource_type", None)
# Parse distance metric
# Match with ES DistanceMetric.
# Default is cosine.
match parsed_params.pop("distance_metric", None):
case "euclidean":
distance_metric = DistanceMetric.EUCLIDEAN_DISTANCE
case "cosine":
distance_metric = DistanceMetric.COSINE
case _:
distance_metric = DistanceMetric.COSINE
parsed_params["distance_strategy"] = parsed_params.pop(
"distance_strategy", distance_metric
)
# Determine retrieval strategy type from parameters
if "strategy" not in parsed_params or not isinstance(
parsed_params["strategy"], RetrievalStrategy
):
if "model_id" in parsed_params:
parsed_params["strategy"] = SparseVectorStrategy(
model_id=parsed_params.pop("model_id")
)
else:
retrieval_strategy_warning = (
"The default retrieval strategy will soon be updated from `DenseVectorScriptScoreStrategy` to `DenseVectorStrategy` "
"for consistency with the `ElasticsearchVectorStore` class. "
"Using the latter is recommended to avoid potential issues in the future."
)
warn(retrieval_strategy_warning)
parsed_params["strategy"] = DenseVectorScriptScoreStrategy(
distance=distance_metric,
)
# Set embedding from params
if parsed_params.get("embedding") is None:
parsed_params["embedding"] = parsed_params.pop("embeddings", None)
_doc_keys = ("document_name_field", "chunk_sequence_number_field")
document_key_mapping_args = {
k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params
}
return ElasticsearchVectorStore(
vector_store=ElasticsearchStore(**parsed_params),
**document_key_mapping_args,
)
def _get_elasticsearch_connection_params(self) -> dict:
parsed_params = self.properties
# Always use empty es_params if not provided
if "es_params" not in parsed_params:
parsed_params["es_params"] = {}
# Drop unnecessary stuff from connection asset if they are present
parsed_params.pop("auth_method", None)
parsed_params.pop("use_anonymous_access", None)
# Parse ES connection data - select proper connection type
# Connecting by 'url': username/password or api_key
if "url" in parsed_params:
# Get URL of ES instance
parsed_params["es_url"] = parsed_params.pop("url")
# Detect credentials given in connection asset
if "username" in parsed_params and "password" in parsed_params:
# Connect by username and password extracted from connection
parsed_params["es_user"] = parsed_params.pop("username")
parsed_params["es_password"] = parsed_params.pop("password")
parsed_params.pop("api_key", None)
elif "api_key" in parsed_params:
# Connect by api key
parsed_params["es_api_key"] = parsed_params.pop("api_key")
parsed_params.pop("username", None)
parsed_params.pop("password", None)
else:
raise ValueError(
"""To connect to given hostname ['url'] provide
either ['username', 'password'] or ['api_key'].
Make sure those fields are present in connection details or parameters given
upon VectorStore initialization. """
)
elif "es_url" in parsed_params:
if "es_user" in parsed_params and "es_password" in parsed_params:
pass
elif "es_api_key" in parsed_params:
pass
else:
raise ValueError(
"""To connect to given hostname ['es_url'] provide
either ['es_user', 'es_password'] or ['es_api_key'].
Make sure those fields are present in parameters given
upon VectorStore initialization. """
)
# Connecting by '(es_)cloud_id' to Elasticsearch cloud
elif "cloud_id" in parsed_params and "api_key" in parsed_params:
parsed_params["es_cloud_id"] = parsed_params.pop("cloud_id", None)
parsed_params["es_api_key"] = parsed_params.pop("api_key", None)
elif "es_cloud_id" in parsed_params and "es_api_key" in parsed_params:
pass
else:
raise ValueError(
"""Connection data was not sufficent. Either provide:
- ['url', 'username', 'password'],
- ['url', 'api_key'],
- ['cloud_id', 'api_key']
or
- ['es_url', 'es_user', 'es_password'],
- ['es_url', 'es_api_key'],
- ['es_cloud_id', 'es_api_key'],
in your connection asset or in params for VectorStore."""
)
if not parsed_params.get("index_name"):
raise ValueError("Provide 'index_name'.")
# Parse SSL certificate
ssl_certificate_content = parsed_params.pop("ssl_certificate", None)
if ssl_certificate_content:
parsed_params["es_params"]["ca_certs"] = save_ssl_certificate_as_file(
ssl_certificate_content
)
return parsed_params
[docs]
def get_db2(self) -> LangChainVectorStoreAdapter:
"""Creates a DV2 vector store.
:raises ImportError: langchain-db2 required
:return: vector store adapter for LangChain's DB2
:rtype: LangChainVectorStoreAdapter
"""
if not is_lib_installed(ext := "langchain-db2"):
raise MissingExtension(ext, extra_info="rag")
from langchain_db2 import DB2VS
from langchain_db2.db2vs import DistanceStrategy
from ibm_watsonx_ai.foundation_models.extensions.rag.vector_stores.adapters.db2_adapter import (
DB2VectorStore,
)
parsed_params = self._get_db2_connection_params()
parsed_params.pop("datasource_type", None)
# Connection 'index_name' is 'table_name' in DB2
if index_name := parsed_params.pop("index_name", None):
parsed_params["table_name"] = index_name
elif not parsed_params.get("table_name"):
raise ValueError("Provide 'index_name' or 'table_name'.")
# Parse distance_metric
# Match with DB2 DistanceStrategy.
distance_metric: Optional[DistanceStrategy] = None
match parsed_params.pop("distance_metric", None):
case "euclidean":
distance_metric = DistanceStrategy.EUCLIDEAN_DISTANCE
case "max_inner_product":
distance_metric = DistanceStrategy.MAX_INNER_PRODUCT
case "dot":
distance_metric = DistanceStrategy.DOT_PRODUCT
case "jaccard":
distance_metric = DistanceStrategy.JACCARD
case "cosine":
distance_metric = DistanceStrategy.COSINE
case _:
pass
distance_strategy = parsed_params.get("distance_strategy", None)
if distance_strategy or distance_metric:
parsed_params["distance_strategy"] = distance_strategy or distance_metric
# Set embedding_function
if "embedding_function" not in parsed_params:
parsed_params["embedding_function"] = parsed_params.pop("embeddings", None)
elif "embeddings" in parsed_params:
raise ValueError(
"Either `embeddings` or `embedding_function` must be specified, but not both."
)
_doc_keys = ("document_name_field", "chunk_sequence_number_field")
document_key_mapping_args = {
k: parsed_params.pop(k) for k in _doc_keys if k in parsed_params
}
return DB2VectorStore(
vector_store=DB2VS(**parsed_params),
**document_key_mapping_args,
)
def _get_db2_connection_params(self) -> dict:
parsed_params = self.properties
# Prepare connection_args (if not present)
if "connection_args" not in parsed_params:
parsed_params["connection_args"] = {}
# Connection 'ssl' is 'security' in DB2
ssl_flag = parsed_params.pop("ssl", None)
if ssl_flag is True or (
isinstance(ssl_flag, str) and ssl_flag.strip().lower() == "true"
):
parsed_params["security"] = "ssl"
# Move each param that was in parsed_params to connection_args if we expect it here
for param in [
"database",
"host",
"port",
"username",
"password",
"security",
]:
if param in parsed_params.keys():
parsed_params["connection_args"][param] = parsed_params.pop(param)
return parsed_params