Coverage for mcpgateway / utils / sqlalchemy_modifier.py: 100%

120 statements  

« prev     ^ index     » next       coverage.py v7.13.1, created at 2026-02-11 07:10 +0000

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/sqlalchemy_modifier.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Madhav Kandukuri 

6 

7SQLAlchemy modifiers 

8 

9- json_contains_expr: handles json_contains logic for different dialects 

10- json_contains_tag_expr: handles tag filtering for dict-format tags [{id, label}] 

11""" 

12 

13# Standard 

14import itertools 

15import re 

16import threading 

17from typing import Any, Iterable, List, Union 

18import uuid 

19 

20# Third-Party 

21import orjson 

22from sqlalchemy import and_, func, or_, text 

23from sqlalchemy.sql.elements import TextClause 

24 

25# Thread-safe counter for generating unique bind parameter prefixes 

26_bind_counter = itertools.count() 

27_bind_counter_lock = threading.Lock() 

28 

29 

30def _ensure_list(values: Union[str, Iterable[str]]) -> List[str]: 

31 """ 

32 Normalize input into a list of strings. 

33 

34 Args: 

35 values: A single string or any iterable of strings. If `None`, an empty 

36 list is returned. 

37 

38 Returns: 

39 A list of strings. If `values` is a string it will be wrapped in a 

40 single-item list; if it's already an iterable, it will be converted to 

41 a list. If `values` is `None`, returns an empty list. 

42 """ 

43 if values is None: 

44 return [] 

45 if isinstance(values, str): 

46 return [values] 

47 return list(values) 

48 

49 

50def _generate_unique_prefix(col_ref: str) -> str: 

51 """ 

52 Generate a unique SQL bind parameter prefix for a column reference. 

53 

54 Combines a sanitized column name with a thread-safe counter to ensure 

55 unique bind parameter names across all calls, even when: 

56 - The same column is filtered multiple times in one query 

57 - Different column refs sanitize to the same string (e.g., a_b.c vs a.b_c) 

58 

59 Args: 

60 col_ref: Column reference like "resources.tags" 

61 

62 Returns: 

63 Unique prefix like "resources_tags_42" 

64 """ 

65 sanitized = re.sub(r"[^a-zA-Z0-9]", "_", col_ref) 

66 with _bind_counter_lock: 

67 counter = next(_bind_counter) 

68 return f"{sanitized}_{counter}" 

69 

70 

71def _sqlite_tag_any_template(col_ref: str, prefix: str, n: int) -> TextClause: 

72 """ 

73 Build a SQLite SQL template for matching ANY of the given tags 

74 inside a JSON array column. 

75 

76 This template supports both legacy string tags and object-style tags 

77 (e.g., {"id": "api"}). It safely guards `json_extract` with 

78 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values. 

79 

80 Design: We only filter by 'id' field, not 'label'. TagValidator ensures 

81 all properly validated tags have an 'id' field. See json_contains_tag_expr 

82 docstring for rationale. 

83 

84 The generated SQL uses unique bind parameters with the provided prefix 

85 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag 

86 filters are used in the same query. 

87 

88 Args: 

89 col_ref (str): Fully-qualified column reference (e.g., "resources.tags"). 

90 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix). 

91 n (int): Number of tag values being matched. 

92 

93 Returns: 

94 sqlalchemy.sql.elements.TextClause: 

95 A SQL template for matching ANY of the given tags. 

96 """ 

97 if n == 1: 

98 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p0 OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p0)" # nosec B608 

99 sql = tmp_.strip() 

100 else: 

101 placeholders = ",".join(f":{prefix}_p{i}" for i in range(n)) 

102 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders}) OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) IN ({placeholders}))" # nosec B608 

103 sql = tmp_.strip() 

104 

105 return text(sql) 

106 

107 

108def _sqlite_tag_all_template(col_ref: str, prefix: str, n: int) -> TextClause: 

109 """ 

110 Build a SQLite SQL template for matching ALL of the given tags 

111 inside a JSON array column. 

112 

113 This is implemented as an AND-chain of EXISTS subqueries, where each 

