Skip to content

Commit 04e193c

Browse files
committed
gh-141510: Fix frozendict.fromkeys() for subclasses
Copy the frozendict.
1 parent 7ac0868 commit 04e193c

File tree

2 files changed

+52
-2
lines changed

2 files changed

+52
-2
lines changed

Lib/test/test_dict.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1787,6 +1787,31 @@ def test_hash(self):
17871787
with self.assertRaisesRegex(TypeError, "unhashable type: 'list'"):
17881788
hash(fd)
17891789

1790+
def test_fromkeys(self):
1791+
self.assertEqual(frozendict.fromkeys('abc'),
1792+
frozendict(a=None, b=None, c=None))
1793+
1794+
# Subclass which overrides the constructor
1795+
class FrozenDictSubclass(frozendict):
1796+
def __new__(self):
1797+
return frozendict(x=1)
1798+
1799+
fd = FrozenDictSubclass.fromkeys("abc")
1800+
self.assertEqual(fd, frozendict(x=1, a=None, b=None, c=None))
1801+
self.assertEqual(type(fd), FrozenDictSubclass)
1802+
1803+
fd = FrozenDictSubclass.fromkeys(frozendict(y=2))
1804+
self.assertEqual(fd, frozendict(x=1, y=None))
1805+
self.assertEqual(type(fd), FrozenDictSubclass)
1806+
1807+
# Subclass which doesn't override the constructor
1808+
class FrozenDictSubclass2(frozendict):
1809+
pass
1810+
1811+
fd = FrozenDictSubclass2.fromkeys("abc")
1812+
self.assertEqual(fd, frozendict(a=None, b=None, c=None))
1813+
self.assertEqual(type(fd), FrozenDictSubclass2)
1814+
17901815

17911816
if __name__ == "__main__":
17921817
unittest.main()

Objects/dictobject.c

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ As a consequence of this, split keys have a maximum size of 16.
138138
// Forward declarations
139139
static PyObject* frozendict_new(PyTypeObject *type, PyObject *args,
140140
PyObject *kwds);
141+
static int dict_merge(PyObject *a, PyObject *b, int override);
141142

142143

143144
/*[clinic input]
@@ -3286,9 +3287,31 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
32863287
int status;
32873288

32883289
d = _PyObject_CallNoArgs(cls);
3289-
if (d == NULL)
3290+
if (d == NULL) {
32903291
return NULL;
3292+
}
32913293

3294+
// If cls is a frozendict subclass with overridden constructor,
3295+
// copy the frozendict.
3296+
PyTypeObject *cls_type = _PyType_CAST(cls);
3297+
if (PyFrozenDict_Check(d)
3298+
&& PyObject_IsSubclass(cls, (PyObject*)&PyFrozenDict_Type)
3299+
&& cls_type->tp_new != frozendict_new)
3300+
{
3301+
// Subclass-friendly copy
3302+
PyObject *copy = frozendict_new(cls_type, NULL, NULL);
3303+
if (copy == NULL) {
3304+
Py_DECREF(d);
3305+
return NULL;
3306+
}
3307+
if (dict_merge(copy, d, 1) < 0) {
3308+
Py_DECREF(d);
3309+
Py_DECREF(copy);
3310+
return NULL;
3311+
}
3312+
Py_SETREF(d, copy);
3313+
}
3314+
assert(!PyFrozenDict_Check(d) || Py_REFCNT(d) == 1);
32923315

32933316
if (PyDict_CheckExact(d)) {
32943317
if (PyDict_CheckExact(iterable)) {
@@ -3359,7 +3382,7 @@ _PyDict_FromKeys(PyObject *cls, PyObject *iterable, PyObject *value)
33593382
dict_iter_exit:;
33603383
Py_END_CRITICAL_SECTION();
33613384
}
3362-
else if (PyFrozenDict_CheckExact(d)) {
3385+
else if (PyFrozenDict_Check(d)) {
33633386
while ((key = PyIter_Next(it)) != NULL) {
33643387
// anydict_setitem_take2 consumes a reference to key
33653388
status = anydict_setitem_take2((PyDictObject *)d,
@@ -7994,6 +8017,8 @@ frozendict_new(PyTypeObject *type, PyObject *args, PyObject *kwds)
79948017
if (d == NULL) {
79958018
return NULL;
79968019
}
8020+
assert(Py_REFCNT(d) == 1);
8021+
79978022
PyFrozenDictObject *self = _PyFrozenDictObject_CAST(d);
79988023
self->ma_hash = -1;
79998024

0 commit comments

Comments
 (0)