Source code for ibm_watsonx_ai.foundation_models.extensions.rag.chunker.langchain_chunker

#  -----------------------------------------------------------------------------------------
#  (C) Copyright IBM Corp. 2024.
#  https://opensource.org/licenses/BSD-3-Clause
#  -----------------------------------------------------------------------------------------

from typing import Literal, Sequence, Any, Iterable

from langchain.text_splitter import TextSplitter
from langchain_core.documents import Document

from .base_chunker import BaseChunker


__all__ = [
    "LangChainChunker",
]


[docs] class LangChainChunker(BaseChunker[Document]): """ Wrapper for LangChain TextSplitter. :param method: describes type of the TextSplitter as the main instance performing chunking, defaults to "recursive" :type method: Literal["recursive", "character", "token"], optional :param chunk_size: maximum size of a single chunk that is returned, defaults to 4000 :type chunk_size: int, optional :param chunk_overlap: overlap in characters between chunks, defaults to 200 :type chunk_overlap: int, optional :param encoding_name: encoding used in the TokenTextSplitter, defaults to "gpt2" :type encoding_name: str, optional :param model_name: model used in the TokenTextSplitter :type model_name: str, optional .. code-block:: python from ibm_watsonx_ai.foundation_models.extensions.rag.chunker import LangChainChunker text_splitter = LangChainChunker( method="recursive", chunk_size=1000, chunk_overlap=200 ) chunks_ids = [] for i, document in enumerate(data_loader): chunks = text_splitter.split_documents([document]) chunks_ids.append(vector_store.add_documents(chunks, batch_size=300)) """ supported_methods = ("recursive", "character", "token") def __init__( self, method: Literal["recursive", "character", "token"] = "recursive", chunk_size: int = 4000, chunk_overlap: int = 200, encoding_name: str = "gpt2", model_name: str | None = None, **kwargs: Any, ) -> None: self.method = method self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap self.encoding_name = encoding_name self.model_name = model_name self.separators = kwargs.pop("separators", ["\n\n", "(?<=\. )", "\n", " ", ""]) self._text_splitter = self._get_text_splitter() def __eq__(self, other: object) -> bool: if isinstance(other, LangChainChunker): return self.to_dict() == other.to_dict() else: return NotImplemented def _get_text_splitter(self) -> TextSplitter: """Create instance of TextSplitter based on the settings.""" text_splitter: TextSplitter match self.method: case "recursive": from langchain.text_splitter import RecursiveCharacterTextSplitter text_splitter = RecursiveCharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, separators=self.separators, length_function=len, add_start_index=True, ) case "character": from langchain.text_splitter import CharacterTextSplitter text_splitter = CharacterTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, add_start_index=True, ) case "token": from langchain.text_splitter import TokenTextSplitter text_splitter = TokenTextSplitter( chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, encoding_name=self.encoding_name, model_name=self.model_name, add_start_index=True, ) case _: raise ValueError( "Chunker method '{}' is not supported. Use one of {}".format( self.method, self.supported_methods ) ) return text_splitter
[docs] def to_dict(self) -> dict[str, Any]: """ Return dict that can be used to recreate instance of the LangChainChunker. """ params = ( "method", "chunk_size", "chunk_overlap", "encoding_name", "model_name", ) ret = {k: v for k, v in self.__dict__.items() if k in params} return ret
[docs] @classmethod def from_dict(cls, d: dict[str, Any]) -> "LangChainChunker": """Create instance from the dictionary""" return cls(**d)
def _set_document_id_in_metadata_if_missing( self, documents: Iterable[Document] ) -> None: """ Sets "document_id" in the metadata if missing. The document_id is the hash of the document's content. :param documents: sequence of documents :type documents: Iterable[langchain_core.documents.Document] :return: None """ for doc in documents: if "document_id" not in doc.metadata: doc.metadata["document_id"] = str(hash(doc.page_content)) def _set_sequence_number_in_metadata( self, chunks: list[Document] ) -> list[Document]: """ Sets "sequence_number" in the metadata, sorted by chunks' "start_index". :param chunks: sequence of chunks of documents that contain context in the format of text :type chunks: list[langchain_core.documents.Document] :return: chunks: list of updated chunks, sorted by document_id and sequence_number :type chunks: list[langchain_core.documents.Document] """ # sort chunks by start_index for each document_id sorted_chunks = sorted( chunks, key=lambda x: (x.metadata["document_id"], x.metadata["start_index"]) ) document_sequence: dict[str, int] = {} for chunk in sorted_chunks: doc_id = chunk.metadata["document_id"] prev_seq_num = document_sequence.get(doc_id, 0) seq_num = prev_seq_num + 1 document_sequence[doc_id] = seq_num chunk.metadata["sequence_number"] = seq_num return sorted_chunks
[docs] def split_documents(self, documents: Sequence[Document]) -> list[Document]: """ Split series of documents into smaller chunks based on the provided chunker settings. Each chunk has metadata that includes the document_id, sequence_number and start_index. :param documents: sequence of elements that contain context in the format of text :type documents: Sequence[langchain_core.documents.Document] :return: list of documents splitter into smaller ones, having less content :rtype: list[langchain_core.documents.Document] """ self._set_document_id_in_metadata_if_missing(documents) chunks = self._text_splitter.split_documents(documents) sorted_chunks = self._set_sequence_number_in_metadata(chunks) return sorted_chunks