Coverage for mcpgateway / utils / sqlalchemy_modifier.py: 100%
120 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-02-11 07:10 +0000
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/sqlalchemy_modifier.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Madhav Kandukuri
7SQLAlchemy modifiers
9- json_contains_expr: handles json_contains logic for different dialects
10- json_contains_tag_expr: handles tag filtering for dict-format tags [{id, label}]
11"""
13# Standard
14import itertools
15import re
16import threading
17from typing import Any, Iterable, List, Union
18import uuid
20# Third-Party
21import orjson
22from sqlalchemy import and_, func, or_, text
23from sqlalchemy.sql.elements import TextClause
25# Thread-safe counter for generating unique bind parameter prefixes
26_bind_counter = itertools.count()
27_bind_counter_lock = threading.Lock()
30def _ensure_list(values: Union[str, Iterable[str]]) -> List[str]:
31 """
32 Normalize input into a list of strings.
34 Args:
35 values: A single string or any iterable of strings. If `None`, an empty
36 list is returned.
38 Returns:
39 A list of strings. If `values` is a string it will be wrapped in a
40 single-item list; if it's already an iterable, it will be converted to
41 a list. If `values` is `None`, returns an empty list.
42 """
43 if values is None:
44 return []
45 if isinstance(values, str):
46 return [values]
47 return list(values)
50def _generate_unique_prefix(col_ref: str) -> str:
51 """
52 Generate a unique SQL bind parameter prefix for a column reference.
54 Combines a sanitized column name with a thread-safe counter to ensure
55 unique bind parameter names across all calls, even when:
56 - The same column is filtered multiple times in one query
57 - Different column refs sanitize to the same string (e.g., a_b.c vs a.b_c)
59 Args:
60 col_ref: Column reference like "resources.tags"
62 Returns:
63 Unique prefix like "resources_tags_42"
64 """
65 sanitized = re.sub(r"[^a-zA-Z0-9]", "_", col_ref)
66 with _bind_counter_lock:
67 counter = next(_bind_counter)
68 return f"{sanitized}_{counter}"
71def _sqlite_tag_any_template(col_ref: str, prefix: str, n: int) -> TextClause:
72 """
73 Build a SQLite SQL template for matching ANY of the given tags
74 inside a JSON array column.
76 This template supports both legacy string tags and object-style tags
77 (e.g., {"id": "api"}). It safely guards `json_extract` with
78 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
80 Design: We only filter by 'id' field, not 'label'. TagValidator ensures
81 all properly validated tags have an 'id' field. See json_contains_tag_expr
82 docstring for rationale.
84 The generated SQL uses unique bind parameters with the provided prefix
85 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
86 filters are used in the same query.
88 Args:
89 col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
90 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
91 n (int): Number of tag values being matched.
93 Returns:
94 sqlalchemy.sql.elements.TextClause:
95 A SQL template for matching ANY of the given tags.
96 """
97 if n == 1:
98 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p0 OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p0)" # nosec B608
99 sql = tmp_.strip()
100 else:
101 placeholders = ",".join(f":{prefix}_p{i}" for i in range(n))
102 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders}) OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) IN ({placeholders}))" # nosec B608
103 sql = tmp_.strip()
105 return text(sql)
108def _sqlite_tag_all_template(col_ref: str, prefix: str, n: int) -> TextClause:
109 """
110 Build a SQLite SQL template for matching ALL of the given tags
111 inside a JSON array column.
113 This is implemented as an AND-chain of EXISTS subqueries, where each
114 subquery ensures the presence of one required tag.
116 This template supports both legacy string tags and object-style tags
117 (e.g., {"id": "api"}). It safely guards `json_extract` with
118 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
120 The generated SQL uses unique bind parameters with the provided prefix
121 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
122 filters are used in the same query.
124 Args:
125 col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
126 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
127 n (int): Number of tag values being matched.
129 Returns:
130 sqlalchemy.sql.elements.TextClause:
131 A SQL template for matching ALL of the given tags.
132 """
133 clauses = []
134 for i in range(n):
135 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p{i} OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p{i})" # nosec B608
136 clauses.append(tmp_.strip())
138 return text(" AND ".join(clauses))
141def json_contains_tag_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
142 """
143 Return a SQLAlchemy expression that is True when JSON column `col`
144 contains tags matching the given values. Handles both legacy List[str]
145 and new List[Dict[str, str]] (with 'id' field) tag formats.
147 Args:
148 session: database session
149 col: column that contains JSON array of tags
150 values: list of tag IDs to match against
151 match_any: Boolean to set OR (True) or AND (False) matching
153 Returns:
154 Any: SQLAlchemy boolean expression suitable for use in .where()
156 Raises:
157 RuntimeError: If dialect is not supported
158 ValueError: If values is empty
160 Examples:
161 SQLite builds a parameterized `TextClause` using `json_each()`:
163 >>> from unittest.mock import Mock, patch
164 >>> from sqlalchemy.sql.elements import TextClause
165 >>> from mcpgateway.utils import sqlalchemy_modifier as mod
166 >>> session = Mock()
167 >>> session.get_bind.return_value.dialect.name = "sqlite"
168 >>> col = Mock()
169 >>> col.table.name = "resources"
170 >>> col.name = "tags"
171 >>> with patch.object(mod, "_generate_unique_prefix", return_value="resources_tags_0"):
172 ... expr = mod.json_contains_tag_expr(session, col, ["api", "db"], match_any=True)
173 ... (isinstance(expr, TextClause), "json_each(resources.tags)" in expr.text, expr.compile().params["resources_tags_0_p0"])
174 (True, True, 'api')
176 Unsupported dialects raise:
178 >>> session.get_bind.return_value.dialect.name = "oracle"
179 >>> mod.json_contains_tag_expr(session, col, ["x"])
180 Traceback (most recent call last):
181 ...
182 RuntimeError: Unsupported dialect for json_contains_tag: oracle
183 """
184 values_list = _ensure_list(values)
185 if not values_list:
186 raise ValueError("values must be non-empty")
188 dialect = session.get_bind().dialect.name
190 # ---------- MySQL ----------
191 # For dict-format tags: use JSON_SEARCH to find tags with matching id
192 # JSON_SEARCH returns path if found, NULL otherwise
193 if dialect == "mysql":
194 # Build conditions that check for both string tags and dict tags with matching id
195 conditions = []
196 for tag_value in values_list:
197 # Check if tag exists as plain string OR as dict with matching id
198 # JSON_SEARCH(col, 'one', value) finds plain string value
199 # JSON_CONTAINS with path $.*.id checks dict format
200 string_match = func.json_search(col, "one", tag_value).isnot(None)
201 dict_match = func.json_contains(col, orjson.dumps([{"id": tag_value}]).decode()) == 1
202 conditions.append(or_(string_match, dict_match))
204 if match_any:
205 return or_(*conditions)
206 return and_(*conditions)
208 # ---------- PostgreSQL ----------
209 # For dict-format tags: use JSON functions that work with both JSON and JSONB types
210 # Note: .contains() only works with JSONB, but our column is JSON type
211 #
212 # Design: We only filter by 'id' field, not 'label'. This is intentional because:
213 # 1. TagValidator.validate_list() always creates tags with both id and label
214 # 2. The 'id' is the normalized, canonical identifier for matching
215 # 3. The 'label' is for display purposes only
216 # If a tag somehow exists without 'id' (malformed data), it won't match at DB level,
217 # but _get_tag_id() has a Python-side fallback for graceful handling.
218 if dialect == "postgresql":
219 # Third-Party
220 from sqlalchemy import bindparam, cast, exists, select
221 from sqlalchemy.dialects.postgresql import JSONB
222 from sqlalchemy.sql import literal_column
224 # Build conditions for each tag value using JSON functions
225 conditions = []
226 for tag_value in values_list:
227 # Generate unique parameter name
228 param_name = f"tag_{uuid.uuid4().hex[:8]}"
229 param_dict = f"tag_{uuid.uuid4().hex[:8]}"
231 # For string tags: use @> operator to check if JSONB array contains the value
232 # Cast the tag_value to JSONB array and check containment
233 string_match = cast(col, JSONB).op("@>")(cast(func.jsonb_build_array(bindparam(param_name, value=tag_value)), JSONB))
235 # For dict tags: use EXISTS with jsonb_array_elements to check 'id' field
236 # Use table_valued() for explicit column reference (elem.c.value)
237 # This is the idiomatic SQLAlchemy pattern for table-valued functions
238 elem_table = func.jsonb_array_elements(cast(col, JSONB)).table_valued("value").alias("elem")
239 dict_match = exists(select(literal_column("1")).select_from(elem_table).where(elem_table.c.value.op("->>")(literal_column("'id'")) == bindparam(param_dict, value=tag_value)))
241 conditions.append(or_(string_match, dict_match))
243 if match_any:
244 return or_(*conditions)
245 return and_(*conditions)
247 # ---------- SQLite (json1) ----------
248 # For dict-format tags: use json_extract to get the 'id' field
249 # Use CASE WHEN type = 'object' to avoid "malformed JSON" error on string elements
250 if dialect == "sqlite":
251 table_name = getattr(getattr(col, "table", None), "name", None)
252 column_name = getattr(col, "name", None) or str(col)
253 col_ref = f"{table_name}.{column_name}" if table_name else column_name
255 n = len(values_list)
256 if n == 0:
257 raise ValueError("values must be non-empty")
259 # Generate unique prefix to avoid bind name collisions when multiple
260 # tag filters are combined in the same query (even on the same column
261 # or when different column refs sanitize to the same string)
262 prefix = _generate_unique_prefix(col_ref)
263 params = {f"{prefix}_p{i}": t for i, t in enumerate(values_list)}
265 if match_any:
266 tmpl = _sqlite_tag_any_template(col_ref, prefix, n)
267 return tmpl.bindparams(**params)
269 tmpl = _sqlite_tag_all_template(col_ref, prefix, n)
270 return tmpl.bindparams(**params)
272 raise RuntimeError(f"Unsupported dialect for json_contains_tag: {dialect}")
275def json_contains_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
276 """
277 Return a SQLAlchemy expression that is True when JSON column `col`
278 contains the scalar `value`. `session` is used to detect dialect.
279 Assumes `col` is a JSON/JSONB column (array-of-strings case).
281 Args:
282 session: database session
283 col: column that contains JSON
284 values: list of values to check for in json
285 match_any: Boolean to set OR or AND matching
287 Returns:
288 Any: SQLAlchemy boolean expression suitable for use in .where()
290 Raises:
291 RuntimeError: If dialect is not supported
292 ValueError: If values is empty
294 Examples:
295 SQLite builds a parameterized SQL fragment using `json_each()`:
297 >>> from unittest.mock import Mock
298 >>> from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr
299 >>> session = Mock()
300 >>> session.get_bind.return_value.dialect.name = "sqlite"
301 >>> col = Mock()
302 >>> col.table.name = "resources"
303 >>> col.name = "scopes"
304 >>> expr = json_contains_expr(session, col, ["a", "b"], match_any=True)
305 >>> ("json_each(resources.scopes)" in expr.text, "value IN" in expr.text, len(expr.compile().params))
306 (True, True, 2)
308 Unsupported dialects raise:
310 >>> session.get_bind.return_value.dialect.name = "oracle"
311 >>> json_contains_expr(session, col, ["x"])
312 Traceback (most recent call last):
313 ...
314 RuntimeError: Unsupported dialect for json_contains: oracle
315 """
316 values_list = _ensure_list(values)
317 if not values_list:
318 raise ValueError("values must be non-empty")
320 dialect = session.get_bind().dialect.name
322 # ---------- MySQL ----------
323 # - all-of: JSON_CONTAINS(col, '["a","b"]') == 1
324 # - any-of: prefer JSON_OVERLAPS (MySQL >= 8.0.17), otherwise OR of JSON_CONTAINS for each value
325 if dialect == "mysql":
326 try:
327 if match_any:
328 # JSON_OVERLAPS exists in modern MySQL; SQLAlchemy will emit func.json_overlaps(...)
329 return func.json_overlaps(col, orjson.dumps(values_list).decode()) == 1
330 else:
331 return func.json_contains(col, orjson.dumps(values_list).decode()) == 1
332 except Exception:
333 # Fallback: compose OR of json_contains for each scalar
334 if match_any:
335 return or_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list])
336 else:
337 return and_(*[func.json_contains(col, orjson.dumps(t).decode()) == 1 for t in values_list])
339 # ---------- PostgreSQL ----------
340 # - all-of: col.contains(list) (works if col is JSONB)
341 # - any-of: use OR of col.contains([value]) (or use ?| operator if you prefer)
342 if dialect == "postgresql":
343 # prefer JSONB .contains for all-of
344 if not match_any:
345 return col.contains(values_list)
346 # match_any: use OR over element-containment
347 return or_(*[col.contains([t]) for t in values_list])
349 # ---------- SQLite (json1) ----------
350 # SQLite doesn't have JSON_CONTAINS. We build safe SQL:
351 # - any-of: single EXISTS ... WHERE value IN (:p0,:p1,...)
352 # - all-of: multiple EXISTS with unique bind params (one EXISTS per value) => AND semantics
353 if dialect == "sqlite":
354 table_name = getattr(getattr(col, "table", None), "name", None)
355 column_name = getattr(col, "name", None) or str(col)
356 col_ref = f"{table_name}.{column_name}" if table_name else column_name
358 if match_any:
359 # Build placeholders with unique param names and pass *values* to bindparams
360 params = {}
361 placeholders = []
362 for i, t in enumerate(values_list):
363 pname = f"t_{uuid.uuid4().hex[:8]}_{i}"
364 placeholders.append(f":{pname}")
365 params[pname] = t
366 placeholders_sql = ",".join(placeholders)
367 sq = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders_sql}))") # nosec B608 - Safe: uses parameterized queries with bindparams()
368 # IMPORTANT: pass plain values as kwargs to bindparams
369 return sq.bindparams(**params)
371 # all-of: return AND of EXISTS(... = :pX) with plain values
372 exists_clauses = []
373 for t in values_list:
374 pname = f"t_{uuid.uuid4().hex[:8]}"
375 clause = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{pname})").bindparams(**{pname: t}) # nosec B608 - Safe: uses parameterized queries with bindparams()
376 exists_clauses.append(clause)
377 if len(exists_clauses) == 1:
378 return exists_clauses[0]
379 return and_(*exists_clauses)
381 raise RuntimeError(f"Unsupported dialect for json_contains: {dialect}")