Source code for genai.extensions.langchain.embeddings
from __future__ import annotations
from typing import List, Optional
from pydantic import BaseModel, ConfigDict
from genai._types import ModelLike
from genai.client import Client
from genai.schema import TextEmbeddingParameters
from genai.text.embedding.embedding_service import CreateExecutionOptions
__all__ = ["LangChainEmbeddingsInterface"]
try:
from langchain_core.embeddings import Embeddings
from langchain_core.runnables.config import run_in_executor
except ImportError:
raise ImportError("Could not import langchain: Please install ibm-generative-ai[langchain] extension.") # noqa: B904
[docs]
class LangChainEmbeddingsInterface(BaseModel, Embeddings):
"""
Class representing the LangChainChatInterface for interacting with the LangChain chat API.
Example::
from genai import Client, Credentials
from genai.extensions.langchain import LangChainEmbeddingsInterface
from genai.text.embedding import TextEmbeddingParameters
client = Client(credentials=Credentials.from_env())
embeddings = LangChainEmbeddingsInterface(
client=client,
model_id="sentence-transformers/all-minilm-l6-v2",
parameters=TextEmbeddingParameters(truncate_input_tokens=True)
)
embeddings.embed_query("Hello world!")
embeddings.embed_documents(["First document", "Second document"])
"""
model_config = ConfigDict(extra="forbid", protected_namespaces=(), arbitrary_types_allowed=True)
client: Client
model_id: str
parameters: Optional[ModelLike[TextEmbeddingParameters]] = None
execution_options: Optional[ModelLike[CreateExecutionOptions]] = None
[docs]
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed search documents"""
return self._get_embeddings(texts)
[docs]
def embed_query(self, text: str) -> list[float]:
"""Embed query text."""
response = self._get_embeddings([text])
return response[0]
[docs]
async def aembed_query(self, text: str) -> List[float]:
"""Asynchronous Embed query text."""
return await run_in_executor(None, self.embed_query, text)
[docs]
async def aembed_documents(self, texts: List[str]) -> list[list[float]]:
"""Asynchronous Embed search documents"""
return await run_in_executor(None, self.embed_documents, texts)
def _get_embeddings(self, texts: list[str]) -> list[list[float]]:
embeddings: list[list[float]] = []
for response in self.client.text.embedding.create(
model_id=self.model_id, inputs=texts, parameters=self.parameters, execution_options=self.execution_options
):
embedding_list = [result.embedding for result in response.results]
embeddings.extend(embedding_list)
return embeddings
LangChainEmbeddingsInterface.model_rebuild()