Coverage for mcpgateway / routers / cancellation_router.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2# mcpgateway/routers/cancellation_router.py 

3"""Location: ./mcpgateway/routers/cancellation_router.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Cancellation router to support gateway-authoritative cancellation actions. 

9 

10Endpoints: 

11- POST /cancellation/cancel -> Request cancellation for a run/requestId 

12- GET /cancellation/status/{request_id} -> Get status for a registered run 

13 

14Security: endpoints require RBAC permission `admin.system_config` by default. 

15""" 

16# Standard 

17from typing import Optional 

18 

19# Third-Party 

20from fastapi import APIRouter, Depends, HTTPException, status 

21from pydantic import BaseModel, ConfigDict, Field 

22 

23# First-Party 

24import mcpgateway.main as main_module 

25from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission 

26from mcpgateway.services.cancellation_service import cancellation_service 

27from mcpgateway.services.logging_service import LoggingService 

28 

29# Initialize logging 

30logging_service = LoggingService() 

31logger = logging_service.get_logger(__name__) 

32 

33router = APIRouter(prefix="/cancellation", tags=["Cancellation"]) 

34 

35 

36class CancelRequest(BaseModel): 

37 """ 

38 Request model for cancelling a run/requestId. 

39 

40 Attributes: 

41 request_id: The ID of the request to cancel. 

42 reason: Optional reason for cancellation. 

43 """ 

44 

45 model_config = ConfigDict(populate_by_name=True) 

46 

47 request_id: str = Field(..., alias="requestId") 

48 reason: Optional[str] = None 

49 

50 

51class CancelResponse(BaseModel): 

52 """ 

53 Response model for cancellation requests. 

54 

55 Attributes: 

56 status: Status of the cancellation request ("cancelled" or "queued"). 

57 request_id: The ID of the request that was cancelled. 

58 reason: Optional reason for cancellation. 

59 """ 

60 

61 model_config = ConfigDict(populate_by_name=True) 

62 

63 status: str # "cancelled" | "queued" 

64 request_id: str = Field(..., alias="requestId") 

65 reason: Optional[str] = None 

66 

67 

68@router.post("/cancel", response_model=CancelResponse) 

69@require_permission("admin.system_config") 

70async def cancel_run(payload: CancelRequest, _user=Depends(get_current_user_with_permissions)) -> CancelResponse: 

71 """ 

72 Cancel a run by its request ID. 

73 

74 Args: 

75 payload: The cancellation request payload. 

76 _user: The current user (dependency injection). 

77 

78 Returns: 

79 CancelResponse: The cancellation response indicating whether the run was cancelled or queued. 

80 """ 

81 request_id = payload.request_id 

82 reason = payload.reason 

83 

84 # Try local cancellation first 

85 local_cancelled = await cancellation_service.cancel_run(request_id, reason=reason) 

86 

87 # Build MCP-style notification to broadcast to sessions (servers/peers) 

88 notification = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": request_id, "reason": reason}} 

89 

90 # Broadcast best-effort to all sessions 

91 try: 

92 session_ids = await main_module.session_registry.get_all_session_ids() 

93 for sid in session_ids: 

94 try: 

95 await main_module.session_registry.broadcast(sid, notification) 

96 except Exception as e: 

97 # Per-session errors are non-fatal for cancellation (best-effort) 

98 logger.warning(f"Failed to broadcast cancellation notification to session {sid}: {e}") 

99 except Exception as e: 

100 # Continue silently if we cannot enumerate sessions 

101 logger.warning(f"Failed to enumerate sessions for cancellation notification: {e}") 

102 

103 return CancelResponse(status=("cancelled" if local_cancelled else "queued"), request_id=request_id, reason=reason) 

104 

105 

106@router.get("/status/{request_id}") 

107@require_permission("admin.system_config") 

108async def get_status(request_id: str, _user=Depends(get_current_user_with_permissions)): 

109 """ 

110 Get the status of a run by its request ID. 

111 

112 Args: 

113 request_id: The ID of the request to get the status for. 

114 _user: The current user (dependency injection). 

115 

116 Returns: 

117 dict: The status dictionary for the run (e.g. keys: 'name', 'registered_at', 'cancelled'). 

118 

119 Raises: 

120 HTTPException: If the run is not found. 

121 """ 

122 if not await cancellation_service.is_registered(request_id): 

123 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") 

124 status_obj = await cancellation_service.get_status(request_id) 

125 if status_obj is None: 

126 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") 

127 # Filter out non-serializable fields (cancel_callback is a function reference) 

128 return {k: v for k, v in status_obj.items() if k != "cancel_callback"}