Source code for genai.extensions.localserver.local_api_server

import contextlib
import datetime
import logging
import threading
import time
import typing
import uuid
from typing import Optional

from fastapi.utils import is_body_allowed_for_status_code
from starlette.responses import Response

from genai._utils.general import cast_list
from genai._utils.responses import BaseErrorResponse, get_api_error_class_by_status_code
from genai.credentials import Credentials
from genai.extensions.localserver.local_base_model import LocalModel
from genai.extensions.localserver.schema import (
    TextGenerationCreateRequest,
    TextTokenizationCreateRequest,
)
from genai.schema import (
    ConcurrencyLimit,
    InternalServerErrorResponse,
    NotFoundResponse,
    TextGenerationCreateResponse,
    TextGenerationLimit,
    TextGenerationLimitRetrieveResponse,
    TextTokenizationCreateResponse,
    UnauthorizedResponse,
)

try:
    import uvicorn
    from fastapi import APIRouter, FastAPI, HTTPException, Request, status
    from fastapi.responses import JSONResponse
    from starlette.middleware.base import (
        BaseHTTPMiddleware,
        DispatchFunction,
        RequestResponseEndpoint,
    )
except ImportError as iex:
    raise ImportError(f"Could not import {iex.name}: Please install ibm-generative-ai[localserver] extension.")  # noqa: B904

logger = logging.getLogger(__name__)


[docs] class ApiErrorMiddleware(BaseHTTPMiddleware):
[docs] def __init__(self, app: FastAPI, dispatch: typing.Optional[DispatchFunction] = None) -> None: super().__init__(app, dispatch) app.add_exception_handler(HTTPException, self._handle_exception)
async def _handle_exception(self, request: Request, exc: Exception): detail = str(exc) status_code = getattr(exc, "status_code", 500) headers = getattr(exc, "headers", None) if isinstance(exc, HTTPException): detail = exc.detail status_code = exc.status_code if not is_body_allowed_for_status_code(status_code): return Response(status_code=status_code, headers=headers) response_body: BaseErrorResponse try: cls: type[BaseErrorResponse] = get_api_error_class_by_status_code(status_code) or BaseErrorResponse response_body = cls.model_validate( cls( error=detail, message=detail, status_code=status_code, extensions={}, # type: ignore ) if isinstance(detail, str) else cls.model_validate(detail) ) except: # noqa response_body = InternalServerErrorResponse(message=detail, error=detail, extensions={}) # type: ignore return JSONResponse( status_code=status_code, headers=headers, content=response_body.model_dump(), )
[docs] async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: try: return await call_next(request) except Exception as exc: return await self._handle_exception(request, exc)
[docs] class ApiAuthMiddleware(BaseHTTPMiddleware):
[docs] def __init__(self, app: FastAPI, api_key: Optional[str] = None, insecure: bool = False): super().__init__(app) self.api_key = api_key self.insecure = insecure
[docs] async def dispatch(self, request: Request, call_next): auth_header = request.headers["Authorization"] logging.debug(f"Incoming Request: Auth: {auth_header}, Endpoint: {request.url}") if self.insecure is False: if str(self.api_key) not in auth_header: logger.warning(f"A client used an incorrect API Key: {auth_header}") response_obj = UnauthorizedResponse( error="Unauthorized", message="API key not found", extensions={}, # type: ignore ) logger.debug(response_obj.model_dump()) return JSONResponse( content=response_obj.model_dump(), status_code=status.HTTP_401_UNAUTHORIZED, ) logging.debug("Calling next in chain") response = await call_next(request) logging.debug("Returning response") return response
[docs] class LocalLLMServer:
[docs] def __init__( self, models: list[type[LocalModel]], port: int = 8080, interface: str = "0.0.0.0", api_key: Optional[str] = None, insecure_api: bool = False, ): self.models = {model.model_id: model() for model in models} self.port = port self.interface = interface self.insecure_api = insecure_api # Set the API Key self.api_key: str = api_key or "test" if insecure_api else str(uuid.uuid4()) self.app = FastAPI( title="IBM Generative AI Local Model Server", debug=False, openapi_url=None, docs_url=None, redoc_url=None, ) self.app.add_middleware(ApiErrorMiddleware) self.router = APIRouter() self.router.add_api_route("/v2/text/tokenization", self._route_text_tokenization, methods=["POST"]) self.router.add_api_route("/v2/text/generation", self._route_text_generation, methods=["POST"]) self.router.add_api_route("/v2/text/generation/limits", self._route_generation_limits, methods=["GET"]) self.app.include_router(self.router) self.app.add_middleware(ApiAuthMiddleware, api_key=self.api_key, insecure=insecure_api) self.uvicorn_config = uvicorn.Config(self.app, host=interface, port=port, log_level="error", access_log=False) self.uvicorn_server = uvicorn.Server(self.uvicorn_config) self.endpoint = f"http://{self.interface}:{self.port}" logger.debug(f"{__name__}: API: {self.endpoint}, Insecure: {insecure_api}")
[docs] @contextlib.contextmanager def run_locally(self): thread = threading.Thread(target=self.start_server) thread.start() try: while not self.uvicorn_server.started: time.sleep(1e-3) yield finally: self.uvicorn_server.should_exit = True thread.join()
[docs] def get_credentials(self): return Credentials(api_key=self.api_key, api_endpoint=self.endpoint)
[docs] def start_server(self): logger.debug("Starting server background process") logger.info(f"Credentials: Endpoint: {self.endpoint} API Key: {self.api_key}") logger.info( f"In Python Script: credentials = Credentials(api_key={self.api_key}, api_endpoint={self.endpoint})" ) self.uvicorn_server.run()
[docs] def stop_server(self): logger.debug("Stopping server background process") self.uvicorn_server.should_exit = True while not self.uvicorn_server.started: time.sleep(1e-3)
def _get_model(self, model_id) -> LocalModel: model = self.models.get(model_id) if not model: raise HTTPException( status_code=404, detail=NotFoundResponse( message="Model not found", error="Not Found", extensions={"state": {"model_id": model_id}}, # type: ignore ).model_dump(), ) return model async def _route_text_tokenization(self, body: TextTokenizationCreateRequest) -> TextTokenizationCreateResponse: logger.info(f"Tokenize called: {body}") model = self._get_model(body.model_id) results = [model.tokenize(input_text=input, parameters=body.parameters) for input in cast_list(body.input)] return TextTokenizationCreateResponse( model_id=body.model_id, created_at=datetime.datetime.now().isoformat(), results=results, ) async def _route_text_generation(self, body: TextGenerationCreateRequest) -> TextGenerationCreateResponse: logger.info(f"Text Generation Called: Model: {body.model_id}, Input: {body.input}") model = self._get_model(body.model_id) results = [model.generate(input_text=input, parameters=body.parameters) for input in cast_list(body.input)] return TextGenerationCreateResponse( id=str(uuid.uuid4()), model_id=body.model_id, created_at=datetime.datetime.now(datetime.timezone.utc), results=results, ) async def _route_generation_limits(self): logger.info("Generate Limits Called") response = TextGenerationLimitRetrieveResponse( result=TextGenerationLimit(concurrency=ConcurrencyLimit(limit=100, remaining=100)) ) return response