114 subquery ensures the presence of one required tag. 

115 

116 This template supports both legacy string tags and object-style tags 

117 (e.g., {"id": "api"}). It safely guards `json_extract` with 

118 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values. 

119 

120 The generated SQL uses unique bind parameters with the provided prefix 

121 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag 

122 filters are used in the same query. 

123 

124 Args: 

125 col_ref (str): Fully-qualified column reference (e.g., "resources.tags"). 

126 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix). 

127 n (int): Number of tag values being matched. 

128 

129 Returns: 

130 sqlalchemy.sql.elements.TextClause: 

131 A SQL template for matching ALL of the given tags. 

132 """ 

133 clauses = [] 

134 for i in range(n): 

135 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p{i} OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p{i})" # nosec B608 

136 clauses.append(tmp_.strip()) 

137 

138 return text(" AND ".join(clauses)) 

139 

140 

141def json_contains_tag_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any: 

142 """ 

143 Return a SQLAlchemy expression that is True when JSON column `col` 

144 contains tags matching the given values. Handles both legacy List[str] 

145 and new List[Dict[str, str]] (with 'id' field) tag formats. 

146 

147 Args: 

148 session: database session 

149 col: column that contains JSON array of tags 

150 values: list of tag IDs to match against 

151 match_any: Boolean to set OR (True) or AND (False) matching 

152 

153 Returns: 

154 Any: SQLAlchemy boolean expression suitable for use in .where() 

155 

156 Raises: 

157 RuntimeError: If dialect is not supported 

158 ValueError: If values is empty 

159 

160 Examples: 

161 SQLite builds a parameterized `TextClause` using `json_each()`: 

162 

163 >>> from unittest.mock import Mock, patch 

164 >>> from sqlalchemy.sql.elements import TextClause 

165 >>> from mcpgateway.utils import sqlalchemy_modifier as mod 

166 >>> session = Mock() 

167 >>> session.get_bind.return_value.dialect.name = "sqlite" 

168 >>> col = Mock() 

169 >>> col.table.name = "resources" 

170 >>> col.name = "tags" 

171 >>> with patch.object(mod, "_generate_unique_prefix", return_value="resources_tags_0"): 

172 ... expr = mod.json_contains_tag_expr(session, col, ["api", "db"], match_any=True) 

173 ... (isinstance(expr, TextClause), "json_each(resources.tags)" in expr.text, expr.compile().params["resources_tags_0_p0"]) 

174 (True, True, 'api') 

175 

176 Unsupported dialects raise: 

177 

178 >>> session.get_bind.return_value.dialect.name = "oracle" 

179 >>> mod.json_contains_tag_expr(session, col, ["x"]) 

180 Traceback (most recent call last): 

181 ... 

182 RuntimeError: Unsupported dialect for json_contains_tag: oracle 

