Coverage for mcpgateway / services / import_service.py: 100%
833 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-03-09 03:05 +0000
1# -*- coding: utf-8 -*-
2# pylint: disable=import-outside-toplevel,no-name-in-module
3"""Location: ./mcpgateway/services/import_service.py
4Copyright 2025
5SPDX-License-Identifier: Apache-2.0
6Authors: Mihai Criveti
8Import Service Implementation.
9This module implements comprehensive configuration import functionality according to the import specification.
10It handles:
11- Import file validation and schema compliance
12- Entity creation and updates with conflict resolution
13- Dependency resolution and processing order
14- Authentication data decryption and re-encryption
15- Dry-run functionality for validation
16- Cross-environment key rotation support
17- Import status tracking and progress reporting
18"""
20# Standard
21import base64
22from datetime import datetime, timedelta, timezone
23from enum import Enum
24import logging
25from typing import Any, Dict, List, Optional
26import uuid
28# Third-Party
29from sqlalchemy.orm import Session
31# First-Party
32from mcpgateway.db import A2AAgent, EmailUser, Gateway, Prompt, Resource, Server, Tool
33from mcpgateway.schemas import AuthenticationValues, GatewayCreate, GatewayUpdate, PromptCreate, PromptUpdate, ResourceCreate, ResourceUpdate, ServerCreate, ServerUpdate, ToolCreate, ToolUpdate
34from mcpgateway.services.gateway_service import GatewayNameConflictError
35from mcpgateway.services.prompt_service import PromptNameConflictError
36from mcpgateway.services.resource_service import ResourceURIConflictError
37from mcpgateway.services.server_service import ServerNameConflictError
38from mcpgateway.services.tool_service import ToolNameConflictError
39from mcpgateway.utils.services_auth import decode_auth, encode_auth
41logger = logging.getLogger(__name__)
44class ConflictStrategy(str, Enum):
45 """Strategies for handling conflicts during import.
47 Examples:
48 >>> ConflictStrategy.SKIP.value
49 'skip'
50 >>> ConflictStrategy.UPDATE.value
51 'update'
52 >>> ConflictStrategy.RENAME.value
53 'rename'
54 >>> ConflictStrategy.FAIL.value
55 'fail'
56 >>> ConflictStrategy("update")
57 <ConflictStrategy.UPDATE: 'update'>
58 """
60 SKIP = "skip"
61 UPDATE = "update"
62 RENAME = "rename"
63 FAIL = "fail"
66class ImportError(Exception): # pylint: disable=redefined-builtin
67 """Base class for import-related errors.
69 Examples:
70 >>> error = ImportError("Something went wrong")
71 >>> str(error)
72 'Something went wrong'
73 >>> isinstance(error, Exception)
74 True
75 """
78class ImportValidationError(ImportError):
79 """Raised when import data validation fails.
81 Examples:
82 >>> error = ImportValidationError("Invalid schema")
83 >>> str(error)
84 'Invalid schema'
85 >>> isinstance(error, ImportError)
86 True
87 """
90class ImportConflictError(ImportError):
91 """Raised when import conflicts cannot be resolved.
93 Examples:
94 >>> error = ImportConflictError("Name conflict: tool_name")
95 >>> str(error)
96 'Name conflict: tool_name'
97 >>> isinstance(error, ImportError)
98 True
99 """
102class ImportStatus:
103 """Tracks the status of an import operation."""
105 def __init__(self, import_id: str):
106 """Initialize import status tracking.
108 Args:
109 import_id: Unique identifier for the import operation
111 Examples:
112 >>> status = ImportStatus("import_123")
113 >>> status.import_id
114 'import_123'
115 >>> status.status
116 'pending'
117 >>> status.total_entities
118 0
119 """
120 self.import_id = import_id
121 self.status = "pending"
122 self.total_entities = 0
123 self.processed_entities = 0
124 self.created_entities = 0
125 self.updated_entities = 0
126 self.skipped_entities = 0
127 self.failed_entities = 0
128 self.errors: List[str] = []
129 self.warnings: List[str] = []
130 self.started_at = datetime.now(timezone.utc)
131 self.completed_at: Optional[datetime] = None
133 def to_dict(self) -> Dict[str, Any]:
134 """Convert status to dictionary for API responses.
136 Returns:
137 Dictionary representation of import status
138 """
139 return {
140 "import_id": self.import_id,
141 "status": self.status,
142 "progress": {
143 "total": self.total_entities,
144 "processed": self.processed_entities,
145 "created": self.created_entities,
146 "updated": self.updated_entities,
147 "skipped": self.skipped_entities,
148 "failed": self.failed_entities,
149 },
150 "errors": self.errors,
151 "warnings": self.warnings,
152 "started_at": self.started_at.isoformat(),
153 "completed_at": self.completed_at.isoformat() if self.completed_at else None,
154 }
157class ImportService:
158 """Service for importing ContextForge configuration and data.
160 This service provides comprehensive import functionality including:
161 - Import file validation and schema compliance
162 - Entity creation and updates with conflict resolution
163 - Dependency resolution and correct processing order
164 - Secure authentication data handling with re-encryption
165 - Dry-run capabilities for validation without changes
166 - Progress tracking and status reporting
167 - Cross-environment key rotation support
168 """
170 def __init__(self):
171 """Initialize the import service with required dependencies.
173 Creates instances of all entity services and initializes the active imports tracker.
175 Examples:
176 >>> service = ImportService()
177 >>> service.active_imports
178 {}
179 >>> hasattr(service, 'tool_service')
180 True
181 >>> hasattr(service, 'gateway_service')
182 True
183 """
184 # Prefer globally-initialized singletons from mcpgateway.main to ensure
185 # services share initialized EventService/Redis clients. Import lazily
186 # to avoid circular import at module load time. Fall back to local
187 # instances if singletons are not available (tests, isolated usage).
188 # Use globally-exported singletons from service modules so they
189 # share initialized EventService/Redis clients created at app startup.
190 # First-Party
191 from mcpgateway.services.gateway_service import gateway_service
192 from mcpgateway.services.prompt_service import prompt_service
193 from mcpgateway.services.resource_service import resource_service
194 from mcpgateway.services.root_service import root_service
195 from mcpgateway.services.server_service import server_service
196 from mcpgateway.services.tool_service import tool_service
198 self.gateway_service = gateway_service
199 self.tool_service = tool_service
200 self.resource_service = resource_service
201 self.prompt_service = prompt_service
202 self.server_service = server_service
203 self.root_service = root_service
204 self.active_imports: Dict[str, ImportStatus] = {}
206 async def initialize(self) -> None:
207 """Initialize the import service."""
208 logger.info("Import service initialized")
210 async def shutdown(self) -> None:
211 """Shutdown the import service."""
212 logger.info("Import service shutdown")
214 def validate_import_data(self, import_data: Dict[str, Any]) -> None:
215 """Validate import data against the expected schema.
217 Args:
218 import_data: The import data to validate
220 Raises:
221 ImportValidationError: If validation fails
223 Examples:
224 >>> service = ImportService()
225 >>> valid_data = {
226 ... "version": "2025-03-26",
227 ... "exported_at": "2025-01-01T00:00:00Z",
228 ... "entities": {"tools": []}
229 ... }
230 >>> service.validate_import_data(valid_data) # Should not raise
232 >>> invalid_data = {"missing": "version"}
233 >>> try:
234 ... service.validate_import_data(invalid_data)
235 ... except ImportValidationError as e:
236 ... "Missing required field" in str(e)
237 True
238 """
239 logger.debug("Validating import data structure")
241 # Check required top-level fields
242 required_fields = ["version", "exported_at", "entities"]
243 for field in required_fields:
244 if field not in import_data:
245 raise ImportValidationError(f"Missing required field: {field}")
247 # Validate version compatibility
248 if not import_data.get("version"):
249 raise ImportValidationError("Version field cannot be empty")
251 # Validate entities structure
252 entities = import_data.get("entities", {})
253 if not isinstance(entities, dict):
254 raise ImportValidationError("Entities must be a dictionary")
256 # Validate each entity type
257 valid_entity_types = ["tools", "gateways", "servers", "prompts", "resources", "roots"]
258 for entity_type, entity_list in entities.items():
259 if entity_type not in valid_entity_types:
260 raise ImportValidationError(f"Unknown entity type: {entity_type}")
262 if not isinstance(entity_list, list):
263 raise ImportValidationError(f"Entity type '{entity_type}' must be a list")
265 # Validate individual entities
266 for i, entity in enumerate(entity_list):
267 if not isinstance(entity, dict):
268 raise ImportValidationError(f"Entity {i} in '{entity_type}' must be a dictionary")
270 # Check required fields based on entity type
271 self._validate_entity_fields(entity_type, entity, i)
273 logger.debug("Import data validation passed")
275 def _validate_entity_fields(self, entity_type: str, entity: Dict[str, Any], index: int) -> None:
276 """Validate required fields for a specific entity type.
278 Args:
279 entity_type: Type of entity (tools, gateways, etc.)
280 entity: Entity data dictionary
281 index: Index of entity in list for error messages
283 Raises:
284 ImportValidationError: If required fields are missing
285 """
286 required_fields = {
287 "tools": ["name", "url", "integration_type"],
288 "gateways": ["name", "url"],
289 "servers": ["name"],
290 "prompts": ["name", "template"],
291 "resources": ["name", "uri"],
292 "roots": ["uri", "name"],
293 }
295 if entity_type in required_fields:
296 for field in required_fields[entity_type]:
297 if field not in entity:
298 raise ImportValidationError(f"Entity {index} in '{entity_type}' missing required field: {field}")
300 async def import_configuration(
301 self,
302 db: Session,
303 import_data: Dict[str, Any],
304 conflict_strategy: ConflictStrategy = ConflictStrategy.UPDATE,
305 dry_run: bool = False,
306 rekey_secret: Optional[str] = None,
307 imported_by: str = "system",
308 selected_entities: Optional[Dict[str, List[str]]] = None,
309 ) -> ImportStatus:
310 """Import configuration data with conflict resolution.
312 Args:
313 db: Database session
314 import_data: The validated import data
315 conflict_strategy: How to handle naming conflicts
316 dry_run: If True, validate but don't make changes
317 rekey_secret: New encryption secret for cross-environment imports
318 imported_by: Username of the person performing the import
319 selected_entities: Dict of entity types to specific entity names/ids to import
321 Returns:
322 ImportStatus: Status object tracking import progress and results
324 Raises:
325 ImportError: If import fails
326 """
327 import_id = str(uuid.uuid4())
328 status = ImportStatus(import_id)
329 self.active_imports[import_id] = status
331 try:
332 logger.info(f"Starting configuration import {import_id} by {imported_by} (dry_run={dry_run})")
334 # Validate import data
335 self.validate_import_data(import_data)
337 # Calculate total entities to process
338 entities = import_data.get("entities", {})
339 status.total_entities = self._calculate_total_entities(entities, selected_entities)
341 status.status = "running"
343 # Process entities in dependency order
344 processing_order = ["roots", "gateways", "tools", "resources", "prompts", "servers"]
346 for entity_type in processing_order:
347 if entity_type in entities:
348 await self._process_entities(db, entity_type, entities[entity_type], conflict_strategy, dry_run, rekey_secret, status, selected_entities, imported_by)
349 # Flush after each entity type to make records visible for associations
350 if not dry_run:
351 db.flush()
353 # Assign all imported items to user's team with team visibility (after all entities processed)
354 if not dry_run:
355 await self._assign_imported_items_to_team(db, imported_by, imported_after=status.started_at)
357 # Mark as completed
358 status.status = "completed"
359 status.completed_at = datetime.now(timezone.utc)
361 logger.info(f"Import {import_id} completed: created={status.created_entities}, updated={status.updated_entities}, skipped={status.skipped_entities}, failed={status.failed_entities}")
363 return status
365 except Exception as e:
366 status.status = "failed"
367 status.completed_at = datetime.now(timezone.utc)
368 status.errors.append(f"Import failed: {str(e)}")
369 logger.error(f"Import {import_id} failed: {str(e)}")
370 raise ImportError(f"Import failed: {str(e)}")
372 def _get_entity_identifier(self, entity_type: str, entity: Dict[str, Any]) -> str:
373 """Get the unique identifier for an entity based on its type.
375 Args:
376 entity_type: Type of entity
377 entity: Entity data dictionary
379 Returns:
380 Unique identifier string for the entity
382 Examples:
383 >>> service = ImportService()
384 >>> tool_entity = {"name": "my_tool", "url": "https://example.com"}
385 >>> service._get_entity_identifier("tools", tool_entity)
386 'my_tool'
388 >>> resource_entity = {"name": "my_resource", "uri": "/api/data"}
389 >>> service._get_entity_identifier("resources", resource_entity)
390 '/api/data'
392 >>> root_entity = {"name": "workspace", "uri": "file:///workspace"}
393 >>> service._get_entity_identifier("roots", root_entity)
394 'file:///workspace'
396 >>> unknown_entity = {"data": "test"}
397 >>> service._get_entity_identifier("unknown", unknown_entity)
398 ''
399 """
400 if entity_type in ["tools", "gateways", "servers", "prompts"]:
401 return entity.get("name", "")
402 if entity_type == "resources":
403 return entity.get("uri", "")
404 if entity_type == "roots":
405 return entity.get("uri", "")
406 return ""
408 def _calculate_total_entities(self, entities: Dict[str, List[Dict[str, Any]]], selected_entities: Optional[Dict[str, List[str]]]) -> int:
409 """Calculate total entities to process based on selection criteria.
411 Args:
412 entities: Dictionary of entities from import data
413 selected_entities: Optional entity selection filter
415 Returns:
416 Total number of entities to process
418 Examples:
419 No selection counts all entities:
420 >>> svc = ImportService()
421 >>> entities = {
422 ... 'tools': [{"name": "t1"}, {"name": "t2"}],
423 ... 'resources': [{"uri": "/r1"}],
424 ... }
425 >>> svc._calculate_total_entities(entities, selected_entities=None)
426 3
428 Selection for a subset by name/identifier:
429 >>> selected = {'tools': ['t2'], 'resources': ['/r1']}
430 >>> svc._calculate_total_entities(entities, selected)
431 2
433 Selection for only a type (empty list means all of that type):
434 >>> selected = {'tools': []}
435 >>> svc._calculate_total_entities(entities, selected)
436 2
437 """
438 if selected_entities:
439 total = 0
440 for entity_type, entity_list in entities.items():
441 if entity_type in selected_entities:
442 selected_names = selected_entities[entity_type]
443 if selected_names:
444 # Count entities that match selection
445 for entity in entity_list:
446 entity_name = self._get_entity_identifier(entity_type, entity)
447 if entity_name in selected_names:
448 total += 1
449 else:
450 total += len(entity_list)
451 return total
452 return sum(len(entity_list) for entity_list in entities.values())
454 async def _process_entities(
455 self,
456 db: Session,
457 entity_type: str,
458 entity_list: List[Dict[str, Any]],
459 conflict_strategy: ConflictStrategy,
460 dry_run: bool,
461 rekey_secret: Optional[str],
462 status: ImportStatus,
463 selected_entities: Optional[Dict[str, List[str]]],
464 imported_by: str,
465 ) -> None:
466 """Process a list of entities of a specific type using bulk operations.
468 This method now uses bulk registration for tools, resources, and prompts
469 to achieve 10-50x performance improvements over individual processing.
471 Args:
472 db: Database session
473 entity_type: Type of entities being processed
474 entity_list: List of entity data dictionaries
475 conflict_strategy: How to handle naming conflicts
476 dry_run: Whether this is a dry run
477 rekey_secret: New encryption secret if re-keying
478 status: Import status tracker
479 selected_entities: Optional entity selection filter
480 imported_by: Username of the person performing the import
481 """
482 logger.debug(f"Processing {len(entity_list)} {entity_type} entities")
484 # Filter entities based on selection
485 filtered_entities = []
486 for entity_data in entity_list:
487 # Check if this entity is selected for import
488 if selected_entities and entity_type in selected_entities:
489 selected_names = selected_entities[entity_type]
490 if selected_names: # If specific entities are selected
491 entity_name = self._get_entity_identifier(entity_type, entity_data)
492 if entity_name not in selected_names:
493 continue # Skip this entity
495 # Handle authentication re-encryption if needed
496 if rekey_secret and self._has_auth_data(entity_data):
497 entity_data = self._rekey_auth_data(entity_data, rekey_secret)
499 # Never trust imported ownership/team fields from payload.
500 entity_data = self._sanitize_import_scope_fields(entity_type, entity_data)
502 filtered_entities.append(entity_data)
504 if not filtered_entities:
505 logger.debug(f"No {entity_type} entities to process after filtering")
506 return
508 # Use bulk operations for tools, resources, and prompts
509 if entity_type == "tools":
510 await self._process_tools_bulk(db, filtered_entities, conflict_strategy, dry_run, status, imported_by)
511 elif entity_type == "resources":
512 await self._process_resources_bulk(db, filtered_entities, conflict_strategy, dry_run, status, imported_by)
513 elif entity_type == "prompts":
514 await self._process_prompts_bulk(db, filtered_entities, conflict_strategy, dry_run, status, imported_by)
515 else:
516 # Fall back to individual processing for other entity types
517 for entity_data in filtered_entities:
518 try:
519 await self._process_single_entity(db, entity_type, entity_data, conflict_strategy, dry_run, status, imported_by)
520 status.processed_entities += 1
521 except Exception as e:
522 status.failed_entities += 1
523 status.errors.append(f"Failed to process {entity_type} entity: {str(e)}")
524 logger.error(f"Failed to process {entity_type} entity: {str(e)}")
526 def _has_auth_data(self, entity_data: Dict[str, Any]) -> bool:
527 """Check if entity has authentication data that needs re-encryption.
529 Args:
530 entity_data: Entity data dictionary
532 Returns:
533 True if entity has auth data, False otherwise
535 Examples:
536 >>> service = ImportService()
537 >>> entity_with_auth = {"name": "test", "auth_value": "encrypted_data"}
538 >>> bool(service._has_auth_data(entity_with_auth))
539 True
541 >>> entity_without_auth = {"name": "test"}
542 >>> service._has_auth_data(entity_without_auth)
543 False
545 >>> entity_empty_auth = {"name": "test", "auth_value": ""}
546 >>> bool(service._has_auth_data(entity_empty_auth))
547 False
549 >>> entity_none_auth = {"name": "test", "auth_value": None}
550 >>> bool(service._has_auth_data(entity_none_auth))
551 False
552 """
553 return "auth_value" in entity_data and entity_data.get("auth_value")
555 def _sanitize_import_scope_fields(self, entity_type: str, entity_data: Dict[str, Any]) -> Dict[str, Any]:
556 """Drop untrusted ownership scope fields from imported entity payloads.
558 Import ownership/team assignment is derived from the authenticated
559 importer context, not from import file metadata.
561 Args:
562 entity_type: Entity family being imported.
563 entity_data: Source entity payload from import file.
565 Returns:
566 Sanitized entity payload copy.
567 """
568 scoped_entity_types = {"tools", "gateways", "servers", "resources", "prompts", "a2a_agents"}
569 if entity_type not in scoped_entity_types:
570 return entity_data
572 sanitized = dict(entity_data)
573 sanitized.pop("team_id", None)
574 sanitized.pop("owner_email", None)
575 sanitized.pop("visibility", None)
576 sanitized.pop("team", None)
577 return sanitized
579 def _rekey_auth_data(self, entity_data: Dict[str, Any], new_secret: str) -> Dict[str, Any]:
580 """Re-encrypt authentication data with a new secret key.
582 Args:
583 entity_data: Entity data dictionary
584 new_secret: New encryption secret
586 Returns:
587 Updated entity data with re-encrypted auth
589 Raises:
590 ImportError: If re-encryption fails
592 Examples:
593 Returns original entity when no auth data present:
594 >>> svc = ImportService()
595 >>> svc._has_auth_data({'name': 'x'})
596 False
597 >>> svc._rekey_auth_data({'name': 'x'}, 'new')
598 {'name': 'x'}
600 Rekeys when auth data is present (encode/decode patched):
601 >>> from unittest.mock import patch
602 >>> data = {'name': 'x', 'auth_value': 'enc_old'}
603 >>> with patch('mcpgateway.services.import_service.decode_auth', return_value='plain'):
604 ... with patch('mcpgateway.services.import_service.encode_auth', return_value='enc_new'):
605 ... result = svc._rekey_auth_data(dict(data), 'new-secret')
606 >>> result['auth_value']
607 'enc_new'
608 """
609 if not self._has_auth_data(entity_data):
610 return entity_data
612 try:
613 # Decrypt with current key, re-encrypt with new key.
614 # Pass secrets explicitly to avoid mutating global settings state,
615 # which would corrupt concurrent encode/decode operations.
616 old_auth_value = entity_data["auth_value"]
617 decrypted_auth = decode_auth(old_auth_value)
618 new_auth_value = encode_auth(decrypted_auth, secret=new_secret)
619 entity_data["auth_value"] = new_auth_value
621 logger.debug("Successfully re-keyed authentication data")
622 return entity_data
624 except Exception as e:
625 raise ImportError(f"Failed to re-key authentication data: {str(e)}")
627 async def _process_single_entity(
628 self, db: Session, entity_type: str, entity_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str
629 ) -> None:
630 """Process a single entity with conflict resolution.
632 Args:
633 db: Database session
634 entity_type: Type of entity
635 entity_data: Entity data dictionary
636 conflict_strategy: How to handle conflicts
637 dry_run: Whether this is a dry run
638 status: Import status tracker
639 imported_by: Username of the person performing the import
641 Raises:
642 ImportError: If processing fails
643 """
644 try:
645 if entity_type == "tools":
646 await self._process_tool(db, entity_data, conflict_strategy, dry_run, status)
647 elif entity_type == "gateways":
648 await self._process_gateway(db, entity_data, conflict_strategy, dry_run, status, imported_by)
649 elif entity_type == "servers":
650 await self._process_server(db, entity_data, conflict_strategy, dry_run, status, imported_by)
651 elif entity_type == "prompts":
652 await self._process_prompt(db, entity_data, conflict_strategy, dry_run, status)
653 elif entity_type == "resources":
654 await self._process_resource(db, entity_data, conflict_strategy, dry_run, status)
655 elif entity_type == "roots":
656 await self._process_root(entity_data, conflict_strategy, dry_run, status)
658 except Exception as e:
659 raise ImportError(f"Failed to process {entity_type}: {str(e)}")
661 async def _process_tool(self, db: Session, tool_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus) -> None:
662 """Process a tool entity.
664 Args:
665 db: Database session
666 tool_data: Tool data dictionary
667 conflict_strategy: How to handle conflicts
668 dry_run: Whether this is a dry run
669 status: Import status tracker
671 Raises:
672 ImportError: If processing fails
673 ImportConflictError: If conflict cannot be resolved
674 """
675 tool_name = tool_data["name"]
677 if dry_run:
678 status.warnings.append(f"Would import tool: {tool_name}")
679 return
681 try:
682 # Convert to ToolCreate schema
683 create_data = self._convert_to_tool_create(tool_data)
685 # Try to create the tool
686 try:
687 await self.tool_service.register_tool(db, create_data)
688 status.created_entities += 1
689 logger.debug(f"Created tool: {tool_name}")
691 except ToolNameConflictError:
692 # Handle conflict based on strategy
693 if conflict_strategy == ConflictStrategy.SKIP:
694 status.skipped_entities += 1
695 status.warnings.append(f"Skipped existing tool: {tool_name}")
696 elif conflict_strategy == ConflictStrategy.UPDATE:
697 # For conflict resolution, we need to find existing tool ID
698 # This is a simplified approach - in practice you'd query the database
699 try:
700 # Try to get tools and find by name
701 tools, _ = await self.tool_service.list_tools(db, include_inactive=True)
702 existing_tool = next((t for t in tools if t.original_name == tool_name), None)
703 if existing_tool:
704 update_data = self._convert_to_tool_update(tool_data)
705 await self.tool_service.update_tool(db, existing_tool.id, update_data)
706 status.updated_entities += 1
707 logger.debug(f"Updated tool: {tool_name}")
708 else:
709 status.warnings.append(f"Could not find existing tool to update: {tool_name}")
710 status.skipped_entities += 1
711 except Exception as update_error:
712 logger.warning(f"Failed to update tool {tool_name}: {str(update_error)}")
713 status.warnings.append(f"Could not update tool {tool_name}: {str(update_error)}")
714 status.skipped_entities += 1
715 elif conflict_strategy == ConflictStrategy.RENAME:
716 # Rename and create
717 new_name = f"{tool_name}_imported_{int(datetime.now().timestamp())}"
718 create_data.name = new_name
719 await self.tool_service.register_tool(db, create_data)
720 status.created_entities += 1
721 status.warnings.append(f"Renamed tool {tool_name} to {new_name}")
722 elif conflict_strategy == ConflictStrategy.FAIL:
723 raise ImportConflictError(f"Tool name conflict: {tool_name}")
725 except Exception as e:
726 raise ImportError(f"Failed to process tool {tool_name}: {str(e)}")
728 async def _process_gateway(self, db: Session, gateway_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str) -> None:
729 """Process a gateway entity.
731 Args:
732 db: Database session
733 gateway_data: Gateway data dictionary
734 conflict_strategy: How to handle conflicts
735 dry_run: Whether this is a dry run
736 status: Import status tracker
737 imported_by: Username of the person performing the import
739 Raises:
740 ImportError: If processing fails
741 ImportConflictError: If conflict cannot be resolved
742 """
743 gateway_name = gateway_data["name"]
745 if dry_run is True:
746 status.warnings.append(f"Would import gateway: {gateway_name}")
747 return
749 try:
750 # Convert to GatewayCreate schema
751 create_data = self._convert_to_gateway_create(gateway_data)
753 try:
754 await self.gateway_service.register_gateway(db, create_data, created_by=imported_by, created_via="import")
755 status.created_entities += 1
756 logger.debug(f"Created gateway: {gateway_name}")
758 except GatewayNameConflictError:
759 if conflict_strategy == ConflictStrategy.SKIP:
760 status.skipped_entities += 1
761 status.warnings.append(f"Skipped existing gateway: {gateway_name}")
762 elif conflict_strategy == ConflictStrategy.UPDATE:
763 try:
764 # Find existing gateway by name
765 gateways, _ = await self.gateway_service.list_gateways(db, include_inactive=True)
766 existing_gateway = next((g for g in gateways if g.name == gateway_name), None)
767 if existing_gateway:
768 update_data = self._convert_to_gateway_update(gateway_data)
769 await self.gateway_service.update_gateway(db, existing_gateway.id, update_data)
770 status.updated_entities += 1
771 logger.debug(f"Updated gateway: {gateway_name}")
772 else:
773 status.warnings.append(f"Could not find existing gateway to update: {gateway_name}")
774 status.skipped_entities += 1
775 except Exception as update_error:
776 logger.warning(f"Failed to update gateway {gateway_name}: {str(update_error)}")
777 status.warnings.append(f"Could not update gateway {gateway_name}: {str(update_error)}")
778 status.skipped_entities += 1
779 elif conflict_strategy == ConflictStrategy.RENAME:
780 new_name = f"{gateway_name}_imported_{int(datetime.now().timestamp())}"
781 create_data.name = new_name
782 await self.gateway_service.register_gateway(db, create_data, created_by=imported_by, created_via="import")
783 status.created_entities += 1
784 status.warnings.append(f"Renamed gateway {gateway_name} to {new_name}")
785 elif conflict_strategy == ConflictStrategy.FAIL:
786 raise ImportConflictError(f"Gateway name conflict: {gateway_name}")
788 except Exception as e:
789 raise ImportError(f"Failed to process gateway {gateway_name}: {str(e)}")
791 async def _process_server(self, db: Session, server_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str) -> None:
792 """Process a server entity.
794 Args:
795 db: Database session
796 server_data: Server data dictionary
797 conflict_strategy: How to handle conflicts
798 dry_run: Whether this is a dry run
799 status: Import status tracker
800 imported_by: Username of the person performing the import
802 Raises:
803 ImportError: If processing fails
804 ImportConflictError: If conflict cannot be resolved
805 """
806 server_name = server_data["name"]
808 if dry_run:
809 status.warnings.append(f"Would import server: {server_name}")
810 return
812 try:
813 create_data = await self._convert_to_server_create(db, server_data)
815 try:
816 await self.server_service.register_server(db, create_data, created_by=imported_by, created_via="import")
817 status.created_entities += 1
818 logger.debug(f"Created server: {server_name}")
820 except ServerNameConflictError:
821 if conflict_strategy == ConflictStrategy.SKIP:
822 status.skipped_entities += 1
823 status.warnings.append(f"Skipped existing server: {server_name}")
824 elif conflict_strategy == ConflictStrategy.UPDATE:
825 try:
826 # Find existing server by name
827 servers = await self.server_service.list_servers(db, include_inactive=True)
828 existing_server = next((s for s in servers if s.name == server_name), None)
829 if existing_server:
830 update_data = await self._convert_to_server_update(db, server_data)
831 await self.server_service.update_server(db, existing_server.id, update_data, imported_by)
832 status.updated_entities += 1
833 logger.debug(f"Updated server: {server_name}")
834 else:
835 status.warnings.append(f"Could not find existing server to update: {server_name}")
836 status.skipped_entities += 1
837 except Exception as update_error:
838 logger.warning(f"Failed to update server {server_name}: {str(update_error)}")
839 status.warnings.append(f"Could not update server {server_name}: {str(update_error)}")
840 status.skipped_entities += 1
841 elif conflict_strategy == ConflictStrategy.RENAME:
842 new_name = f"{server_name}_imported_{int(datetime.now().timestamp())}"
843 create_data.name = new_name
844 await self.server_service.register_server(db, create_data, created_by=imported_by, created_via="import")
845 status.created_entities += 1
846 status.warnings.append(f"Renamed server {server_name} to {new_name}")
847 elif conflict_strategy == ConflictStrategy.FAIL:
848 raise ImportConflictError(f"Server name conflict: {server_name}")
850 except Exception as e:
851 raise ImportError(f"Failed to process server {server_name}: {str(e)}")
853 async def _process_prompt(self, db: Session, prompt_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus) -> None:
854 """Process a prompt entity.
856 Args:
857 db: Database session
858 prompt_data: Prompt data dictionary
859 conflict_strategy: How to handle conflicts
860 dry_run: Whether this is a dry run
861 status: Import status tracker
863 Raises:
864 ImportError: If processing fails
865 ImportConflictError: If conflict cannot be resolved
866 """
867 prompt_name = prompt_data["name"]
869 if dry_run:
870 status.warnings.append(f"Would import prompt: {prompt_name}")
871 return
873 try:
874 create_data = self._convert_to_prompt_create(prompt_data)
876 try:
877 await self.prompt_service.register_prompt(db, create_data)
878 status.created_entities += 1
879 logger.debug(f"Created prompt: {prompt_name}")
881 except PromptNameConflictError:
882 if conflict_strategy == ConflictStrategy.SKIP:
883 status.skipped_entities += 1
884 status.warnings.append(f"Skipped existing prompt: {prompt_name}")
885 elif conflict_strategy == ConflictStrategy.UPDATE:
886 update_data = self._convert_to_prompt_update(prompt_data)
887 await self.prompt_service.update_prompt(db, prompt_name, update_data)
888 status.updated_entities += 1
889 logger.debug(f"Updated prompt: {prompt_name}")
890 elif conflict_strategy == ConflictStrategy.RENAME:
891 new_name = f"{prompt_name}_imported_{int(datetime.now().timestamp())}"
892 create_data.name = new_name
893 await self.prompt_service.register_prompt(db, create_data)
894 status.created_entities += 1
895 status.warnings.append(f"Renamed prompt {prompt_name} to {new_name}")
896 elif conflict_strategy == ConflictStrategy.FAIL:
897 raise ImportConflictError(f"Prompt name conflict: {prompt_name}")
899 except Exception as e:
900 raise ImportError(f"Failed to process prompt {prompt_name}: {str(e)}")
902 async def _process_resource(self, db: Session, resource_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus) -> None:
903 """Process a resource entity.
905 Args:
906 db: Database session
907 resource_data: Resource data dictionary
908 conflict_strategy: How to handle conflicts
909 dry_run: Whether this is a dry run
910 status: Import status tracker
912 Raises:
913 ImportError: If processing fails
914 ImportConflictError: If conflict cannot be resolved
915 """
916 resource_uri = resource_data["uri"]
918 if dry_run:
919 status.warnings.append(f"Would import resource: {resource_uri}")
920 return
922 try:
923 create_data = self._convert_to_resource_create(resource_data)
925 try:
926 await self.resource_service.register_resource(db, create_data)
927 status.created_entities += 1
928 logger.debug(f"Created resource: {resource_uri}")
930 except ResourceURIConflictError:
931 if conflict_strategy == ConflictStrategy.SKIP:
932 status.skipped_entities += 1
933 status.warnings.append(f"Skipped existing resource: {resource_uri}")
934 elif conflict_strategy == ConflictStrategy.UPDATE:
935 update_data = self._convert_to_resource_update(resource_data)
936 await self.resource_service.update_resource(db, resource_uri, update_data)
937 status.updated_entities += 1
938 logger.debug(f"Updated resource: {resource_uri}")
939 elif conflict_strategy == ConflictStrategy.RENAME:
940 new_uri = f"{resource_uri}_imported_{int(datetime.now().timestamp())}"
941 create_data.uri = new_uri
942 await self.resource_service.register_resource(db, create_data)
943 status.created_entities += 1
944 status.warnings.append(f"Renamed resource {resource_uri} to {new_uri}")
945 elif conflict_strategy == ConflictStrategy.FAIL:
946 raise ImportConflictError(f"Resource URI conflict: {resource_uri}")
948 except Exception as e:
949 raise ImportError(f"Failed to process resource {resource_uri}: {str(e)}")
951 async def _process_tools_bulk(self, db: Session, tools_data: List[Dict[str, Any]], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str) -> None:
952 """Process multiple tools using bulk operations.
954 Args:
955 db: Database session
956 tools_data: List of tool data dictionaries
957 conflict_strategy: How to handle conflicts
958 dry_run: Whether this is a dry run
959 status: Import status tracker
960 imported_by: Username of the person performing the import
961 """
962 if dry_run:
963 for tool_data in tools_data:
964 status.warnings.append(f"Would import tool: {tool_data.get('name', 'unknown')}")
965 return
967 try:
968 # Convert all tool data to ToolCreate schemas
969 tools_to_register = []
970 for tool_data in tools_data:
971 try:
972 create_data = self._convert_to_tool_create(tool_data)
973 tools_to_register.append(create_data)
974 except Exception as e:
975 status.failed_entities += 1
976 status.errors.append(f"Failed to convert tool {tool_data.get('name', 'unknown')}: {str(e)}")
977 logger.warning(f"Failed to convert tool data: {str(e)}")
979 if not tools_to_register:
980 return
982 # Use a batch ID so we can scope the post-import fixup below.
983 batch_id = str(uuid.uuid4())
985 # Use bulk registration
986 result = await self.tool_service.register_tools_bulk(
987 db=db,
988 tools=tools_to_register,
989 created_by=imported_by,
990 created_via="import",
991 import_batch_id=batch_id,
992 conflict_strategy=conflict_strategy.value,
993 )
995 # Restore original_description from export data for newly created tools.
996 # register_tools_bulk sets original_description=description, but the
997 # export payload may carry the real upstream original_description.
998 if result.get("created", 0) > 0:
999 orig_desc_map = {d["name"]: d["original_description"] for d in tools_data if d.get("original_description") and d.get("original_description") != d.get("description")}
1000 if orig_desc_map:
1001 # Third-Party
1002 from sqlalchemy import select
1004 for tool_name, orig_desc in orig_desc_map.items():
1005 stmt = select(Tool).where(Tool.original_name == tool_name, Tool.import_batch_id == batch_id)
1006 db_tool = db.execute(stmt).scalar_one_or_none()
1007 if db_tool:
1008 db_tool.original_description = orig_desc
1009 db.commit()
1011 # Update status based on results
1012 status.created_entities += result["created"]
1013 status.updated_entities += result["updated"]
1014 status.skipped_entities += result["skipped"]
1015 status.failed_entities += result["failed"]
1016 status.processed_entities += result["created"] + result["updated"] + result["skipped"]
1018 # Add any errors to status
1019 for error in result.get("errors", []):
1020 status.errors.append(error)
1022 logger.info(f"Bulk processed {len(tools_data)} tools: {result['created']} created, {result['updated']} updated, {result['skipped']} skipped, {result['failed']} failed")
1024 except Exception as e:
1025 status.failed_entities += len(tools_data)
1026 status.errors.append(f"Bulk tool processing failed: {str(e)}")
1027 logger.error(f"Failed to bulk process tools: {str(e)}")
1028 # Don't raise - allow import to continue with other entities
1030 async def _process_resources_bulk(self, db: Session, resources_data: List[Dict[str, Any]], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str) -> None:
1031 """Process multiple resources using bulk operations.
1033 Args:
1034 db: Database session
1035 resources_data: List of resource data dictionaries
1036 conflict_strategy: How to handle conflicts
1037 dry_run: Whether this is a dry run
1038 status: Import status tracker
1039 imported_by: Username of the person performing the import
1040 """
1041 if dry_run:
1042 for resource_data in resources_data:
1043 status.warnings.append(f"Would import resource: {resource_data.get('uri', 'unknown')}")
1044 return
1046 try:
1047 # Convert all resource data to ResourceCreate schemas
1048 resources_to_register = []
1049 for resource_data in resources_data:
1050 try:
1051 create_data = self._convert_to_resource_create(resource_data)
1052 resources_to_register.append(create_data)
1053 except Exception as e:
1054 status.failed_entities += 1
1055 status.errors.append(f"Failed to convert resource {resource_data.get('uri', 'unknown')}: {str(e)}")
1056 logger.warning(f"Failed to convert resource data: {str(e)}")
1058 if not resources_to_register:
1059 return
1061 # Use bulk registration
1062 result = await self.resource_service.register_resources_bulk(
1063 db=db,
1064 resources=resources_to_register,
1065 created_by=imported_by,
1066 created_via="import",
1067 conflict_strategy=conflict_strategy.value,
1068 )
1070 # Update status based on results
1071 status.created_entities += result["created"]
1072 status.updated_entities += result["updated"]
1073 status.skipped_entities += result["skipped"]
1074 status.failed_entities += result["failed"]
1075 status.processed_entities += result["created"] + result["updated"] + result["skipped"]
1077 # Add any errors to status
1078 for error in result.get("errors", []):
1079 status.errors.append(error)
1081 logger.info(f"Bulk processed {len(resources_data)} resources: {result['created']} created, {result['updated']} updated, {result['skipped']} skipped, {result['failed']} failed")
1083 except Exception as e:
1084 status.failed_entities += len(resources_data)
1085 status.errors.append(f"Bulk resource processing failed: {str(e)}")
1086 logger.error(f"Failed to bulk process resources: {str(e)}")
1087 # Don't raise - allow import to continue with other entities
1089 async def _process_prompts_bulk(self, db: Session, prompts_data: List[Dict[str, Any]], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus, imported_by: str) -> None:
1090 """Process multiple prompts using bulk operations.
1092 Args:
1093 db: Database session
1094 prompts_data: List of prompt data dictionaries
1095 conflict_strategy: How to handle conflicts
1096 dry_run: Whether this is a dry run
1097 status: Import status tracker
1098 imported_by: Username of the person performing the import
1099 """
1100 if dry_run:
1101 for prompt_data in prompts_data:
1102 status.warnings.append(f"Would import prompt: {prompt_data.get('name', 'unknown')}")
1103 return
1105 try:
1106 # Convert all prompt data to PromptCreate schemas
1107 prompts_to_register = []
1108 for prompt_data in prompts_data:
1109 try:
1110 create_data = self._convert_to_prompt_create(prompt_data)
1111 prompts_to_register.append(create_data)
1112 except Exception as e:
1113 status.failed_entities += 1
1114 status.errors.append(f"Failed to convert prompt {prompt_data.get('name', 'unknown')}: {str(e)}")
1115 logger.warning(f"Failed to convert prompt data: {str(e)}")
1117 if not prompts_to_register:
1118 return
1120 # Use bulk registration
1121 result = await self.prompt_service.register_prompts_bulk(
1122 db=db,
1123 prompts=prompts_to_register,
1124 created_by=imported_by,
1125 created_via="import",
1126 conflict_strategy=conflict_strategy.value,
1127 )
1129 # Update status based on results
1130 status.created_entities += result["created"]
1131 status.updated_entities += result["updated"]
1132 status.skipped_entities += result["skipped"]
1133 status.failed_entities += result["failed"]
1134 status.processed_entities += result["created"] + result["updated"] + result["skipped"]
1136 # Add any errors to status
1137 for error in result.get("errors", []):
1138 status.errors.append(error)
1140 logger.info(f"Bulk processed {len(prompts_data)} prompts: {result['created']} created, {result['updated']} updated, {result['skipped']} skipped, {result['failed']} failed")
1142 except Exception as e:
1143 status.failed_entities += len(prompts_data)
1144 status.errors.append(f"Bulk prompt processing failed: {str(e)}")
1145 logger.error(f"Failed to bulk process prompts: {str(e)}")
1146 # Don't raise - allow import to continue with other entities
1148 async def _process_root(self, root_data: Dict[str, Any], conflict_strategy: ConflictStrategy, dry_run: bool, status: ImportStatus) -> None:
1149 """Process a root entity.
1151 Args:
1152 root_data: Root data dictionary
1153 conflict_strategy: How to handle conflicts
1154 dry_run: Whether this is a dry run
1155 status: Import status tracker
1157 Raises:
1158 ImportError: If processing fails
1159 ImportConflictError: If conflict cannot be resolved
1160 """
1161 root_uri = root_data["uri"]
1163 if dry_run:
1164 status.warnings.append(f"Would import root: {root_uri}")
1165 return
1167 try:
1168 await self.root_service.add_root(root_uri, root_data.get("name"))
1169 status.created_entities += 1
1170 logger.debug(f"Created root: {root_uri}")
1172 except Exception as e:
1173 if conflict_strategy == ConflictStrategy.SKIP:
1174 status.skipped_entities += 1
1175 status.warnings.append(f"Skipped existing root: {root_uri}")
1176 elif conflict_strategy == ConflictStrategy.FAIL:
1177 raise ImportConflictError(f"Root URI conflict: {root_uri}")
1178 else:
1179 raise ImportError(f"Failed to process root {root_uri}: {str(e)}")
1181 def _convert_to_tool_create(self, tool_data: Dict[str, Any]) -> ToolCreate:
1182 """Convert import data to ToolCreate schema.
1184 Args:
1185 tool_data: Tool data dictionary from import
1187 Returns:
1188 ToolCreate schema object
1189 """
1190 # Extract auth information if present
1191 auth_info = None
1192 if tool_data.get("auth_type") and tool_data.get("auth_value"):
1193 auth_info = AuthenticationValues(auth_type=tool_data["auth_type"], auth_value=tool_data["auth_value"])
1195 return ToolCreate(
1196 name=tool_data["name"],
1197 displayName=tool_data.get("displayName"),
1198 url=tool_data["url"],
1199 description=tool_data.get("description"),
1200 integration_type=tool_data.get("integration_type", "REST"),
1201 request_type=tool_data.get("request_type", "GET"),
1202 headers=tool_data.get("headers"),
1203 input_schema=tool_data.get("input_schema"),
1204 output_schema=tool_data.get("output_schema"),
1205 annotations=tool_data.get("annotations"),
1206 jsonpath_filter=tool_data.get("jsonpath_filter"),
1207 auth=auth_info,
1208 tags=tool_data.get("tags", []),
1209 )
1211 def _convert_to_tool_update(self, tool_data: Dict[str, Any]) -> ToolUpdate:
1212 """Convert import data to ToolUpdate schema.
1214 Args:
1215 tool_data: Tool data dictionary from import
1217 Returns:
1218 ToolUpdate schema object
1219 """
1220 auth_info = None
1221 if tool_data.get("auth_type") and tool_data.get("auth_value"):
1222 auth_info = AuthenticationValues(auth_type=tool_data["auth_type"], auth_value=tool_data["auth_value"])
1224 return ToolUpdate(
1225 name=tool_data.get("name"),
1226 displayName=tool_data.get("displayName"),
1227 url=tool_data.get("url"),
1228 description=tool_data.get("description"),
1229 integration_type=tool_data.get("integration_type"),
1230 request_type=tool_data.get("request_type"),
1231 headers=tool_data.get("headers"),
1232 input_schema=tool_data.get("input_schema"),
1233 output_schema=tool_data.get("output_schema"),
1234 annotations=tool_data.get("annotations"),
1235 jsonpath_filter=tool_data.get("jsonpath_filter"),
1236 auth=auth_info,
1237 tags=tool_data.get("tags"),
1238 )
1240 def _convert_to_gateway_create(self, gateway_data: Dict[str, Any]) -> GatewayCreate:
1241 """Convert import data to GatewayCreate schema.
1243 Args:
1244 gateway_data: Gateway data dictionary from import
1246 Returns:
1247 GatewayCreate schema object
1248 """
1249 # Handle auth data
1250 auth_kwargs = {}
1251 if gateway_data.get("auth_type"):
1252 auth_kwargs["auth_type"] = gateway_data["auth_type"]
1254 # Handle query_param auth type (new in this version)
1255 if gateway_data["auth_type"] == "query_param" and gateway_data.get("auth_query_params"):
1256 try:
1257 auth_query_params = gateway_data["auth_query_params"]
1258 if auth_query_params:
1259 # Get the first key-value pair (schema supports single param)
1260 param_key = next(iter(auth_query_params.keys()))
1261 encrypted_value = auth_query_params[param_key]
1262 # Decode the encrypted value - returns dict like {param_key: value}
1263 decrypted_dict = decode_auth(encrypted_value)
1264 # Extract the actual value from the dict
1265 decrypted_value = decrypted_dict.get(param_key, "") if isinstance(decrypted_dict, dict) else str(decrypted_dict)
1266 auth_kwargs["auth_query_param_key"] = param_key
1267 auth_kwargs["auth_query_param_value"] = decrypted_value
1268 logger.debug(f"Importing gateway with query_param auth, key: {param_key}")
1269 except Exception as e:
1270 logger.warning(f"Failed to decode query param auth for gateway: {str(e)}")
1271 # Decode auth_value to get original credentials
1272 elif gateway_data.get("auth_value"):
1273 try:
1274 decoded_auth = decode_auth(gateway_data["auth_value"])
1275 if gateway_data["auth_type"] == "basic":
1276 # Extract username and password from Basic auth
1277 auth_header = decoded_auth.get("Authorization", "")
1278 if auth_header.startswith("Basic "):
1279 creds = base64.b64decode(auth_header[6:]).decode("utf-8")
1280 username, password = creds.split(":", 1)
1281 auth_kwargs.update({"auth_username": username, "auth_password": password})
1282 elif gateway_data["auth_type"] == "bearer":
1283 # Extract token from Bearer auth
1284 auth_header = decoded_auth.get("Authorization", "")
1285 if auth_header.startswith("Bearer "):
1286 auth_kwargs["auth_token"] = auth_header[7:]
1287 elif gateway_data["auth_type"] == "authheaders":
1288 # Handle custom headers
1289 if len(decoded_auth) == 1:
1290 key, value = next(iter(decoded_auth.items()))
1291 auth_kwargs.update({"auth_header_key": key, "auth_header_value": value})
1292 else:
1293 # Multiple headers - use the new format
1294 headers_list = [{"key": k, "value": v} for k, v in decoded_auth.items()]
1295 auth_kwargs["auth_headers"] = headers_list
1296 except Exception as e:
1297 logger.warning(f"Failed to decode auth data for gateway: {str(e)}")
1299 return GatewayCreate(
1300 name=gateway_data["name"],
1301 url=gateway_data["url"],
1302 description=gateway_data.get("description"),
1303 transport=gateway_data.get("transport", "SSE"),
1304 passthrough_headers=gateway_data.get("passthrough_headers"),
1305 tags=gateway_data.get("tags", []),
1306 **auth_kwargs,
1307 )
1309 def _convert_to_gateway_update(self, gateway_data: Dict[str, Any]) -> GatewayUpdate:
1310 """Convert import data to GatewayUpdate schema.
1312 Args:
1313 gateway_data: Gateway data dictionary from import
1315 Returns:
1316 GatewayUpdate schema object
1317 """
1318 # Similar to create but all fields optional
1319 auth_kwargs = {}
1320 if gateway_data.get("auth_type"):
1321 auth_kwargs["auth_type"] = gateway_data["auth_type"]
1323 # Handle query_param auth type (new in this version)
1324 if gateway_data["auth_type"] == "query_param" and gateway_data.get("auth_query_params"):
1325 try:
1326 auth_query_params = gateway_data["auth_query_params"]
1327 if auth_query_params:
1328 # Get the first key-value pair (schema supports single param)
1329 param_key = next(iter(auth_query_params.keys()))
1330 encrypted_value = auth_query_params[param_key]
1331 # Decode the encrypted value - returns dict like {param_key: value}
1332 decrypted_dict = decode_auth(encrypted_value)
1333 # Extract the actual value from the dict
1334 decrypted_value = decrypted_dict.get(param_key, "") if isinstance(decrypted_dict, dict) else str(decrypted_dict)
1335 auth_kwargs["auth_query_param_key"] = param_key
1336 auth_kwargs["auth_query_param_value"] = decrypted_value
1337 logger.debug(f"Importing gateway update with query_param auth, key: {param_key}")
1338 except Exception as e:
1339 logger.warning(f"Failed to decode query param auth for gateway update: {str(e)}")
1340 elif gateway_data.get("auth_value"):
1341 try:
1342 decoded_auth = decode_auth(gateway_data["auth_value"])
1343 if gateway_data["auth_type"] == "basic":
1344 auth_header = decoded_auth.get("Authorization", "")
1345 if auth_header.startswith("Basic "):
1346 creds = base64.b64decode(auth_header[6:]).decode("utf-8")
1347 username, password = creds.split(":", 1)
1348 auth_kwargs.update({"auth_username": username, "auth_password": password})
1349 elif gateway_data["auth_type"] == "bearer":
1350 auth_header = decoded_auth.get("Authorization", "")
1351 if auth_header.startswith("Bearer "):
1352 auth_kwargs["auth_token"] = auth_header[7:]
1353 elif gateway_data["auth_type"] == "authheaders":
1354 if len(decoded_auth) == 1:
1355 key, value = next(iter(decoded_auth.items()))
1356 auth_kwargs.update({"auth_header_key": key, "auth_header_value": value})
1357 else:
1358 headers_list = [{"key": k, "value": v} for k, v in decoded_auth.items()]
1359 auth_kwargs["auth_headers"] = headers_list
1360 except Exception as e:
1361 logger.warning(f"Failed to decode auth data for gateway update: {str(e)}")
1363 return GatewayUpdate(
1364 name=gateway_data.get("name"),
1365 url=gateway_data.get("url"),
1366 description=gateway_data.get("description"),
1367 transport=gateway_data.get("transport"),
1368 passthrough_headers=gateway_data.get("passthrough_headers"),
1369 tags=gateway_data.get("tags"),
1370 **auth_kwargs,
1371 )
1373 async def _convert_to_server_create(self, db: Session, server_data: Dict[str, Any]) -> ServerCreate:
1374 """Convert import data to ServerCreate schema, resolving tool references.
1376 Args:
1377 db: Database session
1378 server_data: Server data dictionary from import
1380 Returns:
1381 ServerCreate schema object with resolved tool IDs
1382 """
1383 # Resolve tool references (could be names or IDs) to current tool IDs
1384 tool_references = server_data.get("tool_ids", []) or server_data.get("associated_tools", [])
1385 resolved_tool_ids = []
1387 if tool_references:
1388 # Get all tools to resolve references
1389 all_tools, _ = await self.tool_service.list_tools(db, include_inactive=True)
1391 for tool_ref in tool_references:
1392 # Try to find tool by ID first, then by name
1393 found_tool = None
1395 # Try exact ID match
1396 found_tool = next((t for t in all_tools if t.id == tool_ref), None)
1398 # If not found, try by original_name or name
1399 if not found_tool:
1400 found_tool = next((t for t in all_tools if t.original_name == tool_ref), None)
1402 if not found_tool:
1403 found_tool = next((t for t in all_tools if hasattr(t, "name") and t.name == tool_ref), None)
1405 if found_tool:
1406 resolved_tool_ids.append(found_tool.id)
1407 logger.debug(f"Resolved tool reference '{tool_ref}' to ID {found_tool.id}")
1408 else:
1409 logger.warning(f"Could not resolve tool reference: {tool_ref}")
1410 # Don't include unresolvable references
1412 return ServerCreate(name=server_data["name"], description=server_data.get("description"), associated_tools=resolved_tool_ids, tags=server_data.get("tags", []))
1414 async def _convert_to_server_update(self, db: Session, server_data: Dict[str, Any]) -> ServerUpdate:
1415 """Convert import data to ServerUpdate schema, resolving tool references.
1417 Args:
1418 db: Database session
1419 server_data: Server data dictionary from import
1421 Returns:
1422 ServerUpdate schema object with resolved tool IDs
1423 """
1424 # Resolve tool references same as create method
1425 tool_references = server_data.get("tool_ids", []) or server_data.get("associated_tools", [])
1426 resolved_tool_ids = []
1428 if tool_references:
1429 all_tools, _ = await self.tool_service.list_tools(db, include_inactive=True)
1431 for tool_ref in tool_references:
1432 found_tool = next((t for t in all_tools if t.id == tool_ref), None)
1433 if not found_tool:
1434 found_tool = next((t for t in all_tools if t.original_name == tool_ref), None)
1435 if not found_tool:
1436 found_tool = next((t for t in all_tools if hasattr(t, "name") and t.name == tool_ref), None)
1438 if found_tool:
1439 resolved_tool_ids.append(found_tool.id)
1440 else:
1441 logger.warning(f"Could not resolve tool reference for update: {tool_ref}")
1443 return ServerUpdate(name=server_data.get("name"), description=server_data.get("description"), associated_tools=resolved_tool_ids if resolved_tool_ids else None, tags=server_data.get("tags"))
1445 def _convert_to_prompt_create(self, prompt_data: Dict[str, Any]) -> PromptCreate:
1446 """Convert import data to PromptCreate schema.
1448 Args:
1449 prompt_data: Prompt data dictionary from import
1451 Returns:
1452 PromptCreate schema object
1453 """
1454 # Convert input_schema back to arguments format
1455 arguments = []
1456 input_schema = prompt_data.get("input_schema", {})
1457 if isinstance(input_schema, dict):
1458 properties = input_schema.get("properties", {})
1459 required_fields = input_schema.get("required", [])
1461 for prop_name, prop_data in properties.items():
1462 arguments.append({"name": prop_name, "description": prop_data.get("description", ""), "required": prop_name in required_fields})
1464 original_name = prompt_data.get("original_name") or prompt_data["name"]
1465 return PromptCreate(
1466 name=original_name,
1467 custom_name=prompt_data.get("custom_name"),
1468 display_name=prompt_data.get("display_name"),
1469 template=prompt_data["template"],
1470 description=prompt_data.get("description"),
1471 arguments=arguments,
1472 tags=prompt_data.get("tags", []),
1473 )
1475 def _convert_to_prompt_update(self, prompt_data: Dict[str, Any]) -> PromptUpdate:
1476 """Convert import data to PromptUpdate schema.
1478 Args:
1479 prompt_data: Prompt data dictionary from import
1481 Returns:
1482 PromptUpdate schema object
1483 """
1484 arguments = []
1485 input_schema = prompt_data.get("input_schema", {})
1486 if isinstance(input_schema, dict):
1487 properties = input_schema.get("properties", {})
1488 required_fields = input_schema.get("required", [])
1490 for prop_name, prop_data in properties.items():
1491 arguments.append({"name": prop_name, "description": prop_data.get("description", ""), "required": prop_name in required_fields})
1493 original_name = prompt_data.get("original_name") or prompt_data.get("name")
1494 return PromptUpdate(
1495 name=original_name,
1496 custom_name=prompt_data.get("custom_name"),
1497 display_name=prompt_data.get("display_name"),
1498 template=prompt_data.get("template"),
1499 description=prompt_data.get("description"),
1500 arguments=arguments if arguments else None,
1501 tags=prompt_data.get("tags"),
1502 )
1504 def _convert_to_resource_create(self, resource_data: Dict[str, Any]) -> ResourceCreate:
1505 """Convert import data to ResourceCreate schema.
1507 Args:
1508 resource_data: Resource data dictionary from import
1510 Returns:
1511 ResourceCreate schema object
1512 """
1513 return ResourceCreate(
1514 uri=resource_data["uri"],
1515 name=resource_data["name"],
1516 description=resource_data.get("description"),
1517 mime_type=resource_data.get("mime_type"),
1518 content=resource_data.get("content", ""), # Default empty content
1519 tags=resource_data.get("tags", []),
1520 )
1522 def _convert_to_resource_update(self, resource_data: Dict[str, Any]) -> ResourceUpdate:
1523 """Convert import data to ResourceUpdate schema.
1525 Args:
1526 resource_data: Resource data dictionary from import
1528 Returns:
1529 ResourceUpdate schema object
1530 """
1531 return ResourceUpdate(
1532 name=resource_data.get("name"), description=resource_data.get("description"), mime_type=resource_data.get("mime_type"), content=resource_data.get("content"), tags=resource_data.get("tags")
1533 )
1535 def get_import_status(self, import_id: str) -> Optional[ImportStatus]:
1536 """Get the status of an import operation.
1538 Args:
1539 import_id: Import operation ID
1541 Returns:
1542 Import status object or None if not found
1543 """
1544 return self.active_imports.get(import_id)
1546 def list_import_statuses(self) -> List[ImportStatus]:
1547 """List all import statuses.
1549 Returns:
1550 List of all import status objects
1551 """
1552 return list(self.active_imports.values())
1554 def cleanup_completed_imports(self, max_age_hours: int = 24) -> int:
1555 """Clean up completed import statuses older than max_age_hours.
1557 Args:
1558 max_age_hours: Maximum age in hours for keeping completed imports
1560 Returns:
1561 Number of import statuses removed
1562 """
1563 cutoff_time = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
1564 removed = 0
1566 to_remove = []
1567 for import_id, status in self.active_imports.items():
1568 if status.status in ["completed", "failed"] and status.completed_at and status.completed_at < cutoff_time:
1569 to_remove.append(import_id)
1571 for import_id in to_remove:
1572 del self.active_imports[import_id]
1573 removed += 1
1575 return removed
1577 async def preview_import(self, db: Session, import_data: Dict[str, Any]) -> Dict[str, Any]:
1578 """Preview import file to show what would be imported with smart categorization.
1580 Args:
1581 db: Database session
1582 import_data: The validated import data
1584 Returns:
1585 Dictionary with categorized items for selective import UI
1587 Examples:
1588 >>> service = ImportService()
1589 >>> # This would return a structure for the UI to build selection interface
1590 """
1591 self.validate_import_data(import_data)
1593 entities = import_data.get("entities", {})
1594 preview = {
1595 "summary": {"total_items": sum(len(items) for items in entities.values()), "by_type": {entity_type: len(items) for entity_type, items in entities.items()}},
1596 "items": {},
1597 "bundles": {},
1598 "conflicts": {},
1599 "dependencies": {},
1600 }
1602 # Categorize each entity type
1603 for entity_type, entity_list in entities.items():
1604 preview["items"][entity_type] = []
1606 for entity in entity_list:
1607 item_info = await self._analyze_import_item(db, entity_type, entity)
1608 preview["items"][entity_type].append(item_info)
1610 # Find gateway bundles (gateways + their tools/resources/prompts)
1611 if "gateways" in entities:
1612 preview["bundles"] = self._find_gateway_bundles(entities)
1614 # Find server dependencies
1615 if "servers" in entities:
1616 preview["dependencies"] = self._find_server_dependencies(entities)
1618 # Detect conflicts with existing items
1619 preview["conflicts"] = await self._detect_import_conflicts(db, entities)
1621 return preview
1623 async def _analyze_import_item(self, db: Session, entity_type: str, entity: Dict[str, Any]) -> Dict[str, Any]:
1624 """Analyze a single import item for the preview.
1626 Args:
1627 db: Database session
1628 entity_type: Type of entity
1629 entity: Entity data
1631 Returns:
1632 Item analysis with metadata for UI selection
1633 """
1634 item_name = self._get_entity_identifier(entity_type, entity)
1636 # Basic item info
1637 item_info = {
1638 "id": item_name,
1639 "name": entity.get("name", item_name),
1640 "type": entity_type,
1641 "is_gateway_item": bool(entity.get("gateway_name") or entity.get("gateway_id")),
1642 "is_custom": not bool(entity.get("gateway_name") or entity.get("gateway_id")),
1643 "description": entity.get("description", ""),
1644 }
1646 # Check if it conflicts with existing items
1647 try:
1648 if entity_type == "tools":
1649 existing, _ = await self.tool_service.list_tools(db)
1650 item_info["conflicts_with"] = any(t.original_name == item_name for t in existing)
1651 elif entity_type == "gateways":
1652 existing, _ = await self.gateway_service.list_gateways(db)
1653 item_info["conflicts_with"] = any(g.name == item_name for g in existing)
1654 elif entity_type == "servers":
1655 existing = await self.server_service.list_servers(db)
1656 item_info["conflicts_with"] = any(s.name == item_name for s in existing)
1657 elif entity_type == "prompts":
1658 existing, _ = await self.prompt_service.list_prompts(db)
1659 item_info["conflicts_with"] = any(p.name == item_name for p in existing)
1660 elif entity_type == "resources":
1661 existing, _ = await self.resource_service.list_resources(db)
1662 item_info["conflicts_with"] = any(r.uri == item_name for r in existing)
1663 else:
1664 item_info["conflicts_with"] = False
1665 except Exception:
1666 item_info["conflicts_with"] = False
1668 # Add metadata for smart selection
1669 if entity_type == "servers":
1670 item_info["dependencies"] = {"tools": entity.get("associated_tools", []), "resources": entity.get("associated_resources", []), "prompts": entity.get("associated_prompts", [])}
1672 return item_info
1674 def _find_gateway_bundles(self, entities: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
1675 """Find gateway bundles (gateway + associated tools/resources/prompts).
1677 Args:
1678 entities: All entities from import data
1680 Returns:
1681 Gateway bundle information for UI
1682 """
1683 bundles = {}
1685 if "gateways" not in entities:
1686 return bundles
1688 for gateway in entities["gateways"]:
1689 gateway_name = gateway.get("name", "")
1690 bundle_items = {"tools": [], "resources": [], "prompts": []}
1692 # Find items that belong to this gateway
1693 for entity_type in ["tools", "resources", "prompts"]:
1694 if entity_type in entities:
1695 for item in entities[entity_type]:
1696 item_gateway = item.get("gateway_name") or item.get("gateway_id")
1697 if item_gateway == gateway_name:
1698 item_name = self._get_entity_identifier(entity_type, item)
1699 bundle_items[entity_type].append({"id": item_name, "name": item.get("name", item_name), "description": item.get("description", "")})
1701 if any(bundle_items.values()): # Only add if gateway has items
1702 bundles[gateway_name] = {
1703 "gateway": {"name": gateway_name, "description": gateway.get("description", "")},
1704 "items": bundle_items,
1705 "total_items": sum(len(items) for items in bundle_items.values()),
1706 }
1708 return bundles
1710 def _find_server_dependencies(self, entities: Dict[str, List[Dict[str, Any]]]) -> Dict[str, Any]:
1711 """Find server dependencies for smart selection.
1713 Args:
1714 entities: All entities from import data
1716 Returns:
1717 Server dependency information for UI
1718 """
1719 dependencies = {}
1721 if "servers" not in entities:
1722 return dependencies
1724 for server in entities["servers"]:
1725 server_name = server.get("name", "")
1726 deps = {"tools": server.get("associated_tools", []), "resources": server.get("associated_resources", []), "prompts": server.get("associated_prompts", [])}
1728 if any(deps.values()): # Only add if server has dependencies
1729 dependencies[server_name] = {
1730 "server": {"name": server_name, "description": server.get("description", "")},
1731 "requires": deps,
1732 "total_dependencies": sum(len(items) for items in deps.values()),
1733 }
1735 return dependencies
1737 async def _detect_import_conflicts(self, db: Session, entities: Dict[str, List[Dict[str, Any]]]) -> Dict[str, List[Dict[str, Any]]]:
1738 """Detect conflicts between import items and existing database items.
1740 Args:
1741 db: Database session
1742 entities: Import entities
1744 Returns:
1745 Dictionary of conflicts by entity type
1746 """
1747 conflicts = {}
1749 try:
1750 # Check tool conflicts
1751 if "tools" in entities:
1752 existing_tools, _ = await self.tool_service.list_tools(db)
1753 existing_names = {t.original_name for t in existing_tools}
1755 tool_conflicts = []
1756 for tool in entities["tools"]:
1757 tool_name = tool.get("name", "")
1758 if tool_name in existing_names:
1759 tool_conflicts.append({"name": tool_name, "type": "name_conflict", "description": tool.get("description", "")})
1761 if tool_conflicts:
1762 conflicts["tools"] = tool_conflicts
1764 # Check gateway conflicts
1765 if "gateways" in entities:
1766 existing_gateways, _ = await self.gateway_service.list_gateways(db)
1767 existing_names = {g.name for g in existing_gateways}
1769 gateway_conflicts = []
1770 for gateway in entities["gateways"]:
1771 gateway_name = gateway.get("name", "")
1772 if gateway_name in existing_names:
1773 gateway_conflicts.append({"name": gateway_name, "type": "name_conflict", "description": gateway.get("description", "")})
1775 if gateway_conflicts:
1776 conflicts["gateways"] = gateway_conflicts
1778 # Add other entity types as needed...
1780 except Exception as e:
1781 logger.warning(f"Could not detect all conflicts: {e}")
1783 return conflicts
1785 async def _get_user_context(self, db: Session, imported_by: str) -> Optional[Dict[str, Any]]:
1786 """Get user context for import team assignment.
1788 Args:
1789 db: Database session
1790 imported_by: Email of importing user
1792 Returns:
1793 User context dict or None if not found
1794 """
1795 try:
1796 user = db.query(EmailUser).filter(EmailUser.email == imported_by).first()
1797 if not user:
1798 logger.warning(f"Could not find importing user: {imported_by}")
1799 return None
1801 personal_team = user.get_personal_team()
1802 if not personal_team:
1803 logger.warning(f"User {imported_by} has no personal team")
1804 return None
1806 return {"user_email": user.email, "team_id": personal_team.id, "team_name": personal_team.name}
1807 except Exception as e:
1808 logger.error(f"Failed to get user context: {e}")
1809 return None
1811 def _add_multitenancy_context(self, entity_data: Dict[str, Any], user_context: Dict[str, Any]) -> Dict[str, Any]:
1812 """Add team and visibility context to entity data for import.
1814 Args:
1815 entity_data: Original entity data
1816 user_context: User context with team information
1818 Returns:
1819 Entity data enhanced with multitenancy fields
1820 """
1821 # Create copy to avoid modifying original
1822 enhanced_data = dict(entity_data)
1824 # Add team assignment (assign to importing user's personal team)
1825 if not enhanced_data.get("team_id"):
1826 enhanced_data["team_id"] = user_context["team_id"]
1828 if not enhanced_data.get("owner_email"):
1829 enhanced_data["owner_email"] = user_context["user_email"]
1831 # Set visibility: use export value if present, otherwise default to 'team'
1832 # This supports pre-0.7.0 exports that don't have visibility field
1833 if not enhanced_data.get("visibility"):
1834 enhanced_data["visibility"] = "team"
1836 # Add import tracking
1837 if not enhanced_data.get("federation_source"):
1838 enhanced_data["federation_source"] = f"imported-by-{user_context['user_email']}"
1840 logger.debug(f"Enhanced entity with multitenancy: team_id={enhanced_data['team_id']}, visibility={enhanced_data['visibility']}")
1841 return enhanced_data
1843 async def _assign_imported_items_to_team(self, db: Session, imported_by: str, imported_after: Optional[datetime] = None) -> None:
1844 """Assign imported items without team assignment to the importer's personal team.
1846 Args:
1847 db: Database session
1848 imported_by: Email of user who performed the import
1849 imported_after: Optional lower bound timestamp for this import run
1850 """
1851 try:
1852 # Find the importing user and their personal team
1853 user = db.query(EmailUser).filter(EmailUser.email == imported_by).first()
1854 if not user:
1855 logger.warning(f"Could not find importing user {imported_by} - skipping team assignment")
1856 return
1858 personal_team = user.get_personal_team()
1859 if not personal_team:
1860 logger.warning(f"User {imported_by} has no personal team - skipping team assignment")
1861 return
1863 logger.info(f"Assigning imported items to {imported_by}'s team: {personal_team.name}")
1865 # Resource types to check
1866 resource_types = [("servers", Server), ("tools", Tool), ("resources", Resource), ("prompts", Prompt), ("gateways", Gateway), ("a2a_agents", A2AAgent)]
1868 total_assigned = 0
1870 for resource_name, resource_model in resource_types:
1871 try:
1872 query = db.query(resource_model).filter(resource_model.created_by == imported_by, resource_model.created_via == "import")
1873 if imported_after is not None:
1874 query = query.filter(resource_model.created_at >= imported_after)
1876 recently_imported = query.all()
1877 if not recently_imported:
1878 continue
1880 assigned_for_type = 0
1881 for item in recently_imported:
1882 changed = False
1883 if not item.team_id:
1884 item.team_id = personal_team.id
1885 changed = True
1886 if not item.owner_email:
1887 item.owner_email = user.email
1888 changed = True
1889 if getattr(item, "visibility", None) != "team":
1890 # Assign a secure default visibility when import payload omits it.
1891 item.visibility = "team"
1892 changed = True
1893 if hasattr(item, "federation_source") and not item.federation_source:
1894 item.federation_source = f"imported-by-{imported_by}"
1895 changed = True
1897 if changed:
1898 assigned_for_type += 1
1900 if assigned_for_type:
1901 logger.info(f"Assigned {assigned_for_type} imported {resource_name} to user team")
1902 total_assigned += assigned_for_type
1904 except Exception as e:
1905 logger.error(f"Failed to assign {resource_name} to team: {e}")
1906 continue
1908 if total_assigned > 0:
1909 db.commit()
1910 logger.info(f"Assigned {total_assigned} imported items to {personal_team.name} with team visibility")
1911 else:
1912 logger.debug("No orphaned imported items found")
1914 except Exception as e:
1915 logger.error(f"Failed to assign imported items to team: {e}")
1916 # Don't fail the import for team assignment issues