Coverage for mcpgateway / utils / sqlalchemy_modifier.py: 100%
101 statements
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
« prev ^ index » next coverage.py v7.13.5, created at 2026-04-06 00:56 +0100
1# -*- coding: utf-8 -*-
2"""Location: ./mcpgateway/utils/sqlalchemy_modifier.py
3Copyright 2025
4SPDX-License-Identifier: Apache-2.0
5Authors: Madhav Kandukuri
7SQLAlchemy modifiers
9- json_contains_expr: handles json_contains logic for different dialects
10- json_contains_tag_expr: handles tag filtering for dict-format tags [{id, label}]
11"""
13# Standard
14import itertools
15import re
16import threading
17from typing import Any, Iterable, List, Union
18import uuid
20# Third-Party
21from sqlalchemy import and_, func, or_, text
22from sqlalchemy.sql.elements import TextClause
24# Thread-safe counter for generating unique bind parameter prefixes
25_bind_counter = itertools.count()
26_bind_counter_lock = threading.Lock()
29def _ensure_list(values: Union[str, Iterable[str]]) -> List[str]:
30 """
31 Normalize input into a list of strings.
33 Args:
34 values: A single string or any iterable of strings. If `None`, an empty
35 list is returned.
37 Returns:
38 A list of strings. If `values` is a string it will be wrapped in a
39 single-item list; if it's already an iterable, it will be converted to
40 a list. If `values` is `None`, returns an empty list.
41 """
42 if values is None:
43 return []
44 if isinstance(values, str):
45 return [values]
46 return list(values)
49def _generate_unique_prefix(col_ref: str) -> str:
50 """
51 Generate a unique SQL bind parameter prefix for a column reference.
53 Combines a sanitized column name with a thread-safe counter to ensure
54 unique bind parameter names across all calls, even when:
55 - The same column is filtered multiple times in one query
56 - Different column refs sanitize to the same string (e.g., a_b.c vs a.b_c)
58 Args:
59 col_ref: Column reference like "resources.tags"
61 Returns:
62 Unique prefix like "resources_tags_42"
63 """
64 sanitized = re.sub(r"[^a-zA-Z0-9]", "_", col_ref)
65 with _bind_counter_lock:
66 counter = next(_bind_counter)
67 return f"{sanitized}_{counter}"
70def _sqlite_tag_any_template(col_ref: str, prefix: str, n: int) -> TextClause:
71 """
72 Build a SQLite SQL template for matching ANY of the given tags
73 inside a JSON array column.
75 This template supports both legacy string tags and object-style tags
76 (e.g., {"id": "api"}). It safely guards `json_extract` with
77 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
79 Design: We only filter by 'id' field, not 'label'. TagValidator ensures
80 all properly validated tags have an 'id' field. See json_contains_tag_expr
81 docstring for rationale.
83 The generated SQL uses unique bind parameters with the provided prefix
84 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
85 filters are used in the same query.
87 Args:
88 col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
89 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
90 n (int): Number of tag values being matched.
92 Returns:
93 sqlalchemy.sql.elements.TextClause:
94 A SQL template for matching ANY of the given tags.
95 """
96 if n == 1:
97 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p0 OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p0)" # nosec B608
98 sql = tmp_.strip()
99 else:
100 placeholders = ",".join(f":{prefix}_p{i}" for i in range(n))
101 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders}) OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) IN ({placeholders}))" # nosec B608
102 sql = tmp_.strip()
104 return text(sql)
107def _sqlite_tag_all_template(col_ref: str, prefix: str, n: int) -> TextClause:
108 """
109 Build a SQLite SQL template for matching ALL of the given tags
110 inside a JSON array column.
112 This is implemented as an AND-chain of EXISTS subqueries, where each
113 subquery ensures the presence of one required tag.
115 This template supports both legacy string tags and object-style tags
116 (e.g., {"id": "api"}). It safely guards `json_extract` with
117 `CASE WHEN type = 'object'` to avoid malformed JSON errors on string values.
119 The generated SQL uses unique bind parameters with the provided prefix
120 (e.g., :resources_tags_42_p0) to avoid collisions when multiple tag
121 filters are used in the same query.
123 Args:
124 col_ref (str): Fully-qualified column reference (e.g., "resources.tags").
125 prefix (str): Unique prefix for bind parameters (from _generate_unique_prefix).
126 n (int): Number of tag values being matched.
128 Returns:
129 sqlalchemy.sql.elements.TextClause:
130 A SQL template for matching ALL of the given tags.
131 """
132 clauses = []
133 for i in range(n):
134 tmp_ = f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{prefix}_p{i} OR (CASE WHEN type = 'object' THEN json_extract(value, '$.id') END) = :{prefix}_p{i})" # nosec B608
135 clauses.append(tmp_.strip())
137 return text(" AND ".join(clauses))
140def json_contains_tag_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
141 """
142 Return a SQLAlchemy expression that is True when JSON column `col`
143 contains tags matching the given values. Handles both legacy List[str]
144 and new List[Dict[str, str]] (with 'id' field) tag formats.
146 Args:
147 session: database session
148 col: column that contains JSON array of tags
149 values: list of tag IDs to match against
150 match_any: Boolean to set OR (True) or AND (False) matching
152 Returns:
153 Any: SQLAlchemy boolean expression suitable for use in .where()
155 Raises:
156 RuntimeError: If dialect is not supported
157 ValueError: If values is empty
159 Examples:
160 SQLite builds a parameterized `TextClause` using `json_each()`:
162 >>> from unittest.mock import Mock, patch
163 >>> from sqlalchemy.sql.elements import TextClause
164 >>> from mcpgateway.utils import sqlalchemy_modifier as mod
165 >>> session = Mock()
166 >>> session.get_bind.return_value.dialect.name = "sqlite"
167 >>> col = Mock()
168 >>> col.table.name = "resources"
169 >>> col.name = "tags"
170 >>> with patch.object(mod, "_generate_unique_prefix", return_value="resources_tags_0"):
171 ... expr = mod.json_contains_tag_expr(session, col, ["api", "db"], match_any=True)
172 ... (isinstance(expr, TextClause), "json_each(resources.tags)" in expr.text, expr.compile().params["resources_tags_0_p0"])
173 (True, True, 'api')
175 Unsupported dialects raise:
177 >>> session.get_bind.return_value.dialect.name = "oracle"
178 >>> mod.json_contains_tag_expr(session, col, ["x"])
179 Traceback (most recent call last):
180 ...
181 RuntimeError: Unsupported dialect for json_contains_tag: oracle
182 """
183 values_list = _ensure_list(values)
184 if not values_list:
185 raise ValueError("values must be non-empty")
187 dialect = session.get_bind().dialect.name
189 # ---------- PostgreSQL ----------
190 # For dict-format tags: use JSON functions that work with both JSON and JSONB types
191 # Note: .contains() only works with JSONB, but our column is JSON type
192 #
193 # Design: We only filter by 'id' field, not 'label'. This is intentional because:
194 # 1. TagValidator.validate_list() always creates tags with both id and label
195 # 2. The 'id' is the normalized, canonical identifier for matching
196 # 3. The 'label' is for display purposes only
197 # If a tag somehow exists without 'id' (malformed data), it won't match at DB level,
198 # but _get_tag_id() has a Python-side fallback for graceful handling.
199 if dialect == "postgresql":
200 # Third-Party
201 from sqlalchemy import bindparam, cast, exists, select
202 from sqlalchemy.dialects.postgresql import JSONB
203 from sqlalchemy.sql import literal_column
205 # Build conditions for each tag value using JSON functions
206 conditions = []
207 for tag_value in values_list:
208 # Generate unique parameter name
209 param_name = f"tag_{uuid.uuid4().hex[:8]}"
210 param_dict = f"tag_{uuid.uuid4().hex[:8]}"
212 # For string tags: use @> operator to check if JSONB array contains the value
213 # Cast the tag_value to JSONB array and check containment
214 string_match = cast(col, JSONB).op("@>")(cast(func.jsonb_build_array(bindparam(param_name, value=tag_value)), JSONB))
216 # For dict tags: use EXISTS with jsonb_array_elements to check 'id' field
217 # Use table_valued() for explicit column reference (elem.c.value)
218 # This is the idiomatic SQLAlchemy pattern for table-valued functions
219 elem_table = func.jsonb_array_elements(cast(col, JSONB)).table_valued("value").alias("elem")
220 dict_match = exists(select(literal_column("1")).select_from(elem_table).where(elem_table.c.value.op("->>")(literal_column("'id'")) == bindparam(param_dict, value=tag_value)))
222 conditions.append(or_(string_match, dict_match))
224 if match_any:
225 return or_(*conditions)
226 return and_(*conditions)
228 # ---------- SQLite (json1) ----------
229 # For dict-format tags: use json_extract to get the 'id' field
230 # Use CASE WHEN type = 'object' to avoid "malformed JSON" error on string elements
231 if dialect == "sqlite":
232 table_name = getattr(getattr(col, "table", None), "name", None)
233 column_name = getattr(col, "name", None) or str(col)
234 col_ref = f"{table_name}.{column_name}" if table_name else column_name
236 n = len(values_list)
237 if n == 0:
238 raise ValueError("values must be non-empty")
240 # Generate unique prefix to avoid bind name collisions when multiple
241 # tag filters are combined in the same query (even on the same column
242 # or when different column refs sanitize to the same string)
243 prefix = _generate_unique_prefix(col_ref)
244 params = {f"{prefix}_p{i}": t for i, t in enumerate(values_list)}
246 if match_any:
247 tmpl = _sqlite_tag_any_template(col_ref, prefix, n)
248 return tmpl.bindparams(**params)
250 tmpl = _sqlite_tag_all_template(col_ref, prefix, n)
251 return tmpl.bindparams(**params)
253 raise RuntimeError(f"Unsupported dialect for json_contains_tag: {dialect}")
256def json_contains_expr(session, col, values: Union[str, Iterable[str]], match_any: bool = True) -> Any:
257 """
258 Return a SQLAlchemy expression that is True when JSON column `col`
259 contains the scalar `value`. `session` is used to detect dialect.
260 Assumes `col` is a JSON/JSONB column (array-of-strings case).
262 Args:
263 session: database session
264 col: column that contains JSON
265 values: list of values to check for in json
266 match_any: Boolean to set OR or AND matching
268 Returns:
269 Any: SQLAlchemy boolean expression suitable for use in .where()
271 Raises:
272 RuntimeError: If dialect is not supported
273 ValueError: If values is empty
275 Examples:
276 SQLite builds a parameterized SQL fragment using `json_each()`:
278 >>> from unittest.mock import Mock
279 >>> from mcpgateway.utils.sqlalchemy_modifier import json_contains_expr
280 >>> session = Mock()
281 >>> session.get_bind.return_value.dialect.name = "sqlite"
282 >>> col = Mock()
283 >>> col.table.name = "resources"
284 >>> col.name = "scopes"
285 >>> expr = json_contains_expr(session, col, ["a", "b"], match_any=True)
286 >>> ("json_each(resources.scopes)" in expr.text, "value IN" in expr.text, len(expr.compile().params))
287 (True, True, 2)
289 Unsupported dialects raise:
291 >>> session.get_bind.return_value.dialect.name = "oracle"
292 >>> json_contains_expr(session, col, ["x"])
293 Traceback (most recent call last):
294 ...
295 RuntimeError: Unsupported dialect for json_contains: oracle
296 """
297 values_list = _ensure_list(values)
298 if not values_list:
299 raise ValueError("values must be non-empty")
301 dialect = session.get_bind().dialect.name
303 # ---------- PostgreSQL ----------
304 # - all-of: col.contains(list) (works if col is JSONB)
305 # - any-of: use OR of col.contains([value]) (or use ?| operator if you prefer)
306 if dialect == "postgresql":
307 # prefer JSONB .contains for all-of
308 if not match_any:
309 return col.contains(values_list)
310 # match_any: use OR over element-containment
311 return or_(*[col.contains([t]) for t in values_list])
313 # ---------- SQLite (json1) ----------
314 # SQLite doesn't have JSON_CONTAINS. We build safe SQL:
315 # - any-of: single EXISTS ... WHERE value IN (:p0,:p1,...)
316 # - all-of: multiple EXISTS with unique bind params (one EXISTS per value) => AND semantics
317 if dialect == "sqlite":
318 table_name = getattr(getattr(col, "table", None), "name", None)
319 column_name = getattr(col, "name", None) or str(col)
320 col_ref = f"{table_name}.{column_name}" if table_name else column_name
322 if match_any:
323 # Build placeholders with unique param names and pass *values* to bindparams
324 params = {}
325 placeholders = []
326 for i, t in enumerate(values_list):
327 pname = f"t_{uuid.uuid4().hex[:8]}_{i}"
328 placeholders.append(f":{pname}")
329 params[pname] = t
330 placeholders_sql = ",".join(placeholders)
331 sq = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value IN ({placeholders_sql}))") # nosec B608 - Safe: uses parameterized queries with bindparams()
332 # IMPORTANT: pass plain values as kwargs to bindparams
333 return sq.bindparams(**params)
335 # all-of: return AND of EXISTS(... = :pX) with plain values
336 exists_clauses = []
337 for t in values_list:
338 pname = f"t_{uuid.uuid4().hex[:8]}"
339 clause = text(f"EXISTS (SELECT 1 FROM json_each({col_ref}) WHERE value = :{pname})").bindparams(**{pname: t}) # nosec B608 - Safe: uses parameterized queries with bindparams()
340 exists_clauses.append(clause)
341 if len(exists_clauses) == 1:
342 return exists_clauses[0]
343 return and_(*exists_clauses)
345 raise RuntimeError(f"Unsupported dialect for json_contains: {dialect}")