183 """ 

184 values_list = _ensure_list(values) 

185 if not values_list: 

186 raise ValueError("values must be non-empty") 

187 

188 dialect = session.get_bind().dialect.name 

189 

190 # ---------- MySQL ---------- 

191 # For dict-format tags: use JSON_SEARCH to find tags with matching id 

192 # JSON_SEARCH returns path if found, NULL otherwise 

193 if dialect == "mysql": 

194 # Build conditions that check for both string tags and dict tags with matching id 

195 conditions = [] 

196 for tag_value in values_list: 

197 # Check if tag exists as plain string OR as dict with matching id 

198 # JSON_SEARCH(col, 'one', value) finds plain string value 

199 # JSON_CONTAINS with path $.*.id checks dict format 

200 string_match = func.json_search(col, "one", tag_value).isnot(None) 

201 dict_match = func.json_contains(col, orjson.dumps([{"id": tag_value}]).decode()) == 1 

202 conditions.append(or_(string_match, dict_match)) 

203 

204 if match_any: 

205 return or_(*conditions) 

206 return and_(*conditions) 

207 

208 # ---------- PostgreSQL ---------- 

209 # For dict-format tags: use JSON functions that work with both JSON and JSONB types 

210 # Note: .contains() only works with JSONB, but our column is JSON type 

211 # 

212 # Design: We only filter by 'id' field, not 'label'. This is intentional because: 

213 # 1. TagValidator.validate_list() always creates tags with both id and label 

214 # 2. The 'id' is the normalized, canonical identifier for matching 

215 # 3. The 'label' is for display purposes only 

216 # If a tag somehow exists without 'id' (malformed data), it won't match at DB level, 

217 # but _get_tag_id() has a Python-side fallback for graceful handling. 

218 if dialect == "postgresql": 

219 # Third-Party 

220 from sqlalchemy import bindparam, cast, exists, select 

221 from sqlalchemy.dialects.postgresql import JSONB 

222 from sqlalchemy.sql import literal_column 

223 

224 # Build conditions for each tag value using JSON functions 

225 conditions = [] 

226 for tag_value in values_list: 

227 # Generate unique parameter name 

228 param_name = f"tag_{uuid.uuid4().hex[:8]}" 

229 param_dict = f"tag_{uuid.uuid4().hex[:8]}" 

230 

231 # For string tags: use @> operator to check if JSONB array contains the value 

232 # Cast the tag_value to JSONB array and check containment 

233 string_match = cast(col, JSONB).op("@>")(cast(func.jsonb_build_array(bindparam(param_name, value=tag_value)), JSONB)) 

234 

235 # For dict tags: use EXISTS with jsonb_array_elements to check 'id' field 

236 # Use table_valued() for explicit column reference (elem.c.value) 

237 # This is the idiomatic SQLAlchemy pattern for table-valued functions 

238 elem_table = func.jsonb_array_elements(cast(col, JSONB)).table_valued("value").alias("elem") 

239 dict_match = exists(select(literal_column("1")).select_from(elem_table).where(elem_table.c.value.op("->>")(literal_column("'id'")) == bindparam(param_dict, value=tag_value))) 

240 

241 conditions.append(or_(string_match, dict_match)) 

242 

243 if match_any: 

244 return or_(*conditions) 

245 return and_(*conditions) 

246 

247 # ---------- SQLite (json1) ---------- 

248 # For dict-format tags: use json_extract to get the 'id' field 

249 # Use CASE WHEN type = 'object' to avoid "malformed JSON" error on string elements 

250 if dialect == "sqlite": 

251 table_name = getattr(getattr(col, "table", None), "name", None) 

252 column_name = getattr(col, "name", None) or str(col) 

253 col_ref = f"{table_name}.{column_name}" if table_name else column_name 

254 

255 n = len(values_list) 

256 if n == 0: 

257 raise ValueError("values must be non-empty") 

258 

259 # Generate unique prefix to avoid bind name collisions when multiple 

260 # tag filters are combined in the same query (even on the same column 

261 # or when different column refs sanitize to the same string) 

262 prefix = _generate_unique_prefix(col_ref) 

263 params = {f"{prefix}_p{i}": t for i, t in enumerate(values_list)} 

264 

265 if match_any: 

266 tmpl = _sqlite_tag_any_template(col_ref, prefix, n) 

267 return tmpl.bindparams(**params) 

268 

269 tmpl = _sqlite_tag_all_template(col_ref, prefix, n) 

270 return tmpl.bindparams(**params) 

271 

272 raise RuntimeError(f"Unsupported dialect for json_contains_tag: {dialect}") 

273 

274 

275def json_contains_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any: 

276 """ 

277 Return a SQLAlchemy expression that is True when JSON column `col` 

278 contains the scalar `value`. `session` is used to detect dialect. 

279 Assumes `col` is a JSON/JSONB column (array-of-strings case). 

280 

281 Args: 

282 session: database session 

283 col: column that contains JSON 

284 values: list of values to check for in json 

285 match_any: Boolean to set OR or AND matching 

286 

287 Returns: 

288 Any: SQLAlchemy boolean expression suitable for use in .where() 

289 

290 Raises: 

291 RuntimeError: If dialect is not supported 

292 ValueError: If values is empty 

293 

294 Examples: 

295 SQLite builds a parameterized SQL fragment using `json_each()`: 

296 

297 >>> from unittest.mock import Mock 

