Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions Include/internal/pycore_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ struct _pycontextobject {
PyHamtObject *ctx_vars;
PyObject *ctx_weakreflist;
int ctx_entered;
// Nesting depth in the context inheritance tree. Assigned at creation
// time: an empty/base context has depth 0, and a context produced by
// copying another (copy_context(), Context.copy(), thread/async context
// inheritance) has depth one greater than its source.
uint64_t ctx_depth;
};


Expand Down Expand Up @@ -58,5 +63,9 @@ PyAPI_FUNC(PyObject*) _PyContext_NewHamtForTests(void);
PyAPI_FUNC(int) _PyContext_Enter(PyThreadState *ts, PyObject *octx);
PyAPI_FUNC(int) _PyContext_Exit(PyThreadState *ts, PyObject *octx);

/* Return the depth (see struct _pycontextobject.ctx_depth) of the current
context, or 0 if there is no current context. */
PyAPI_FUNC(uint64_t) _PyContext_CurrentDepth(void);


#endif /* !Py_INTERNAL_CONTEXT_H */
23 changes: 22 additions & 1 deletion Lib/_pydecimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ class FloatOperation(DecimalException, TypeError):
# current context.

import contextvars
from _contextvars import _current_context_depth

_current_context_var = contextvars.ContextVar('decimal_context')

Expand All @@ -362,18 +363,32 @@ def getcontext():
a new context and sets this thread's context.
New contexts are copies of DefaultContext.
"""
cur_depth = _current_context_depth()
try:
return _current_context_var.get()
context = _current_context_var.get()
except LookupError:
context = Context()
context._local_depth = cur_depth
_current_context_var.set(context)
return context
if context._local_depth != cur_depth:
# The context value was inherited from another task/thread. Because
# the Context() instance is mutable, copy it to ensure that if it is
# changed, those changes are isolated from other tasks/threads.
context = context.copy()
context._local_depth = cur_depth
_current_context_var.set(context)
return context


def setcontext(context):
"""Set this thread's context to context."""
if context in (DefaultContext, BasicContext, ExtendedContext):
context = context.copy()
context.clear_flags()
# Mark the context as owned by the current context scope, so a following
# getcontext() returns this very object rather than a copy.
context._local_depth = _current_context_depth()
_current_context_var.set(context)

