Coverage for mcpgateway / routers / cancellation_router.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2# mcpgateway/routers/cancellation_router.py 

3"""Location: ./mcpgateway/routers/cancellation_router.py 

4Copyright 2025 

5SPDX-License-Identifier: Apache-2.0 

6Authors: Mihai Criveti 

7 

8Cancellation router to support gateway-authoritative cancellation actions. 

9 

10Endpoints: 

11- POST /cancellation/cancel -> Request cancellation for a run/requestId 

12- GET /cancellation/status/{request_id} -> Get status for a registered run 

13 

14Security: endpoints require RBAC permission `admin.system_config` by default. 

15""" 

16 

17# Standard 

18from typing import Optional 

19 

20# Third-Party 

21from fastapi import APIRouter, Depends, HTTPException, status 

22from pydantic import BaseModel, ConfigDict, Field 

23 

24# First-Party 

25import mcpgateway.main as main_module 

26from mcpgateway.middleware.rbac import get_current_user_with_permissions, require_permission 

27from mcpgateway.services.cancellation_service import cancellation_service 

28from mcpgateway.services.logging_service import LoggingService 

29 

30# Initialize logging 

31logging_service = LoggingService() 

32logger = logging_service.get_logger(__name__) 

33 

34router = APIRouter(prefix="/cancellation", tags=["Cancellation"]) 

35 

36 

37class CancelRequest(BaseModel): 

38 """ 

39 Request model for cancelling a run/requestId. 

40 

41 Attributes: 

42 request_id: The ID of the request to cancel. 

43 reason: Optional reason for cancellation. 

44 """ 

45 

46 model_config = ConfigDict(populate_by_name=True) 

47 

48 request_id: str = Field(..., alias="requestId") 

49 reason: Optional[str] = None 

50 

51 

52class CancelResponse(BaseModel): 

53 """ 

54 Response model for cancellation requests. 

55 

56 Attributes: 

57 status: Status of the cancellation request ("cancelled" or "queued"). 

58 request_id: The ID of the request that was cancelled. 

59 reason: Optional reason for cancellation. 

60 """ 

61 

62 model_config = ConfigDict(populate_by_name=True) 

63 

64 status: str # "cancelled" | "queued" 

65 request_id: str = Field(..., alias="requestId") 

66 reason: Optional[str] = None 

67 

68 

69@router.post("/cancel", response_model=CancelResponse) 

70@require_permission("admin.system_config") 

71async def cancel_run(payload: CancelRequest, _user=Depends(get_current_user_with_permissions)) -> CancelResponse: 

72 """ 

73 Cancel a run by its request ID. 

74 

75 Args: 

76 payload: The cancellation request payload. 

77 _user: The current user (dependency injection). 

78 

79 Returns: 

80 CancelResponse: The cancellation response indicating whether the run was cancelled or queued. 

81 """ 

82 request_id = payload.request_id 

83 reason = payload.reason 

84 

85 # Try local cancellation first 

86 local_cancelled = await cancellation_service.cancel_run(request_id, reason=reason) 

87 

88 # Build MCP-style notification to broadcast to sessions (servers/peers) 

89 notification = {"jsonrpc": "2.0", "method": "notifications/cancelled", "params": {"requestId": request_id, "reason": reason}} 

90 

91 # Broadcast best-effort to all sessions 

92 try: 

93 session_ids = await main_module.session_registry.get_all_session_ids() 

94 for sid in session_ids: 

95 try: 

96 await main_module.session_registry.broadcast(sid, notification) 

97 except Exception as e: 

98 # Per-session errors are non-fatal for cancellation (best-effort) 

99 logger.warning(f"Failed to broadcast cancellation notification to session {sid}: {e}") 

100 except Exception as e: 

101 # Continue silently if we cannot enumerate sessions 

102 logger.warning(f"Failed to enumerate sessions for cancellation notification: {e}") 

103 

104 return CancelResponse(status=("cancelled" if local_cancelled else "queued"), request_id=request_id, reason=reason) 

105 

106 

107@router.get("/status/{request_id}") 

108@require_permission("admin.system_config") 

109async def get_status(request_id: str, _user=Depends(get_current_user_with_permissions)): 

110 """ 

111 Get the status of a run by its request ID. 

112 

113 Args: 

114 request_id: The ID of the request to get the status for. 

115 _user: The current user (dependency injection). 

116 

117 Returns: 

118 dict: The status dictionary for the run (e.g. keys: 'name', 'registered_at', 'cancelled'). 

119 

120 Raises: 

121 HTTPException: If the run is not found. 

122 """ 

123 if not await cancellation_service.is_registered(request_id): 

124 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") 

125 status_obj = await cancellation_service.get_status(request_id) 

126 if status_obj is None: 

127 raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Run not found") 

128 # Filter out non-serializable fields (cancel_callback is a function reference) 

129 return {k: v for k, v in status_obj.items() if k != "cancel_callback"}