Coverage for mcpgateway / utils / sqlalchemy_modifier.py: 100%

101 statements  

« prev     ^ index     » next       coverage.py v7.13.5, created at 2026-04-06 00:56 +0100

1# -*- coding: utf-8 -*- 

2"""Location: ./mcpgateway/utils/sqlalchemy_modifier.py 

3Copyright 2025 

4SPDX-License-Identifier: Apache-2.0 

5Authors: Madhav Kandukuri 

6 

7SQLAlchemy modifiers 

8 

9- json_contains_expr: handles json_contains logic for different dialects 

10- json_contains_tag_expr: handles tag filtering for dict-format tags [{id, label}] 

11""" 

12 

13# Standard 

14import itertools 

15import re 

16import threading 

17from typing import Any, Iterable, List, Union 

18import uuid 

19 

20# Third-Party 

21from sqlalchemy import and_, func, or_, text 

22from sqlalchemy.sql.elements import TextClause 

23 

24# Thread-safe counter for generating unique bind parameter prefixes 

25_bind_counter = itertools.count() 

26_bind_counter_lock = threading.Lock() 

27 

28 

29def _ensure_list(values: Union[str, Iterable[str]]) -> List[str]: 

30 """ 

31 Normalize input into a list of strings. 

32 

33 Args: 

34 values: A single string or any iterable of strings. If `None`, an empty 

35 list is returned. 

36 

37 Returns: 

38 A list of strings. If `values` is a string it will be wrapped in a 

39 single-item list; if it's already an iterable, it will be converted to 

40 a list. If `values` is `None`, returns an empty list. 

41 """ 

42 if values is None: 

43 return [] 

44 if isinstance(values, str): 

45 return [values] 

46 return list(values) 

47 

48 

49def _generate_unique_prefix(col_ref: str) -> str: 

50 """ 

51 Generate a unique SQL bind parameter prefix for a column reference. 

52 

53 Combines a sanitized column name with a thread-safe counter to ensure 

54 unique bind parameter names across all calls, even when: 

55 - The same column is filtered multiple times in one query 

56 - Different column refs sanitize to the same string (e.g., a_b.c vs a.b_c) 

57 

58 Args: 

59 col_ref: Column reference like "resources.tags" 

60 

61 Returns: 

62 Unique prefix like "resources_tags_42" 

63 """ 

64 sanitized = re.sub(r"[^a-zA-Z0-9]", "_", col_ref) 

65 with _bind_counter_lock: 

66 counter = next(_bind_counter) 

67 return f"{sanitized}_{counter}" 

68 

69 

70def _sqlite_tag_any_template(col_ref: str, prefix: str, n: int) -> TextClause: 

71 """ 

72 Build a SQLite SQL template for matching ANY of the given tags 

73 inside a JSON array column. 

74 

75 This template supports both legacy string tags and object-style tags 

76 (e.g., {"id": "api"}). It safely guards `json_extract` with 

77 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values. 

78 

79 Design: We only filter by 'id' field, not 'label'. TagValidator ensures 

80 all properly validated tags have an 'id' field. See json_contains_tag_expr 

81 docstring for rationale. 

82 

83 The generated SQL uses unique bind parameters with the provided prefix 

84 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag 

85 filters are used in the same query. 

86 

87 Args: 

88 col_ref (str): Fully-qualified column reference (e.g., "resources.tags"). 

89 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix). 

90 n (int): Number of tag values being matched. 

91 

92 Returns: 

93 sqlalchemy.sql.elements.TextClause: 

94 A SQL template for matching ANY of the given tags. 

95 """ 

96 if n == 1: 

97 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p0 OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p0)" # nosec B608 

98 sql = tmp_.strip() 

99 else: 

100 placeholders = ",".join(f":{prefix}_p{i}" for i in range(n)) 

101 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders}) OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) IN ({placeholders}))" # nosec B608 

102 sql = tmp_.strip() 

103 

104 return text(sql) 

105 

106 

107def _sqlite_tag_all_template(col_ref: str, prefix: str, n: int) -> TextClause: 

108 """ 

109 Build a SQLite SQL template for matching ALL of the given tags 

110 inside a JSON array column. 

111 

112 This is implemented as an AND-chain of EXISTS subqueries, where each 

113 subquery ensures the presence of one required tag. 

114 

115 This template supports both legacy string tags and object-style tags 

116 (e.g., {"id": "api"}). It safely guards `json_extract` with 

117 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values. 

118 

119 The generated SQL uses unique bind parameters with the provided prefix 

120 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag 

121 filters are used in the same query. 

122 

123 Args: 

124 col_ref (str): Fully-qualified column reference (e.g., "resources.tags"). 

125 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix). 

126 n (int): Number of tag values being matched. 

127 

128 Returns: 

129 sqlalchemy.sql.elements.TextClause: 

130 A SQL template for matching ALL of the given tags. 

131 """ 

132 clauses = [] 

