diff --git a/Include/internal/pycore_context.h b/Include/internal/pycore_context.h index a833f790a621b1..8d76a9adf4a21c 100644 --- a/Include/internal/pycore_context.h +++ b/Include/internal/pycore_context.h @@ -24,6 +24,7 @@ struct _pycontextobject { PyObject_HEAD PyContext *ctx_prev; PyHamtObject *ctx_vars; + PyHamtObject *ctx_vars_origin; /* snapshot of ctx_vars at Enter time */ PyObject *ctx_weakreflist; int ctx_entered; }; @@ -58,5 +59,24 @@ PyAPI_FUNC(PyObject*) _PyContext_NewHamtForTests(void); PyAPI_FUNC(int) _PyContext_Enter(PyThreadState *ts, PyObject *octx); PyAPI_FUNC(int) _PyContext_Exit(PyThreadState *ts, PyObject *octx); +/* Get a value for the variable and check if it was changed. + + Like PyContextVar_Get, but also reports whether the variable was + changed in the current context scope via a single HAMT lookup. + + Returns -1 if an error occurred during lookup. + + Returns 0 if no error occurred. In this case: + + - *value will be set the same as for PyContextVar_Get. + - *changed will be set to 1 if the variable was changed in the + current context scope, 0 otherwise. If the variable was not + found, *changed is always 0. + + '*value' will be a new ref, if not NULL. +*/ +PyAPI_FUNC(int) _PyContextVar_GetChanged( + PyObject *var, PyObject *default_value, PyObject **value, int *changed); + #endif /* !Py_INTERNAL_CONTEXT_H */ diff --git a/Lib/_pydecimal.py b/Lib/_pydecimal.py index 8c0afd14d616e8..e81622c062ead5 100644 --- a/Lib/_pydecimal.py +++ b/Lib/_pydecimal.py @@ -363,11 +363,19 @@ def getcontext(): New contexts are copies of DefaultContext. """ try: - return _current_context_var.get() + context, changed = _current_context_var._get_changed() except LookupError: context = Context() _current_context_var.set(context) return context + if not changed: + # The context value was inherited from another task/thread. Because + # the Context() instance is mutable, copy it to ensure that if it is + # changed, those changes are isolated from other tasks/threads. + context = context.copy() + _current_context_var.set(context) + return context + def setcontext(context): """Set this thread's context to context.""" diff --git a/Lib/test/test_context.py b/Lib/test/test_context.py index ef20495dcc01ea..181a4fd3f4b7c1 100644 --- a/Lib/test/test_context.py +++ b/Lib/test/test_context.py @@ -586,6 +586,227 @@ def __eq__(self, other): ctx2.run(var.set, ReentrantHash()) ctx1 == ctx2 + def test__get_changed_outside_run(self): + # Outside any Context.run(), bindings are considered "changed" + v = contextvars.ContextVar('v', default='dflt') + val, changed = v._get_changed() + self.assertEqual(val, 'dflt') + self.assertFalse(changed) # default value, not changed + v.set(42) + val, changed = v._get_changed() + self.assertEqual(val, 42) + self.assertTrue(changed) # set in base context + + def test__get_changed_inherited(self): + # Inherited bindings are not considered "changed" + v = contextvars.ContextVar('v') + v.set('parent') + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertEqual(val, 'parent') + self.assertFalse(changed) + ctx.run(check) + + def test__get_changed_after_set(self): + # After set() inside Context.run(), changed is True + v = contextvars.ContextVar('v') + v.set('parent') + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertFalse(changed) + v.set('child') + val, changed = v._get_changed() + self.assertEqual(val, 'child') + self.assertTrue(changed) + ctx.run(check) + + def test__get_changed_new_var_in_run(self): + # A variable set for the first time inside run() is "changed" + v = contextvars.ContextVar('v') + ctx = contextvars.copy_context() + + def check(): + with self.assertRaises(LookupError): + v._get_changed() + v.set('new') + val, changed = v._get_changed() + self.assertEqual(val, 'new') + self.assertTrue(changed) + ctx.run(check) + + def test__get_changed_not_set_with_default(self): + # A variable not set but with default: changed is False + v = contextvars.ContextVar('v', default='dflt') + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertEqual(val, 'dflt') + self.assertFalse(changed) + ctx.run(check) + + def test__get_changed_not_set_no_default(self): + # A variable that has never been set and has no default + v = contextvars.ContextVar('v') + ctx = contextvars.copy_context() + + def check(): + with self.assertRaises(LookupError): + v._get_changed() + ctx.run(check) + + def test__get_changed_explicit_default_arg(self): + # Passing a default argument to _get_changed() + v = contextvars.ContextVar('v') + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed('fallback') + self.assertEqual(val, 'fallback') + self.assertFalse(changed) + ctx.run(check) + + def test__get_changed_set_same_object(self): + # Setting to the exact same object does not count as "changed" + # because the HAMT recognizes the identical key-value pair + obj = object() + v = contextvars.ContextVar('v') + v.set(obj) + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertIs(val, obj) + self.assertFalse(changed) + v.set(obj) # same object + val, changed = v._get_changed() + self.assertIs(val, obj) + self.assertFalse(changed) + ctx.run(check) + + def test__get_changed_set_different_object(self): + # Setting to a different object counts as "changed" + v = contextvars.ContextVar('v') + v.set([1, 2, 3]) + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertFalse(changed) + v.set([1, 2, 3]) # equal value, different object + val, changed = v._get_changed() + self.assertTrue(changed) + ctx.run(check) + + def test__get_changed_after_reset(self): + # After reset(), the variable reverts to its inherited state + v = contextvars.ContextVar('v') + v.set('original') + ctx = contextvars.copy_context() + + def check(): + val, changed = v._get_changed() + self.assertFalse(changed) + tok = v.set('modified') + val, changed = v._get_changed() + self.assertTrue(changed) + v.reset(tok) + val, changed = v._get_changed() + self.assertFalse(changed) + ctx.run(check) + + def test__get_changed_multiple_vars(self): + # Changing one variable does not affect _get_changed() for others + v1 = contextvars.ContextVar('v1') + v2 = contextvars.ContextVar('v2') + v1.set('a') + v2.set('b') + ctx = contextvars.copy_context() + + def check(): + _, changed1 = v1._get_changed() + _, changed2 = v2._get_changed() + self.assertFalse(changed1) + self.assertFalse(changed2) + v1.set('a2') + _, changed1 = v1._get_changed() + _, changed2 = v2._get_changed() + self.assertTrue(changed1) + self.assertFalse(changed2) + ctx.run(check) + + def test__get_changed_nested_run(self): + # _get_changed() reflects the innermost Context.run() scope + v = contextvars.ContextVar('v') + v.set('root') + ctx1 = contextvars.copy_context() + + def outer(): + _, changed = v._get_changed() + self.assertFalse(changed) + v.set('outer') + _, changed = v._get_changed() + self.assertTrue(changed) + ctx2 = contextvars.copy_context() + + def inner(): + # inherited 'outer' from ctx1, not changed in ctx2 + val, changed = v._get_changed() + self.assertEqual(val, 'outer') + self.assertFalse(changed) + v.set('inner') + val, changed = v._get_changed() + self.assertEqual(val, 'inner') + self.assertTrue(changed) + ctx2.run(inner) + + # after inner run exits, outer's state is restored + _, changed = v._get_changed() + self.assertTrue(changed) + ctx1.run(outer) + + @threading_helper.requires_working_threading() + def test__get_changed_with_threads(self): + # _get_changed() works correctly in a thread with copied context + import threading + v = contextvars.ContextVar('v') + v.set('parent') + ctx = contextvars.copy_context() + results = {} + + def thread_func(): + val, changed = v._get_changed() + results['inherited'] = changed + results['value'] = val + v.set('thread') + val, changed = v._get_changed() + results['after_set'] = changed + + t = threading.Thread(target=ctx.run, args=(thread_func,)) + t.start() + t.join() + self.assertFalse(results['inherited']) + self.assertEqual(results['value'], 'parent') + self.assertTrue(results['after_set']) + + def test__get_changed_empty_context_run(self): + # Running in a brand new empty context + v = contextvars.ContextVar('v') + ctx = contextvars.Context() + + def check(): + with self.assertRaises(LookupError): + v._get_changed() + v.set('value') + val, changed = v._get_changed() + self.assertEqual(val, 'value') + self.assertTrue(changed) + ctx.run(check) + # HAMT Tests diff --git a/Lib/test/test_decimal.py b/Lib/test/test_decimal.py index fe8c8ce12da0bf..dc365a7cd8228d 100644 --- a/Lib/test/test_decimal.py +++ b/Lib/test/test_decimal.py @@ -1770,6 +1770,59 @@ def test_threading(self): DefaultContext.Emax = save_emax DefaultContext.Emin = save_emin + @threading_helper.requires_working_threading() + def test_inherited_context_isolation(self): + # Test that when threads inherit contextvars (e.g. via + # sys.flags.thread_inherit_context), each thread gets its own + # copy of the decimal context so mutations don't leak between + # threads. Also verifies correct behavior with asyncio tasks. + Decimal = self.decimal.Decimal + getcontext = self.decimal.getcontext + setcontext = self.decimal.setcontext + Context = self.decimal.Context + Underflow = self.decimal.Underflow + + # Set up parent context with specific precision + parent_ctx = getcontext() + parent_ctx.prec = 20 + + barrier = threading.Barrier(2, timeout=2) + results = {} + + def child(name, prec_delta): + barrier.wait() + ctx = getcontext() + # Each child should see a context with the parent's precision + results[name + '_initial_prec'] = ctx.prec + results[name + '_ctx_id'] = id(ctx) + # Mutate this thread's context + ctx.prec += prec_delta + results[name + '_modified_prec'] = ctx.prec + + # Spawn threads that inherit the parent's contextvars. + t1 = threading.Thread(target=child, args=('t1', 5), + context=contextvars.copy_context()) + t2 = threading.Thread(target=child, args=('t2', 10), + context=contextvars.copy_context()) + t1.start() + t2.start() + t1.join() + t2.join() + + # Each thread should have started with the parent's precision + self.assertEqual(results['t1_initial_prec'], 20) + self.assertEqual(results['t2_initial_prec'], 20) + + # Each thread should have its own context (different id) + self.assertNotEqual(results['t1_ctx_id'], results['t2_ctx_id']) + + # Mutations should be independent + self.assertEqual(results['t1_modified_prec'], 25) + self.assertEqual(results['t2_modified_prec'], 30) + + # Parent context should be unaffected + self.assertEqual(getcontext().prec, 20) + @requires_cdecimal class CThreadingTest(ThreadingTest, unittest.TestCase): diff --git a/Misc/NEWS.d/next/Library/2026-03-26-10-27-07.gh-issue-141148._XpYnI.rst b/Misc/NEWS.d/next/Library/2026-03-26-10-27-07.gh-issue-141148._XpYnI.rst new file mode 100644 index 00000000000000..8589be19132a74 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2026-03-26-10-27-07.gh-issue-141148._XpYnI.rst @@ -0,0 +1,3 @@ +Ensure that :func:`decimal.getcontext` returns a per-task copy of the +:class:`decimal.Context` so that mutations are isolated between asyncio +tasks and threads using :data:`sys.flags.thread_inherit_context`. diff --git a/Modules/_decimal/_decimal.c b/Modules/_decimal/_decimal.c index 4db1b60be77844..75b52d85c18782 100644 --- a/Modules/_decimal/_decimal.c +++ b/Modules/_decimal/_decimal.c @@ -30,6 +30,7 @@ #endif #include +#include "pycore_context.h" // _PyContextVar_GetChanged() #include "pycore_object.h" // _PyObject_VisitType() #include "pycore_pystate.h" // _PyThreadState_GET() #include "pycore_tuple.h" // _PyTuple_FromPair @@ -1915,9 +1916,9 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v) } #else static PyObject * -init_current_context(decimal_state *state) +init_current_context(decimal_state *state, PyObject *prev_context) { - PyObject *tl_context = context_copy(state, state->default_context_template); + PyObject *tl_context = context_copy(state, prev_context); if (tl_context == NULL) { return NULL; } @@ -1937,15 +1938,25 @@ static inline PyObject * current_context(decimal_state *state) { PyObject *tl_context; - if (PyContextVar_Get(state->current_context_var, NULL, &tl_context) < 0) { + int changed; + if (_PyContextVar_GetChanged(state->current_context_var, NULL, &tl_context, + &changed) < 0) { return NULL; } if (tl_context != NULL) { - return tl_context; + if (!changed) { + /* inherited context object from another thread for async task */ + PyObject *new_context = init_current_context(state, tl_context); + Py_DECREF(tl_context); + return new_context; + } + else { + return tl_context; + } } - return init_current_context(state); + return init_current_context(state, state->default_context_template); } /* ctxobj := borrowed reference to the current context */ diff --git a/Python/clinic/context.c.h b/Python/clinic/context.c.h index ece7341d65d5fb..f0501fce6e90fa 100644 --- a/Python/clinic/context.c.h +++ b/Python/clinic/context.c.h @@ -209,6 +209,52 @@ _contextvars_ContextVar_reset(PyObject *self, PyObject *token) return return_value; } +PyDoc_STRVAR(_contextvars_ContextVar__get_changed__doc__, +"_get_changed($self, default=, /)\n" +"--\n" +"\n" +"Return a tuple of (value, changed) for the context variable.\n" +"\n" +"Like ContextVar.get(), but additionally indicates whether the variable was\n" +"changed in the current context scope. *changed* is True if ContextVar.set()\n" +"has been called on the variable within the current Context.run() call with\n" +"a value that is a different object than the inherited one.\n" +"\n" +"If there is no value for the variable in the current context, the method will:\n" +" * return the value of the default argument of the method, if provided; or\n" +" * return the default value for the context variable, if it was created\n" +" with one; or\n" +" * raise a LookupError.\n" +"\n" +"When the value is found via a default, *changed* is always False."); + +#define _CONTEXTVARS_CONTEXTVAR__GET_CHANGED_METHODDEF \ + {"_get_changed", _PyCFunction_CAST(_contextvars_ContextVar__get_changed), METH_FASTCALL, _contextvars_ContextVar__get_changed__doc__}, + +static PyObject * +_contextvars_ContextVar__get_changed_impl(PyContextVar *self, + PyObject *default_value); + +static PyObject * +_contextvars_ContextVar__get_changed(PyObject *self, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *default_value = NULL; + + if (!_PyArg_CheckPositional("_get_changed", nargs, 0, 1)) { + goto exit; + } + if (nargs < 1) { + goto skip_optional; + } + default_value = args[0]; +skip_optional: + return_value = _contextvars_ContextVar__get_changed_impl((PyContextVar *)self, default_value); + +exit: + return return_value; +} + PyDoc_STRVAR(token_enter__doc__, "__enter__($self, /)\n" "--\n" @@ -259,4 +305,4 @@ token_exit(PyObject *self, PyObject *const *args, Py_ssize_t nargs) exit: return return_value; } -/*[clinic end generated code: output=90ec3e4375804e9b input=a9049054013a1b77]*/ +/*[clinic end generated code: output=2eda90fe52e6ed6f input=a9049054013a1b77]*/ diff --git a/Python/context.c b/Python/context.c index 593e6ef90037cf..1659a5ab2ccecd 100644 --- a/Python/context.c +++ b/Python/context.c @@ -209,6 +209,7 @@ _PyContext_Enter(PyThreadState *ts, PyObject *octx) } ctx->ctx_prev = (PyContext *)ts->context; /* borrow */ + ctx->ctx_vars_origin = (PyHamtObject *)Py_NewRef(ctx->ctx_vars); ts->context = Py_NewRef(ctx); context_switched(ts); return 0; @@ -248,6 +249,7 @@ _PyContext_Exit(PyThreadState *ts, PyObject *octx) Py_SETREF(ts->context, (PyObject *)ctx->ctx_prev); ctx->ctx_prev = NULL; + Py_CLEAR(ctx->ctx_vars_origin); FT_ATOMIC_STORE_INT(ctx->ctx_entered, 0); context_switched(ts); return 0; @@ -410,6 +412,115 @@ PyContextVar_Reset(PyObject *ovar, PyObject *otok) } +/* Check if var's current value (cur_val) differs from the origin snapshot. + ctx must be the current context and cur_val must be the value already + looked up in ctx->ctx_vars. Returns 1 if changed, 0 if not, -1 on error. */ +static int +contextvar_check_changed(PyContext *ctx, PyContextVar *var, PyObject *cur_val) +{ + /* No origin snapshot means this context was never entered via + Context.run(), so all bindings are considered "changed". */ + if (ctx->ctx_vars_origin == NULL) { + return 1; + } + + /* If the HAMT hasn't changed at all, no .set() calls have been made + in this context scope for any variable. */ + if (ctx->ctx_vars == ctx->ctx_vars_origin) { + return 0; + } + + /* Check if this specific variable had a different value (or was + absent) in the origin snapshot. */ + PyObject *orig_val = NULL; + int found_orig = _PyHamt_Find( + ctx->ctx_vars_origin, (PyObject *)var, &orig_val); + if (found_orig < 0) { + return -1; + } + if (found_orig == 0) { + return 1; + } + + return cur_val != orig_val; +} + + +int +_PyContextVar_GetChanged(PyObject *ovar, PyObject *def, PyObject **val, + int *changed) +{ + ENSURE_ContextVar(ovar, -1) + PyContextVar *var = (PyContextVar *)ovar; + + *changed = 0; + + PyThreadState *ts = _PyThreadState_GET(); + assert(ts != NULL); + if (ts->context == NULL) { + goto not_found; + } + + PyContext *ctx = (PyContext *)ts->context; + assert(PyContext_CheckExact(ts->context)); + +#ifndef Py_GIL_DISABLED + /* Try the cache first. When we get a cache hit we still need to + check the origin HAMT, but we skip the main HAMT lookup. */ + if (var->var_cached != NULL && + var->var_cached_tsid == ts->id && + var->var_cached_tsver == ts->context_ver) + { + *val = Py_NewRef(var->var_cached); + int res = contextvar_check_changed(ctx, var, var->var_cached); + if (res < 0) { + Py_CLEAR(*val); + return -1; + } + *changed = res; + return 0; + } +#endif + + PyObject *found_val = NULL; + int res = _PyHamt_Find(ctx->ctx_vars, (PyObject *)var, &found_val); + if (res < 0) { + *val = NULL; + return -1; + } + if (res == 1) { + assert(found_val != NULL); +#ifndef Py_GIL_DISABLED + var->var_cached = found_val; /* borrow */ + var->var_cached_tsid = ts->id; + var->var_cached_tsver = ts->context_ver; +#endif + int chg = contextvar_check_changed(ctx, var, found_val); + if (chg < 0) { + *val = NULL; + return -1; + } + *changed = chg; + *val = Py_NewRef(found_val); + return 0; + } + +not_found: + if (def == NULL) { + if (var->var_default != NULL) { + *val = Py_NewRef(var->var_default); + return 0; + } + *val = NULL; + return 0; + } + else { + *val = Py_NewRef(def); + return 0; + } +} + + /////////////////////////// PyContext /*[clinic input] @@ -433,6 +544,7 @@ _context_alloc(void) } ctx->ctx_vars = NULL; + ctx->ctx_vars_origin = NULL; ctx->ctx_prev = NULL; ctx->ctx_entered = 0; ctx->ctx_weakreflist = NULL; @@ -520,6 +632,7 @@ context_tp_clear(PyObject *op) PyContext *self = _PyContext_CAST(op); Py_CLEAR(self->ctx_prev); Py_CLEAR(self->ctx_vars); + Py_CLEAR(self->ctx_vars_origin); return 0; } @@ -529,6 +642,7 @@ context_tp_traverse(PyObject *op, visitproc visit, void *arg) PyContext *self = _PyContext_CAST(op); Py_VISIT(self->ctx_prev); Py_VISIT(self->ctx_vars); + Py_VISIT(self->ctx_vars_origin); return 0; } @@ -1090,6 +1204,52 @@ _contextvars_ContextVar_reset_impl(PyContextVar *self, PyObject *token) } +/*[clinic input] +@permit_long_docstring_body +_contextvars.ContextVar._get_changed + default: object = NULL + / + +Return a tuple of (value, changed) for the context variable. + +Like ContextVar.get(), but additionally indicates whether the variable was +changed in the current context scope. *changed* is True if ContextVar.set() +has been called on the variable within the current Context.run() call with +a value that is a different object than the inherited one. + +If there is no value for the variable in the current context, the method will: + * return the value of the default argument of the method, if provided; or + * return the default value for the context variable, if it was created + with one; or + * raise a LookupError. + +When the value is found via a default, *changed* is always False. +[clinic start generated code]*/ + +static PyObject * +_contextvars_ContextVar__get_changed_impl(PyContextVar *self, + PyObject *default_value) +/*[clinic end generated code: output=16b72be2c79429e9 input=aa6c784a3846a840]*/ +{ + PyObject *val; + int changed; + if (_PyContextVar_GetChanged( + (PyObject *)self, default_value, &val, &changed) < 0) { + return NULL; + } + + if (val == NULL) { + PyErr_SetObject(PyExc_LookupError, (PyObject *)self); + return NULL; + } + + PyObject *changed_obj = changed ? Py_True : Py_False; + PyObject *result = PyTuple_Pack(2, val, changed_obj); + Py_DECREF(val); + return result; +} + + static PyMemberDef PyContextVar_members[] = { {"name", _Py_T_OBJECT, offsetof(PyContextVar, var_name), Py_READONLY}, {NULL} @@ -1099,6 +1259,7 @@ static PyMethodDef PyContextVar_methods[] = { _CONTEXTVARS_CONTEXTVAR_GET_METHODDEF _CONTEXTVARS_CONTEXTVAR_SET_METHODDEF _CONTEXTVARS_CONTEXTVAR_RESET_METHODDEF + _CONTEXTVARS_CONTEXTVAR__GET_CHANGED_METHODDEF {"__class_getitem__", Py_GenericAlias, METH_O|METH_CLASS, PyDoc_STR("ContextVars are generic over the type of their contained values")},