Skip to content

Commit 2abd718

Browse files
committed
Use dictionary watchers to update version number on writes to dictionary
1 parent 1f67a54 commit 2abd718

File tree

11 files changed

+92
-62
lines changed

11 files changed

+92
-62
lines changed

Include/internal/pycore_dict.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,13 @@ static inline PyDictUnicodeEntry* DK_UNICODE_ENTRIES(PyDictKeysObject *dk) {
265265
#define DICT_UNIQUE_ID_SHIFT (32)
266266
#define DICT_UNIQUE_ID_MAX ((UINT64_C(1) << (64 - DICT_UNIQUE_ID_SHIFT)) - 1)
267267

268+
/* The first three dict watcher IDs are reserved for CPython,
269+
* so we don't need to check that they haven't been used */
270+
#define BUILTINS_WATCHER_ID 0
271+
#define GLOBALS_WATCHER_ID 1
272+
#define MODULE_WATCHER_ID 2
273+
#define FIRST_AVAILABLE_WATCHER 3
274+
268275

269276
PyAPI_FUNC(void)
270277
_PyDict_SendEvent(int watcher_bits,

Include/internal/pycore_moduleobject.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ extern int _PyModule_IsPossiblyShadowing(PyObject *);
1919

2020
extern int _PyModule_IsExtension(PyObject *obj);
2121

22+
extern int _PyModule_InitModuleDictWatcher(PyInterpreterState *interp);
23+
2224
typedef int (*_Py_modexecfunc)(PyObject *);
2325

2426
typedef struct {

Lib/test/test_import/test_lazy_imports.py

Lines changed: 35 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -211,42 +211,9 @@ def test_lazy_import_type_exposed(self):
211211
self.assertHasAttr(types, 'LazyImportType')
212212
self.assertEqual(types.LazyImportType.__name__, 'lazy_import')
213213

214-
def test_lazy_import_type_invalid_name(self):
215-
"""passing invalid name to lazy imports should raise a TypeError"""
216-
with self.assertRaises(TypeError) as cm:
217-
types.LazyImportType({}, None)
218-
219-
def test_lazy_import_type_invalid_fromlist_type(self):
220-
"""LazyImportType should reject invalid fromlist types."""
221-
# fromlist must be None, a string, or a tuple - not an int
222-
with self.assertRaises(TypeError) as cm:
223-
types.LazyImportType({}, "module", 0)
224-
self.assertIn("fromlist must be None, a string, or a tuple", str(cm.exception))
225-
226-
# Also test with other invalid types
227-
with self.assertRaises(TypeError):
228-
types.LazyImportType({}, "module", []) # list not allowed
229-
230-
with self.assertRaises(TypeError):
231-
types.LazyImportType({}, "module", {"x": 1}) # dict not allowed
232-
233-
def test_lazy_import_type_valid_fromlist(self):
234-
"""LazyImportType should accept valid fromlist types."""
235-
# None is valid (implicit)
236-
lazy1 = types.LazyImportType({}, "module")
237-
self.assertIsNotNone(lazy1)
238-
239-
# Explicit None is valid
240-
lazy2 = types.LazyImportType({}, "module", None)
241-
self.assertIsNotNone(lazy2)
242-
243-
# String is valid
244-
lazy3 = types.LazyImportType({}, "module", "attr")
245-
self.assertIsNotNone(lazy3)
246-
247-
# Tuple is valid
248-
lazy4 = types.LazyImportType({}, "module", ("attr1", "attr2"))
249-
self.assertIsNotNone(lazy4)
214+
def test_lazy_import_type_cant_construct(self):
215+
"""LazyImportType should not be directly constructible."""
216+
self.assertRaises(TypeError, types.LazyImportType, {}, "module")
250217

251218

252219
class SyntaxRestrictionTests(unittest.TestCase):
@@ -768,6 +735,38 @@ def test_resolve():
768735
self.assertEqual(result.returncode, 0, f"stdout: {result.stdout}, stderr: {result.stderr}")
769736
self.assertIn("OK", result.stdout)
770737

738+
def test_add_lazy_to_globals(self):
739+
code = textwrap.dedent("""
740+
import sys
741+
import types
742+
743+
lazy from test.test_import.data.lazy_imports import basic2
744+
745+
assert 'test.test_import.data.lazy_imports.basic2' not in sys.modules
746+
747+
class C: pass
748+
sneaky = C()
749+
sneaky.x = 1
750+
751+
def f():
752+
t = 0
753+
for _ in range(5):
754+
t += sneaky.x
755+
return t
756+
757+
f()
758+
globals()["sneaky"] = globals()["basic2"]
759+
assert f() == 210
760+
print("OK")
761+
""")
762+
result = subprocess.run(
763+
[sys.executable, "-c", code],
764+
capture_output=True,
765+
text=True
766+
)
767+
self.assertEqual(result.returncode, 0, f"stdout: {result.stdout}, stderr: {result.stderr}")
768+
self.assertIn("OK", result.stdout)
769+
771770

772771
class MultipleNameFromImportTests(unittest.TestCase):
773772
"""Tests for lazy from ... import with multiple names.

Objects/dictobject.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7648,7 +7648,7 @@ PyDict_AddWatcher(PyDict_WatchCallback callback)
76487648
PyInterpreterState *interp = _PyInterpreterState_GET();
76497649

76507650
/* Start at 2, as 0 and 1 are reserved for CPython */
7651-
for (int i = 2; i < DICT_MAX_WATCHERS; i++) {
7651+
for (int i = FIRST_AVAILABLE_WATCHER; i < DICT_MAX_WATCHERS; i++) {
76527652
if (!interp->dict_state.watchers[i]) {
76537653
interp->dict_state.watchers[i] = callback;
76547654
return i;

Objects/moduleobject.c

Lines changed: 30 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919

2020
#include "osdefs.h" // MAXPATHLEN
2121

22-
2322
#define _PyModule_CAST(op) \
2423
(assert(PyModule_Check(op)), _Py_CAST(PyModuleObject*, (op)))
2524

@@ -196,11 +195,40 @@ new_module_notrack(PyTypeObject *mt)
196195
return m;
197196
}
198197

198+
/* Module dict watcher callback.
199+
* When a module dictionary is modified, we need to clear the keys version
200+
* to invalidate any cached lookups that depend on the dictionary structure.
201+
*/
202+
static int
203+
module_dict_watcher(PyDict_WatchEvent event, PyObject *dict,
204+
PyObject *key, PyObject *new_value)
205+
{
206+
assert(PyDict_Check(dict));
207+
// Only if a new lazy object shows up do we need to clear the dictionary. If
208+
// this is adding a new key then the version will be reset anyway.
209+
if (event == PyDict_EVENT_MODIFIED &&
210+
new_value != NULL &&
211+
PyLazyImport_CheckExact(new_value)) {
212+
_PyDict_ClearKeysVersionLockHeld(dict);
213+
}
214+
return 0;
215+
}
216+
217+
int
218+
_PyModule_InitModuleDictWatcher(PyInterpreterState *interp)
219+
{
220+
// This is a reserved watcher for CPython so there's no need to check for non-NULL.
221+
assert(interp->dict_state.watchers[MODULE_WATCHER_ID] == NULL);
222+
interp->dict_state.watchers[MODULE_WATCHER_ID] = &module_dict_watcher;
223+
return 0;
224+
}
225+
199226
static void
200227
track_module(PyModuleObject *m)
201228
{
202229
_PyDict_EnablePerThreadRefcounting(m->md_dict);
203230
_PyObject_SetDeferredRefcount((PyObject *)m);
231+
PyDict_Watch(MODULE_WATCHER_ID, m->md_dict);
204232
PyObject_GC_Track(m);
205233
}
206234

@@ -1266,23 +1294,6 @@ _PyModule_IsPossiblyShadowing(PyObject *origin)
12661294
return result;
12671295
}
12681296

1269-
int
1270-
_PyModule_ReplaceLazyValue(PyObject *dict, PyObject *name, PyObject *value)
1271-
{
1272-
// The adaptive interpreter uses the dictionary keys version to return the
1273-
// slot at a given index from the module. When replacing a value the
1274-
// version number doesn't change, so we need to atomically clear the
1275-
// version before replacing so that it doesn't return a lazy value.
1276-
int err;
1277-
Py_BEGIN_CRITICAL_SECTION(dict);
1278-
1279-
_PyDict_ClearKeysVersionLockHeld(dict);
1280-
err = _PyDict_SetItem_LockHeld((PyDictObject *)dict, name, value);
1281-
1282-
Py_END_CRITICAL_SECTION();
1283-
return err;
1284-
}
1285-
12861297
PyObject*
12871298
_Py_module_getattro_impl(PyModuleObject *m, PyObject *name, int suppress)
12881299
{
@@ -1306,7 +1317,7 @@ _Py_module_getattro_impl(PyModuleObject *m, PyObject *name, int suppress)
13061317
return NULL;
13071318
}
13081319

1309-
if (_PyModule_ReplaceLazyValue(m->md_dict, name, new_value) < 0) {
1320+
if (PyDict_SetItem(m->md_dict, name, new_value) < 0) {
13101321
Py_CLEAR(new_value);
13111322
}
13121323
Py_DECREF(attr);

Python/bytecodes.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,7 @@ dummy_func(
18291829
Py_DECREF(v_o);
18301830
ERROR_IF(true);
18311831
}
1832-
int err = _PyModule_ReplaceLazyValue(GLOBALS(), name, l_v);
1832+
int err = PyDict_SetItem(GLOBALS(), name, l_v);
18331833
if (err < 0) {
18341834
Py_DECREF(v_o);
18351835
Py_DECREF(l_v);

Python/ceval.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4186,7 +4186,7 @@ _PyEval_LoadGlobalStackRef(PyObject *globals, PyObject *builtins, PyObject *name
41864186
*writeto = PyStackRef_NULL;
41874187
return;
41884188
}
4189-
int err = _PyModule_ReplaceLazyValue(globals, name, l_v);
4189+
int err = PyDict_SetItem(globals, name, l_v);
41904190
if (err < 0) {
41914191
Py_DECREF(l_v);
41924192
*writeto = PyStackRef_NULL;

Python/executor_cases.c.h

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/generated_cases.c.h

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Python/import.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5530,8 +5530,8 @@ publish_lazy_imports_on_module(PyThreadState *tstate,
55305530
}
55315531

55325532
// Publish on the module that was just imported.
5533-
if (_PyModule_ReplaceLazyValue(module_dict, attr_name,
5534-
lazy_module_attr) < 0) {
5533+
if (PyDict_SetItem(module_dict, attr_name,
5534+
lazy_module_attr) < 0) {
55355535
Py_DECREF(lazy_module_attr);
55365536
Py_DECREF(attr_name);
55375537
return -1;

0 commit comments

Comments
 (0)