Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions python/sqlcommenter-python/google/cloud/sqlcommenter/celery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#!/usr/bin/python
#
# Copyright 2026
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import


from contextvars import ContextVar
from typing import Any, Dict, Optional


try:
import celery
from celery import signals
except Exception:
celery = None
signals = None


_context: ContextVar[Dict[str, Any]] = ContextVar("sqlcommenter_celery_context", default={})


def get_celery_info() -> Dict[str, Any]:
info = _context.get() or {}
return dict(info) if info else {}


def install_signals() -> None:
if celery is None or signals is None:
raise ImportError("celery is not installed.")

def _on_task_prerun(sender=None, task_id: Optional[str]=None, task=None, args=None, kwargs=None, **kw):
fw = f"celery:{getattr(celery, '__version__', 'unknown')}"
info: Dict[str, Any] = {"framework": fw}
t = task if task is not None else sender
try:
if t is not None:
name = getattr(t, "name", None)
if not name:
req = getattr(t, "request", None)
name = getattr(req, "task", None)
if name:
info["task"] = name
req = getattr(t, "request", None)
if req is not None:
rk = getattr(req, "routing_key", None)
if not rk:
di = getattr(req, "delivery_info", None) or {}
rk = di.get("routing_key") if isinstance(di, dict) else None
if rk:
info["route"] = rk
except Exception:
pass
token = _context.set(info)
if t is not None:
try:
setattr(t, "_sqlcommenter_token", token)
except Exception:
pass

def _on_task_postrun(sender=None, task_id: Optional[str]=None, task=None, args=None, kwargs=None, **kw):
t = task if task is not None else sender
token = None
if t is not None:
token = getattr(t, "_sqlcommenter_token", None)
try:
if token is not None:
_context.reset(token)
try:
delattr(t, "_sqlcommenter_token")
except Exception:
pass
else:
_context.set({})
except Exception:
pass

signals.task_prerun.connect(_on_task_prerun, weak=False)
signals.task_postrun.connect(_on_task_postrun, weak=False)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from google.cloud.sqlcommenter import add_sql_comment
from google.cloud.sqlcommenter.fastapi import get_fastapi_info
from google.cloud.sqlcommenter.flask import get_flask_info
from google.cloud.sqlcommenter.celery import get_celery_info
from google.cloud.sqlcommenter.opencensus import get_opencensus_values
from google.cloud.sqlcommenter.opentelemetry import get_opentelemetry_values

Expand All @@ -47,6 +48,8 @@ def get_framework_info():
info = get_flask_info()
if not info:
info = get_fastapi_info()
if not info:
info = get_celery_info()
return info

def before_cursor_execute(conn, cursor, sql, parameters, context, executemany):
Expand Down
1 change: 1 addition & 0 deletions python/sqlcommenter-python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def read_file(filename):
'fastapi': ['fastapi'],
'psycopg2': ['psycopg2'],
'sqlalchemy': ['sqlalchemy'],
'celery': ['celery>=5'],
'opencensus': ['opencensus'],
'opentelemetry': ["opentelemetry-api ~= 1.0"],
},
Expand Down
102 changes: 102 additions & 0 deletions python/sqlcommenter-python/tests/generic/test_celery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
#!/usr/bin/python
#
# Copyright 2026
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import absolute_import

import unittest

import google.cloud.sqlcommenter.celery as mod


class CeleryModuleTests(unittest.TestCase):
def tearDown(self) -> None:
try:
mod._context.set({})
except Exception:
pass

def test_get_celery_info_empty(self):
self.assertEqual(mod.get_celery_info(), {})

def test_get_celery_info_after_context_set(self):
token = mod._context.set({
'framework': 'celery:5.3.0',
'task': 'tasks.add',
'route': 'celery',
})
try:
self.assertEqual(
mod.get_celery_info(),
{'framework': 'celery:5.3.0', 'task': 'tasks.add', 'route': 'celery'}
)
finally:
mod._context.reset(token)

def test_install_signals_without_celery_raises(self):
original_celery, original_signals = mod.celery, mod.signals
try:
mod.celery = None
mod.signals = None
with self.assertRaises(ImportError):
mod.install_signals()
finally:
mod.celery, mod.signals = original_celery, original_signals

def test_signal_flow_sets_and_clears_context(self):
class _SigList:
def __init__(self):
self._cbs = []
def connect(self, cb, weak=False): # noqa: ARG002 - weak unused
self._cbs.append(cb)
return cb

class FakeSignals:
def __init__(self):
self.task_prerun = _SigList()
self.task_postrun = _SigList()

class FakeCelery:
__version__ = '9.9.9'

original_celery, original_signals = mod.celery, mod.signals
try:
mod.celery = FakeCelery()
mod.signals = FakeSignals()
mod.install_signals()

self.assertEqual(len(mod.signals.task_prerun._cbs), 1)
self.assertEqual(len(mod.signals.task_postrun._cbs), 1)

prerun_cb = mod.signals.task_prerun._cbs[0]
postrun_cb = mod.signals.task_postrun._cbs[0]

class FakeReq:
routing_key = 'celery'
class FakeTask:
name = 'tasks.add'
request = FakeReq()

t = FakeTask()
prerun_cb(task=t)
self.assertEqual(
mod.get_celery_info(),
{'framework': 'celery:9.9.9', 'task': 'tasks.add', 'route': 'celery'}
)

postrun_cb(task=t)
self.assertEqual(mod.get_celery_info(), {})
finally:
mod.celery, mod.signals = original_celery, original_signals
35 changes: 35 additions & 0 deletions python/sqlcommenter-python/tests/sqlalchemy/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,38 @@ def test_route_disabled(self, get_info):
"SELECT 1 /*controller='c',framework='fastapi'*/;",
with_route=False,
)


class CeleryTests(SQLAlchemyTestCase):
celery_info = {
'framework': 'celery',
'controller': 'tasks.add',
'route': 'celery',
}

@mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info)
def test_all_data(self, get_info):
self.assertSQL(
"SELECT 1 /*controller='tasks.add',framework='celery',route='celery'*/;",
)

@mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info)
def test_framework_disabled(self, get_info):
self.assertSQL(
"SELECT 1 /*controller='tasks.add',route='celery'*/;",
with_framework=False,
)

@mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info)
def test_controller_disabled(self, get_info):
self.assertSQL(
"SELECT 1 /*framework='celery',route='celery'*/;",
with_controller=False,
)

@mock.patch('google.cloud.sqlcommenter.sqlalchemy.executor.get_celery_info', return_value=celery_info)
def test_route_disabled(self, get_info):
self.assertSQL(
"SELECT 1 /*controller='tasks.add',framework='celery'*/;",
with_route=False,
)
6 changes: 3 additions & 3 deletions python/sqlcommenter-python/tox.ini
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[tox]
envlist =
py3{6,7,8,9}-django{21,22,30,31,32}
py3{6,7,8,9}-psycopg2
py3{6,7,8,9}-{flask,generic,sqlalchemy}
py3{6,7,8,9,10}-django{21,22,30,31,32}
py3{6,7,8,9,10}-psycopg2
py3{6,7,8,9,10}-{flask,generic,sqlalchemy}
flake8

#
Expand Down