298 >>> from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr 

299 >>> session = Mock() 

300 >>> session.get_bind.return_value.dialect.name = "sqlite" 

301 >>> col = Mock() 

302 >>> col.table.name = "resources" 

303 >>> col.name = "scopes" 

304 >>> expr = json_contains_expr(session, col, ["a", "b"], match_any=True) 

305 >>> ("json_each(resources.scopes)" in expr.text, "value IN" in expr.text, len(expr.compile().params)) 

306 (True, True, 2) 

307 

308 Unsupported dialects raise: 

309 

310 >>> session.get_bind.return_value.dialect.name = "oracle" 

311 >>> json_contains_expr(session, col, ["x"]) 

312 Traceback (most recent call last): 

313 ... 

314 RuntimeError: Unsupported dialect for json_contains: oracle 

315 """ 

316 values_list = _ensure_list(values) 

317 if not values_list: 

318 raise ValueError("values must be non-empty") 

319 

320 dialect = session.get_bind().dialect.name 

321 

322 # ---------- MySQL ---------- 

323 # - all-of: JSON_CONTAINS(col, '["a","b"]') == 1 

324 # - any-of: prefer JSON_OVERLAPS (MySQL >= 8.0.17), otherwise OR of JSON_CONTAINS for each value 

325 if dialect == "mysql": 

326 try: 

327 if match_any: 

328 # JSON_OVERLAPS exists in modern MySQL; SQLAlchemy will emit func.json_overlaps(...) 

329 return func.json_overlaps(col, orjson.dumps(values_list).decode()) == 1 

330 else: 

331 return func.json_contains(col, orjson.dumps(values_list).decode()) == 1 

332 except Exception: 

333 # Fallback: compose OR of json_contains for each scalar 

334 if match_any: 

335 return or_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list]) 

336 else: 

337 return and_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list]) 

338 

339 # ---------- PostgreSQL ---------- 

340 # - all-of: col.contains(list) (works if col is JSONB) 

341 # - any-of: use OR of col.contains([value]) (or use ?| operator if you prefer) 

342 if dialect == "postgresql": 

343 # prefer JSONB .contains for all-of 

344 if not match_any: 

345 return col.contains(values_list) 

346 # match_any: use OR over element-containment 

347 return or_(*[col.contains([t]) for t in values_list]) 

348 

349 # ---------- SQLite (json1) ---------- 

350 # SQLite doesn't have JSON_CONTAINS. We build safe SQL: 

351 # - any-of: single EXISTS ... WHERE value IN (:p0,:p1,...) 

352 # - all-of: multiple EXISTS with unique bind params (one EXISTS per value) => AND semantics 

353 if dialect == "sqlite": 

354 table_name = getattr(getattr(col, "table", None), "name", None) 

355 column_name = getattr(col, "name", None) or str(col) 

356 col_ref = f"{table_name}.{column_name}" if table_name else column_name 

357 

358 if match_any: 

359 # Build placeholders with unique param names and pass *values* to bindparams 

360 params = {} 

361 placeholders = [] 

362 for i, t in enumerate(values_list): 

363 pname = f"t_{uuid.uuid4().hex[:8]}_{i}" 

364 placeholders.append(f":{pname}") 

365 params[pname] = t 

366 placeholders_sql = ",".join(placeholders) 

367 sq = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders_sql}))") # nosec B608 - Safe: uses parameterized queries with bindparams() 

368 # IMPORTANT: pass plain values as kwargs to bindparams 

369 return sq.bindparams(**params) 

370 

371 # all-of: return AND of EXISTS(... = :pX) with plain values 

372 exists_clauses = [] 

373 for t in values_list: 

374 pname = f"t_{uuid.uuid4().hex[:8]}" 

375 clause = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{pname})").bindparams(**{pname: t}) # nosec B608 - Safe: uses parameterized queries with bindparams() 

376 exists_clauses.append(clause) 

377 if len(exists_clauses) == 1: 

378 return exists_clauses[0] 

379 return and_(*exists_clauses) 

380 

381 raise RuntimeError(f"Unsupported dialect for json_contains: {dialect}")