133 for i in range(n): 

134 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p{i} OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p{i})" # nosec B608 

135 clauses.append(tmp_.strip()) 

136 

137 return text(" AND ".join(clauses)) 

138 

139 

140def json_contains_tag_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any: 

141 """ 

142 Return a SQLAlchemy expression that is True when JSON column `col` 

143 contains tags matching the given values. Handles both legacy List[str] 

144 and new List[Dict[str, str]] (with 'id' field) tag formats. 

145 

146 Args: 

147 session: database session 

148 col: column that contains JSON array of tags 

149 values: list of tag IDs to match against 

150 match_any: Boolean to set OR (True) or AND (False) matching 

151 

152 Returns: 

153 Any: SQLAlchemy boolean expression suitable for use in .where() 

154 

155 Raises: 

156 RuntimeError: If dialect is not supported 

157 ValueError: If values is empty 

158 

159 Examples: 

160 SQLite builds a parameterized `TextClause` using `json_each()`: 

161 

162 >>> from unittest.mock import Mock, patch 

163 >>> from sqlalchemy.sql.elements import TextClause 

164 >>> from mcpgateway.utils import sqlalchemy_modifier as mod 

165 >>> session = Mock() 

166 >>> session.get_bind.return_value.dialect.name = "sqlite" 

167 >>> col = Mock() 

168 >>> col.table.name = "resources" 

169 >>> col.name = "tags" 

170 >>> with patch.object(mod, "_generate_unique_prefix", return_value="resources_tags_0"): 

171 ... expr = mod.json_contains_tag_expr(session, col, ["api", "db"], match_any=True) 

172 ... (isinstance(expr, TextClause), "json_each(resources.tags)" in expr.text, expr.compile().params["resources_tags_0_p0"]) 

173 (True, True, 'api') 

174 

175 Unsupported dialects raise: 

176 

177 >>> session.get_bind.return_value.dialect.name = "oracle" 

178 >>> mod.json_contains_tag_expr(session, col, ["x"]) 

179 Traceback (most recent call last): 

180 ... 

181 RuntimeError: Unsupported dialect for json_contains_tag: oracle 

182 """ 

183 values_list = _ensure_list(values) 

184 if not values_list: 

185 raise ValueError("values must be non-empty") 

186 

187 dialect = session.get_bind().dialect.name 

188 

189 # ---------- PostgreSQL ---------- 

190 # For dict-format tags: use JSON functions that work with both JSON and JSONB types 

191 # Note: .contains() only works with JSONB, but our column is JSON type 

192 # 

193 # Design: We only filter by 'id' field, not 'label'. This is intentional because: 

194 # 1. TagValidator.validate_list() always creates tags with both id and label 

195 # 2. The 'id' is the normalized, canonical identifier for matching 

196 # 3. The 'label' is for display purposes only 

197 # If a tag somehow exists without 'id' (malformed data), it won't match at DB level, 

198 # but _get_tag_id() has a Python-side fallback for graceful handling. 

199 if dialect == "postgresql": 

200 # Third-Party 

201 from sqlalchemy import bindparam, cast, exists, select 

202 from sqlalchemy.dialects.postgresql import JSONB 

203 from sqlalchemy.sql import literal_column 

204 

205 # Build conditions for each tag value using JSON functions 

206 conditions = [] 

207 for tag_value in values_list: 

208 # Generate unique parameter name 

209 param_name = f"tag_{uuid.uuid4().hex[:8]}" 

210 param_dict = f"tag_{uuid.uuid4().hex[:8]}" 

211 

212 # For string tags: use @> operator to check if JSONB array contains the value 

213 # Cast the tag_value to JSONB array and check containment 

214 string_match = cast(col, JSONB).op("@>")(cast(func.jsonb_build_array(bindparam(param_name, value=tag_value)), JSONB)) 

215 

216 # For dict tags: use EXISTS with jsonb_array_elements to check 'id' field 

217 # Use table_valued() for explicit column reference (elem.c.value) 

218 # This is the idiomatic SQLAlchemy pattern for table-valued functions 

219 elem_table = func.jsonb_array_elements(cast(col, JSONB)).table_valued("value").alias("elem") 

220 dict_match = exists(select(literal_column("1")).select_from(elem_table).where(elem_table.c.value.op("->>")(literal_column("'id'")) == bindparam(param_dict, value=tag_value))) 

221 

222 conditions.append(or_(string_match, dict_match)) 

223 

224 if match_any: 

225 return or_(*conditions) 

226 return and_(*conditions) 

227 

228 # ---------- SQLite (json1) ---------- 

229 # For dict-format tags: use json_extract to get the 'id' field 

230 # Use CASE WHEN type = 'object' to avoid "malformed JSON" error on string elements 

231 if dialect == "sqlite": 

232 table_name = getattr(getattr(col, "table", None), "name", None) 

233 column_name = getattr(col, "name", None) or str(col) 

234 col_ref = f"{table_name}.{column_name}" if table_name else column_name 