del contextvars # Don't contaminate the namespace
Expand Down Expand Up @@ -3869,6 +3884,10 @@ class Context(object):
clamp - If 1, change exponents if too high (Default 0)
"""

# Depth of the contextvars context this object was bound into by
# getcontext()/setcontext().
_local_depth = 0

def __init__(self, prec=None, rounding=None, Emin=None, Emax=None,
capitals=None, clamp=None, flags=None, traps=None,
_ignored_flags=None):
Expand Down Expand Up @@ -3951,6 +3970,8 @@ def __setattr__(self, name, value):
return self._set_signal_dict(name, value)
elif name == '_ignored_flags':
return object.__setattr__(self, name, value)
elif name == '_local_depth':
return object.__setattr__(self, name, value)
else:
raise AttributeError(
"'decimal.Context' object has no attribute '%s'" % name)
Expand Down
93 changes: 93 additions & 0 deletions Lib/test/test_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,99 @@ def __eq__(self, other):
ctx2.run(var.set, ReentrantHash())
ctx1 == ctx2

def test_context_depth_increases_on_run(self):
# Entering a copied context reports a depth one greater than the
# context it was copied from; the depth is restored after the run.
from _contextvars import _current_context_depth as depth
base = depth()
got = []
contextvars.copy_context().run(lambda: got.append(depth()))
self.assertEqual(got, [base + 1])
self.assertEqual(depth(), base)

def test_context_depth_nested_run(self):
# Nested copied contexts increase the depth by one per level.
from _contextvars import _current_context_depth as depth
base = depth()
got = []

def outer():
got.append(depth())
contextvars.copy_context().run(
lambda: got.append(depth()))
contextvars.copy_context().run(outer)
self.assertEqual(got, [base + 1, base + 2])

def test_context_depth_reenter_same_context(self):
# The depth is fixed when the context is created, so running the
# same context object again reports the same depth.
from _contextvars import _current_context_depth as depth
ctx = contextvars.copy_context()
got = []
ctx.run(lambda: got.append(depth()))
ctx.run(lambda: got.append(depth()))
self.assertEqual(got[0], got[1])

def test_context_depth_copy_method(self):
# Context.copy() produces a context one level deeper than its source.
from _contextvars import _current_context_depth as depth
base = depth()
# copy_context() -> depth base+1; .copy() of that -> depth base+2,
# regardless of where it is entered (depth is a creation property).
ctx = contextvars.copy_context().copy()
got = []
ctx.run(lambda: got.append(depth()))
self.assertEqual(got, [base + 2])

def test_context_depth_empty_context_is_zero(self):
# A freshly created (empty) Context has depth 0 and reports it
# independently of the context it is entered from.
from _contextvars import _current_context_depth as depth
got = []
contextvars.Context().run(lambda: got.append(depth()))
# Entered from within a deeper context, still its own depth (0).
contextvars.copy_context().run(
lambda: contextvars.Context().run(
lambda: got.append(depth())))
self.assertEqual(got, [0, 0])

def test_context_depth_inherited_value_is_shared(self):
# A value set before copying is inherited by the copied context as
# the same object, while the depth differs -- this is the signal a
# mutable value (e.g. a decimal context) uses to copy for isolation.
from _contextvars import _current_context_depth as depth
v = contextvars.ContextVar('v')
sentinel = object()
v.set(sentinel)
base = depth()
ctx = contextvars.copy_context()

def check():
self.assertEqual(depth(), base + 1)
self.assertIs(v.get(), sentinel)
ctx.run(check)

@threading_helper.requires_working_threading()
def test_context_depth_with_threads(self):
# A thread running a copied context sees a deeper context than the
# parent, and each thread's depth is independent.
import threading
from _contextvars import _current_context_depth as depth
base = depth()
results = {}

def thread_func(name):
results[name] = depth()

t1 = threading.Thread(target=contextvars.copy_context().run,
args=(lambda: thread_func('t1'),))
t2 = threading.Thread(target=contextvars.copy_context().run,
args=(lambda: thread_func('t2'),))
t1.start(); t2.start()
t1.join(); t2.join()
self.assertEqual(results['t1'], base + 1)
self.assertEqual(results['t2'], base + 1)


# HAMT Tests

Expand Down
92 changes: 92 additions & 0 deletions Lib/test/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -1770,6 +1770,98 @@ def test_threading(self):
DefaultContext.Emax = save_emax
DefaultContext.Emin = save_emin

@threading_helper.requires_working_threading()
def test_inherited_context_isolation(self):
# Test that when threads inherit contextvars (e.g. via
# sys.flags.thread_inherit_context), each thread gets its own
# copy of the decimal context so mutations don't leak between
# threads. Also verifies correct behavior with asyncio tasks.
Decimal = self.decimal.Decimal
getcontext = self.decimal.getcontext
setcontext = self.decimal.setcontext
Context = self.decimal.Context
Underflow = self.decimal.Underflow

# Set up parent context with specific precision
parent_ctx = getcontext()
parent_ctx.prec = 20

barrier = threading.Barrier(2, timeout=2)
results = {}

def child(name, prec_delta):
barrier.wait()
ctx = getcontext()
# Each child should see a context with the parent's precision
results[name + '_initial_prec'] = ctx.prec
results[name + '_ctx_id'] = id(ctx)
# Mutate this thread's context
ctx.prec += prec_delta
results[name + '_modified_prec'] = ctx.prec

# Spawn threads that inherit the parent's contextvars.
t1 = threading.Thread(target=child, args=('t1', 5),
context=contextvars.copy_context())
t2 = threading.Thread(target=child, args=('t2', 10),
context=contextvars.copy_context())
t1.start()
t2.start()
t1.join()
t2.join()

# Each thread should have started with the parent's precision
self.assertEqual(results['t1_initial_prec'], 20)
self.assertEqual(results['t2_initial_prec'], 20)

# Each thread should have its own context (different id)
self.assertNotEqual(results['t1_ctx_id'], results['t2_ctx_id'])

# Mutations should be independent
self.assertEqual(results['t1_modified_prec'], 25)
self.assertEqual(results['t2_modified_prec'], 30)

# Parent context should be unaffected
self.assertEqual(getcontext().prec, 20)

def test_inherited_context_isolation_async(self):
# An asyncio child task inherits the parent task's context object
# (create_task copies the current context). Each task must get its
# own decimal context so mutations stay isolated. This is the case
# where every task step runs at the same context nesting level, so a
# per-entry depth would collide -- the depth is assigned when the
# context is *copied*, which keeps parent and child distinct.
import asyncio
getcontext = self.decimal.getcontext

async def child(results):
ctx = getcontext()
results['child_initial_prec'] = ctx.prec
results['child_ctx_id'] = id(ctx)
ctx.prec = 7
results['child_modified_prec'] = ctx.prec

async def parent():
results = {}
ctx = getcontext()
ctx.prec = 33
results['parent_ctx_id'] = id(ctx)
await asyncio.create_task(child(results))
results['parent_after_prec'] = getcontext().prec
return results

# Pass loop_factory so asyncio.run() doesn't lazily initialize the
# global event loop policy, which would be reported as "env changed"
# by regrtest (e.g. on iOS/Android where the policy starts as None).
results = asyncio.run(parent(), loop_factory=asyncio.EventLoop)

# Child inherits the parent's precision value...
self.assertEqual(results['child_initial_prec'], 33)
# ...but in its own context object.
self.assertNotEqual(results['parent_ctx_id'], results['child_ctx_id'])
# The child's mutation does not leak back to the parent.
self.assertEqual(results['child_modified_prec'], 7)
self.assertEqual(results['parent_after_prec'], 33)


@requires_cdecimal
class CThreadingTest(ThreadingTest, unittest.TestCase):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Ensure that :func:`decimal.getcontext` returns a per-task copy of the
:class:`decimal.Context` so that mutations are isolated between asyncio
tasks and threads using :data:`sys.flags.thread_inherit_context`.
37 changes: 33 additions & 4 deletions Modules/_decimal/_decimal.c
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#endif

#include <Python.h>
#include "pycore_context.h" // _PyContext_CurrentDepth()
#include "pycore_object.h" // _PyObject_VisitType()
#include "pycore_pystate.h" // _PyThreadState_GET()
#include "pycore_tuple.h" // _PyTuple_FromPair
Expand Down Expand Up @@ -224,6 +225,11 @@ typedef struct PyDecContextObject {
int capitals;
PyThreadState *tstate;
decimal_state *modstate;
/* Depth of the contextvars context this context object was bound into
(see _pycontextobject.ctx_depth). Used to detect when the current
context object was inherited from an outer context/task and therefore
must be copied to keep mutations isolated. */
uint64_t ctx_depth;
} PyDecContextObject;

#define _PyDecContextObject_CAST(op) ((PyDecContextObject *)(op))
Expand All @@ -247,6 +253,7 @@ typedef struct {
#define SdFlags(v) (*_PyDecSignalDictObject_CAST(v)->flags)
#define CTX(v) (&_PyDecContextObject_CAST(v)->ctx)
#define CtxCaps(v) (_PyDecContextObject_CAST(v)->capitals)
#define CtxDepth(v) (_PyDecContextObject_CAST(v)->ctx_depth)

static inline decimal_state *
get_module_state_from_ctx(PyObject *v)
Expand Down Expand Up @@ -1477,6 +1484,7 @@ context_new(PyTypeObject *type,
CtxCaps(self) = 1;
self->tstate = NULL;
self->modstate = state;
self->ctx_depth = 0;

if (type == state->PyDecContext_Type) {
PyObject_GC_Track(self);
Expand Down Expand Up @@ -1915,13 +1923,17 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
}
#else
static PyObject *
init_current_context(decimal_state *state)
init_current_context(decimal_state *state, PyObject *prev_context,
uint64_t depth)
{
PyObject *tl_context = context_copy(state, state->default_context_template);
PyObject *tl_context = context_copy(state, prev_context);
if (tl_context == NULL) {
return NULL;
}
CTX(tl_context)->status = 0;
/* Stamp the copy with the current context's depth so that subsequent
lookups in this same context recognize it as locally owned. */
CtxDepth(tl_context) = depth;

PyObject *tok = PyContextVar_Set(state->current_context_var, tl_context);
if (tok == NULL) {
Expand All @@ -1941,11 +1953,24 @@ current_context(decimal_state *state)
return NULL;
}

uint64_t cur_depth = _PyContext_CurrentDepth();

if (tl_context != NULL) {
return tl_context;
if (CtxDepth(tl_context) == cur_depth) {
/* The context object was created for this same context scope. */
return tl_context;
}
/* The context object was inherited from an outer context (e.g. a
parent thread or asyncio task); copy it so that mutations stay
isolated from the context that shared it. */
PyObject *new_context = init_current_context(state, tl_context,
cur_depth);
Py_DECREF(tl_context);
return new_context;
}

return init_current_context(state);
return init_current_context(state, state->default_context_template,
cur_depth);
}

/* ctxobj := borrowed reference to the current context */
Expand Down Expand Up @@ -1988,6 +2013,10 @@ PyDec_SetCurrentContext(PyObject *self, PyObject *v)
Py_INCREF(v);
}

/* Mark the context object as owned by the current context scope, so a
following getcontext() returns this very object rather than a copy. */
CtxDepth(v) = _PyContext_CurrentDepth();

PyObject *tok = PyContextVar_Set(state->current_context_var, v);
Py_DECREF(v);
if (tok == NULL) {
Expand Down
Loading
Loading