Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ struct _pycontextobject {
PyObject_HEAD
PyContext *ctx_prev;
PyHamtObject *ctx_vars;
PyHamtObject *ctx_vars_origin; /* snapshot of ctx_vars at Enter time */
PyObject *ctx_weakreflist;
int ctx_entered;
};
Expand Down Expand Up @@ -58,5 +59,24 @@ PyAPI_FUNC(PyObject*) _PyContext_NewHamtForTests(void);
PyAPI_FUNC(int) _PyContext_Enter(PyThreadState *ts, PyObject *octx);
PyAPI_FUNC(int) _PyContext_Exit(PyThreadState *ts, PyObject *octx);

/* Get a value for the variable and check if it was changed.

Like PyContextVar_Get, but also reports whether the variable was
changed in the current context scope via a single HAMT lookup.

Returns -1 if an error occurred during lookup.

Returns 0 if no error occurred. In this case:

- *value will be set the same as for PyContextVar_Get.
- *changed will be set to 1 if the variable was changed in the
current context scope, 0 otherwise. If the variable was not
found, *changed is always 0.

'*value' will be a new ref, if not NULL.
*/
PyAPI_FUNC(int) _PyContextVar_GetChanged(
PyObject *var, PyObject *default_value, PyObject **value, int *changed);


#endif /* !Py_INTERNAL_CONTEXT_H */
10 changes: 9 additions & 1 deletion Lib/_pydecimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,11 +363,19 @@ def getcontext():
New contexts are copies of DefaultContext.
"""
try:
return _current_context_var.get()
context, changed = _current_context_var._get_changed()
except LookupError:
context = Context()
_current_context_var.set(context)
return context
if not changed:
# The context value was inherited from another task/thread. Because
# the Context() instance is mutable, copy it to ensure that if it is
# changed, those changes are isolated from other tasks/threads.
context = context.copy()
_current_context_var.set(context)
return context


def setcontext(context):
"""Set this thread's context to context."""
Expand Down
221 changes: 221 additions & 0 deletions Lib/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,227 @@ def __eq__(self, other):
ctx2.run(var.set, ReentrantHash())
ctx1 == ctx2

def test__get_changed_outside_run(self):
# Outside any Context.run(), bindings are considered "changed"
v = contextvars.ContextVar('v', default='dflt')
val, changed = v._get_changed()
self.assertEqual(val, 'dflt')
self.assertFalse(changed) # default value, not changed
v.set(42)
val, changed = v._get_changed()
self.assertEqual(val, 42)
self.assertTrue(changed) # set in base context

def test__get_changed_inherited(self):
# Inherited bindings are not considered "changed"
v = contextvars.ContextVar('v')
v.set('parent')
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertEqual(val, 'parent')
self.assertFalse(changed)
ctx.run(check)

def test__get_changed_after_set(self):
# After set() inside Context.run(), changed is True
v = contextvars.ContextVar('v')
v.set('parent')
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertFalse(changed)
v.set('child')
val, changed = v._get_changed()
self.assertEqual(val, 'child')
self.assertTrue(changed)
ctx.run(check)

def test__get_changed_new_var_in_run(self):
# A variable set for the first time inside run() is "changed"
v = contextvars.ContextVar('v')
ctx = contextvars.copy_context()

def check():
with self.assertRaises(LookupError):
v._get_changed()
v.set('new')
val, changed = v._get_changed()
self.assertEqual(val, 'new')
self.assertTrue(changed)
ctx.run(check)

def test__get_changed_not_set_with_default(self):
# A variable not set but with default: changed is False
v = contextvars.ContextVar('v', default='dflt')
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertEqual(val, 'dflt')
self.assertFalse(changed)
ctx.run(check)

def test__get_changed_not_set_no_default(self):
# A variable that has never been set and has no default
v = contextvars.ContextVar('v')
ctx = contextvars.copy_context()