235 

236 n = len(values_list) 

237 if n == 0: 

238 raise ValueError("values must be non-empty") 

239 

240 # Generate unique prefix to avoid bind name collisions when multiple 

241 # tag filters are combined in the same query (even on the same column 

242 # or when different column refs sanitize to the same string) 

243 prefix = _generate_unique_prefix(col_ref) 

244 params = {f"{prefix}_p{i}": t for i, t in enumerate(values_list)} 

245 

246 if match_any: 

247 tmpl = _sqlite_tag_any_template(col_ref, prefix, n) 

248 return tmpl.bindparams(**params) 

249 

250 tmpl = _sqlite_tag_all_template(col_ref, prefix, n) 

251 return tmpl.bindparams(**params) 

252 

253 raise RuntimeError(f"Unsupported dialect for json_contains_tag: {dialect}") 

254 

255 

256def json_contains_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any: 

257 """ 

258 Return a SQLAlchemy expression that is True when JSON column `col` 

259 contains the scalar `value`. `session` is used to detect dialect. 

260 Assumes `col` is a JSON/JSONB column (array-of-strings case). 

261 

262 Args: 

263 session: database session 

264 col: column that contains JSON 

265 values: list of values to check for in json 

266 match_any: Boolean to set OR or AND matching 

267 

268 Returns: 

269 Any: SQLAlchemy boolean expression suitable for use in .where() 

270 

271 Raises: 

272 RuntimeError: If dialect is not supported 

273 ValueError: If values is empty 

274 

275 Examples: 

276 SQLite builds a parameterized SQL fragment using `json_each()`: 

277 

278 >>> from unittest.mock import Mock 

279 >>> from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr 

280 >>> session = Mock() 

281 >>> session.get_bind.return_value.dialect.name = "sqlite" 

282 >>> col = Mock() 

283 >>> col.table.name = "resources" 

284 >>> col.name = "scopes" 

285 >>> expr = json_contains_expr(session, col, ["a", "b"], match_any=True) 

286 >>> ("json_each(resources.scopes)" in expr.text, "value IN" in expr.text, len(expr.compile().params)) 

287 (True, True, 2) 

288 

289 Unsupported dialects raise: 

290 

291 >>> session.get_bind.return_value.dialect.name = "oracle" 

292 >>> json_contains_expr(session, col, ["x"]) 

293 Traceback (most recent call last): 

294 ... 

295 RuntimeError: Unsupported dialect for json_contains: oracle 

296 """ 

297 values_list = _ensure_list(values) 

298 if not values_list: 

299 raise ValueError("values must be non-empty") 

300 

301 dialect = session.get_bind().dialect.name 

302 

303 # ---------- PostgreSQL ---------- 

304 # - all-of: col.contains(list) (works if col is JSONB) 

305 # - any-of: use OR of col.contains([value]) (or use ?| operator if you prefer) 

306 if dialect == "postgresql": 

307 # prefer JSONB .contains for all-of 

308 if not match_any: 

309 return col.contains(values_list) 

310 # match_any: use OR over element-containment 

311 return or_(*[col.contains([t]) for t in values_list]) 

312 

313 # ---------- SQLite (json1) ---------- 

314 # SQLite doesn't have JSON_CONTAINS. We build safe SQL: 

315 # - any-of: single EXISTS ... WHERE value IN (:p0,:p1,...) 

316 # - all-of: multiple EXISTS with unique bind params (one EXISTS per value) => AND semantics 

317 if dialect == "sqlite": 

318 table_name = getattr(getattr(col, "table", None), "name", None) 

319 column_name = getattr(col, "name", None) or str(col) 

320 col_ref = f"{table_name}.{column_name}" if table_name else column_name 

321 

322 if match_any: 

323 # Build placeholders with unique param names and pass *values* to bindparams 

324 params = {} 

325 placeholders = [] 

326 for i, t in enumerate(values_list): 

327 pname = f"t_{uuid.uuid4().hex[:8]}_{i}" 

328 placeholders.append(f":{pname}") 

329 params[pname] = t 

330 placeholders_sql = ",".join(placeholders) 

331 sq = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders_sql}))") # nosec B608 - Safe: uses parameterized queries with bindparams() 

332 # IMPORTANT: pass plain values as kwargs to bindparams 

333 return sq.bindparams(**params) 

334 

335 # all-of: return AND of EXISTS(... = :pX) with plain values 

336 exists_clauses = [] 

337 for t in values_list: 

338 pname = f"t_{uuid.uuid4().hex[:8]}" 

339 clause = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{pname})").bindparams(**{pname: t}) # nosec B608 - Safe: uses parameterized queries with bindparams() 

340 exists_clauses.append(clause) 

341 if len(exists_clauses) == 1: 

342 return exists_clauses[0] 

343 return and_(*exists_clauses) 

344 

345 raise RuntimeError(f"Unsupported dialect for json_contains: {dialect}")