From 88ed834c29a33681f564ea9bb14a7d30504d6757 Mon Sep 17 00:00:00 2001 From: fatelei Date: Fri, 6 Mar 2026 22:48:46 +0800 Subject: [PATCH] feat: support celery --- .../google/cloud/sqlcommenter/celery.py | 91 ++++++++++++++++ .../cloud/sqlcommenter/sqlalchemy/executor.py | 3 + python/sqlcommenter-python/setup.py | 1 + .../tests/generic/test_celery.py | 102 ++++++++++++++++++ .../tests/sqlalchemy/tests.py | 35 ++++++ python/sqlcommenter-python/tox.ini | 6 +- 6 files changed, 235 insertions(+), 3 deletions(-) create mode 100644 python/sqlcommenter-python/google/cloud/sqlcommenter/celery.py create mode 100644 python/sqlcommenter-python/tests/generic/test_celery.py diff --git a/python/sqlcommenter-python/google/cloud/sqlcommenter/celery.py b/python/sqlcommenter-python/google/cloud/sqlcommenter/celery.py new file mode 100644 index 00000000..b7d041dc --- /dev/null +++ b/python/sqlcommenter-python/google/cloud/sqlcommenter/celery.py @@ -0,0 +1,91 @@ +#!/usr/bin/python +# +# Copyright 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + + +from contextvars import ContextVar +from typing import Any, Dict, Optional + + +try: + import celery + from celery import signals +except Exception: + celery = None + signals = None + + +_context: ContextVar[Dict[str, Any]] = ContextVar("sqlcommenter_celery_context", default={}) + + +def get_celery_info() -> Dict[str, Any]: + info = _context.get() or {} + return dict(info) if info else {} + + +def install_signals() -> None: + if celery is None or signals is None: + raise ImportError("celery is not installed.") + + def _on_task_prerun(sender=None, task_id: Optional[str]=None, task=None, args=None, kwargs=None, **kw): + fw = f"celery:{getattr(celery, '__version__', 'unknown')}" + info: Dict[str, Any] = {"framework": fw} + t = task if task is not None else sender + try: + if t is not None: + name = getattr(t, "name", None) + if not name: + req = getattr(t, "request", None) + name = getattr(req, "task", None) + if name: + info["task"] = name + req = getattr(t, "request", None) + if req is not None: + rk = getattr(req, "routing_key", None) + if not rk: + di = getattr(req, "delivery_info", None) or {} + rk = di.get("routing_key") if isinstance(di, dict) else None + if rk: + info["route"] = rk + except Exception: + pass + token = _context.set(info) + if t is not None: + try: + setattr(t, "_sqlcommenter_token", token) + except Exception: + pass + + def _on_task_postrun(sender=None, task_id: Optional[str]=None, task=None, args=None, kwargs=None, **kw): + t = task if task is not None else sender + token = None + if t is not None: + token = getattr(t, "_sqlcommenter_token", None) + try: + if token is not None: + _context.reset(token) + try: + delattr(t, "_sqlcommenter_token") + except Exception: + pass + else: + _context.set({}) + except Exception: + pass + + signals.task_prerun.connect(_on_task_prerun, weak=False) + signals.task_postrun.connect(_on_task_postrun, weak=False) diff --git a/python/sqlcommenter-python/google/cloud/sqlcommenter/sqlalchemy/executor.py b/python/sqlcommenter-python/google/cloud/sqlcommenter/sqlalchemy/executor.py index da3ae510..fa585a55 100644 --- a/python/sqlcommenter-python/google/cloud/sqlcommenter/sqlalchemy/executor.py +++ b/python/sqlcommenter-python/google/cloud/sqlcommenter/sqlalchemy/executor.py @@ -24,6 +24,7 @@ from google.cloud.sqlcommenter import add_sql_comment from google.cloud.sqlcommenter.fastapi import get_fastapi_info from google.cloud.sqlcommenter.flask import get_flask_info +from google.cloud.sqlcommenter.celery import get_celery_info from google.cloud.sqlcommenter.opencensus import get_opencensus_values from google.cloud.sqlcommenter.opentelemetry import get_opentelemetry_values @@ -47,6 +48,8 @@ def get_framework_info(): info = get_flask_info() if not info: info = get_fastapi_info() + if not info: + info = get_celery_info() return info def before_cursor_execute(conn, cursor, sql, parameters, context, executemany): diff --git a/python/sqlcommenter-python/setup.py b/python/sqlcommenter-python/setup.py index a68173c6..902b0692 100644 --- a/python/sqlcommenter-python/setup.py +++ b/python/sqlcommenter-python/setup.py @@ -41,6 +41,7 @@ def read_file(filename): 'fastapi': ['fastapi'], 'psycopg2': ['psycopg2'], 'sqlalchemy': ['sqlalchemy'], + 'celery': ['celery>=5'], 'opencensus': ['opencensus'], 'opentelemetry': ["opentelemetry-api ~= 1.0"], }, diff --git a/python/sqlcommenter-python/tests/generic/test_celery.py b/python/sqlcommenter-python/tests/generic/test_celery.py new file mode 100644 index 00000000..33485dbb --- /dev/null +++ b/python/sqlcommenter-python/tests/generic/test_celery.py @@ -0,0 +1,102 @@ +#!/usr/bin/python +# +# Copyright 2026 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import + +import unittest + +import google.cloud.sqlcommenter.celery as mod + + +class CeleryModuleTests(unittest.TestCase): + def tearDown(self) -> None: + try: + mod._context.set({}) + except Exception: + pass + + def test_get_celery_info_empty(self): + self.assertEqual(mod.get_celery_info(), {}) + + def test_get_celery_info_after_context_set(self): + token = mod._context.set({ + 'framework': 'celery:5.3.0', + 'task': 'tasks.add', + 'route': 'celery', + }) + try: + self.assertEqual( + mod.get_celery_info(), + {'framework': 'celery:5.3.0', 'task': 'tasks.add', 'route': 'celery'} + ) + finally: + mod._context.reset(token) + + def test_install_signals_without_celery_raises(self): + original_celery, original_signals = mod.celery, mod.signals + try: + mod.celery = None + mod.signals = None + with self.assertRaises(ImportError): + mod.install_signals() + finally: + mod.celery, mod.signals = original_celery, original_signals + + def test_signal_flow_sets_and_clears_context(self): + class _SigList: + def __init__(self): + self._cbs = [] + def connect(self, cb, weak=False): # noqa: ARG002 - weak unused + self._cbs.append(cb) + return cb + + class FakeSignals: + def __init__(self): + self.task_prerun = _SigList() + self.task_postrun = _SigList() + + class FakeCelery: + __version__ = '9.9.9' + + original_celery, original_signals = mod.celery, mod.signals + try: + mod.celery = FakeCelery() + mod.signals = FakeSignals() + mod.install_signals() + + self.assertEqual(len(mod.signals.task_prerun._cbs), 1) + self.assertEqual(len(mod.signals.task_postrun._cbs), 1) + + prerun_cb = mod.signals.task_prerun._cbs[0] + postrun_cb = mod.signals.task_postrun._cbs[0] + + class FakeReq: + routing_key = 'celery' + class FakeTask: + name = 'tasks.add' + request = FakeReq() + + t = FakeTask() + prerun_cb(task=t) + self.assertEqual( + mod.get_celery_info(), + {'framework': 'celery:9.9.9', 'task': 'tasks.add', 'route': 'celery'} + ) + + postrun_cb(task=t) + self.assertEqual(mod.get_celery_info(), {}) + finally: + mod.celery, mod.signals = original_celery, original_signals \ No newline at end of file diff --git a/python/sqlcommenter-python/tests/sqlalchemy/tests.py b/python/sqlcommenter-python/tests/sqlalchemy/tests.py index d78b5a1e..fdbb0e6f 100644 --- a/python/sqlcommenter-python/tests/sqlalchemy/tests.py +++ b/python/sqlcommenter-python/tests/sqlalchemy/tests.py @@ -161,3 +161,38 @@ def test_route_disabled(self, get_info): "SELECT 1 /*controller='c',framework='fastapi'*/;", with_route=False, ) + + +class CeleryTests(SQLAlchemyTestCase): + celery_info = { + 'framework': 'celery', + 'controller': 'tasks.add', + 'route': 'celery', + } + + @mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info) + def test_all_data(self, get_info): + self.assertSQL( + "SELECT 1 /*controller='tasks.add',framework='celery',route='celery'*/;", + ) + + @mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info) + def test_framework_disabled(self, get_info): + self.assertSQL( + "SELECT 1 /*controller='tasks.add',route='celery'*/;", + with_framework=False, + ) + + @mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info) + def test_controller_disabled(self, get_info): + self.assertSQL( + "SELECT 1 /*framework='celery',route='celery'*/;", + with_controller=False, + ) + + @mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info) + def test_route_disabled(self, get_info): + self.assertSQL( + "SELECT 1 /*controller='tasks.add',framework='celery'*/;", + with_route=False, + ) diff --git a/python/sqlcommenter-python/tox.ini b/python/sqlcommenter-python/tox.ini index aa16b0b9..7a1e9c7e 100644 --- a/python/sqlcommenter-python/tox.ini +++ b/python/sqlcommenter-python/tox.ini @@ -1,8 +1,8 @@ [tox] envlist = - py3{6,7,8,9}-django{21,22,30,31,32} - py3{6,7,8,9}-psycopg2 - py3{6,7,8,9}-{flask,generic,sqlalchemy} + py3{6,7,8,9,10}-django{21,22,30,31,32} + py3{6,7,8,9,10}-psycopg2 + py3{6,7,8,9,10}-{flask,generic,sqlalchemy} flake8 #