def check():
with self.assertRaises(LookupError):
v._get_changed()
ctx.run(check)

def test__get_changed_explicit_default_arg(self):
# Passing a default argument to _get_changed()
v = contextvars.ContextVar('v')
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed('fallback')
self.assertEqual(val, 'fallback')
self.assertFalse(changed)
ctx.run(check)

def test__get_changed_set_same_object(self):
# Setting to the exact same object does not count as "changed"
# because the HAMT recognizes the identical key-value pair
obj = object()
v = contextvars.ContextVar('v')
v.set(obj)
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertIs(val, obj)
self.assertFalse(changed)
v.set(obj) # same object
val, changed = v._get_changed()
self.assertIs(val, obj)
self.assertFalse(changed)
ctx.run(check)

def test__get_changed_set_different_object(self):
# Setting to a different object counts as "changed"
v = contextvars.ContextVar('v')
v.set([1, 2, 3])
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertFalse(changed)
v.set([1, 2, 3]) # equal value, different object
val, changed = v._get_changed()
self.assertTrue(changed)
ctx.run(check)

def test__get_changed_after_reset(self):
# After reset(), the variable reverts to its inherited state
v = contextvars.ContextVar('v')
v.set('original')
ctx = contextvars.copy_context()

def check():
val, changed = v._get_changed()
self.assertFalse(changed)
tok = v.set('modified')
val, changed = v._get_changed()
self.assertTrue(changed)
v.reset(tok)
val, changed = v._get_changed()
self.assertFalse(changed)
ctx.run(check)

def test__get_changed_multiple_vars(self):
# Changing one variable does not affect _get_changed() for others
v1 = contextvars.ContextVar('v1')
v2 = contextvars.ContextVar('v2')
v1.set('a')
v2.set('b')
ctx = contextvars.copy_context()

def check():
_, changed1 = v1._get_changed()
_, changed2 = v2._get_changed()
self.assertFalse(changed1)
self.assertFalse(changed2)
v1.set('a2')
_, changed1 = v1._get_changed()
_, changed2 = v2._get_changed()
self.assertTrue(changed1)
self.assertFalse(changed2)
ctx.run(check)

def test__get_changed_nested_run(self):
# _get_changed() reflects the innermost Context.run() scope
v = contextvars.ContextVar('v')
v.set('root')
ctx1 = contextvars.copy_context()

def outer():
_, changed = v._get_changed()
self.assertFalse(changed)
v.set('outer')
_, changed = v._get_changed()
self.assertTrue(changed)
ctx2 = contextvars.copy_context()

def inner():
# inherited 'outer' from ctx1, not changed in ctx2
val, changed = v._get_changed()
self.assertEqual(val, 'outer')
self.assertFalse(changed)
v.set('inner')
val, changed = v._get_changed()
self.assertEqual(val, 'inner')
self.assertTrue(changed)
ctx2.run(inner)

# after inner run exits, outer's state is restored
_, changed = v._get_changed()
self.assertTrue(changed)
ctx1.run(outer)

@threading_helper.requires_working_threading()
def test__get_changed_with_threads(self):
# _get_changed() works correctly in a thread with copied context
import threading
v = contextvars.ContextVar('v')
v.set('parent')
ctx = contextvars.copy_context()
results = {}

def thread_func():
val, changed = v._get_changed()
results['inherited'] = changed
results['value'] = val
v.set('thread')
val, changed = v._get_changed()
results['after_set'] = changed

t = threading.Thread(target=ctx.run, args=(thread_func,))
t.start()
t.join()
self.assertFalse(results['inherited'])
self.assertEqual(results['value'], 'parent')
self.assertTrue(results['after_set'])

def test__get_changed_empty_context_run(self):
# Running in a brand new empty context
v = contextvars.ContextVar('v')
ctx = contextvars.Context()

def check():
with self.assertRaises(LookupError):
v._get_changed()
v.set('value')
val, changed = v._get_changed()
self.assertEqual(val, 'value')
self.assertTrue(changed)
ctx.run(check)


# HAMT Tests

Expand Down
53 changes: 53 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,59 @@ def test_threading(self):
DefaultContext.Emax = save_emax
DefaultContext.Emin = save_emin

@threading_helper.requires_working_threading()
def test_inherited_context_isolation(self):
# Test that when threads inherit contextvars (e.g. via
# sys.flags.thread_inherit_context), each thread gets its own
# copy of the decimal context so mutations don't leak between
# threads. Also verifies correct behavior with asyncio tasks.
Decimal = self.decimal.Decimal
getcontext = self.decimal.getcontext
setcontext = self.decimal.setcontext
Context = self.decimal.Context
Underflow = self.decimal.Underflow

# Set up parent context with specific precision
parent_ctx = getcontext()
parent_ctx.prec = 20

barrier = threading.Barrier(2, timeout=2)
results = {}

def child(name, prec_delta):
barrier.wait()
ctx = getcontext()
# Each child should see a context with the parent's precision
results[name + '_initial_prec'] = ctx.prec
results[name + '_ctx_id'] = id(ctx)
# Mutate this thread's context
ctx.prec += prec_delta
results[name + '_modified_prec'] = ctx.prec

# Spawn threads that inherit the parent's contextvars.
t1 = threading.Thread(target=child, args=('t1', 5),
context=contextvars.copy_context())
t2 = threading.Thread(target=child, args=('t2', 10),
context=contextvars.copy_context())
t1.start()
t2.start()
t1.join()
t2.join()

# Each thread should have started with the parent's precision
self.assertEqual(results['t1_initial_prec'], 20)
self.assertEqual(results['t2_initial_prec'], 20)

# Each thread should have its own context (different id)
self.assertNotEqual(results['t1_ctx_id'], results['t2_ctx_id'])

# Mutations should be independent
self.assertEqual(results['t1_modified_prec'], 25)
self.assertEqual(results['t2_modified_prec'], 30)

# Parent context should be unaffected
self.assertEqual(getcontext().prec, 20)


@requires_cdecimal
class CThreadingTest(ThreadingTest, unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Ensure that :func:`decimal.getcontext` returns a per-task copy of the
:class:`decimal.Context` so that mutations are isolated between asyncio
tasks and threads using :data:`sys.flags.thread_inherit_context`.
21 changes: 16 additions & 5 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif

#include <Python.h>
#include "pycore_context.h" // _PyContextVar_GetChanged()
#include "pycore_object.h" // _PyObject_VisitType()
#include "pycore_pystate.h" // _PyThreadState_GET()
#include "pycore_tuple.h" // _PyTuple_FromPair
Expand Down Expand Up @@ -1915,9 +1916,9 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
}
#else
static PyObject *
init_current_context(decimal_state *state)
init_current_context(decimal_state *state, PyObject *prev_context)
{
PyObject *tl_context = context_copy(state, state->default_context_template);
PyObject *tl_context = context_copy(state, prev_context);
if (tl_context == NULL) {
return NULL;
}
Expand All @@ -1937,15 +1938,25 @@ static inline PyObject *
current_context(decimal_state *state)
{
PyObject *tl_context;
if (PyContextVar_Get(state->current_context_var, NULL, &tl_context) < 0) {
int changed;
if (_PyContextVar_GetChanged(state->current_context_var, NULL, &tl_context,
&changed) < 0) {
return NULL;
}

if (tl_context != NULL) {
return tl_context;
if (!changed) {
/* inherited context object from another thread for async task */
PyObject *new_context = init_current_context(state, tl_context);
Py_DECREF(tl_context);
return new_context;
}
else {
return tl_context;
}
}

return init_current_context(state);
return init_current_context(state, state->default_context_template);
}

/* ctxobj := borrowed reference to the current context */
Expand Down
Loading
Loading