Skip to content

Commit 3835149

Browse files
committed
marshal: preserve safe recursive hashable refs
1 parent 7c214ea commit 3835149

5 files changed

Lines changed: 978 additions & 63 deletions

File tree

Include/internal/pycore_dict.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,8 @@ typedef struct {
423423
Py_hash_t ma_hash;
424424
} PyFrozenDictObject;
425425

426+
#define _Py_FROZENDICT_HASH_CONSTRUCTING ((Py_hash_t)-2)
427+
426428
#define _PyFrozenDictObject_CAST(op) \
427429
(assert(PyFrozenDict_Check(op)), _Py_CAST(PyFrozenDictObject*, (op)))
428430

Lib/test/test_marshal.py

Lines changed: 317 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
from test.support import is_apple_mobile, os_helper, requires_debug_ranges, is_emscripten
33
from test.support.script_helper import assert_python_ok
44
import array
5+
import itertools
56
import io
67
import marshal
78
import sys
9+
import typing
810
import unittest
911
import os
1012
import types
@@ -731,24 +733,322 @@ def test_read_object_from_file(self):
731733
os_helper.unlink(os_helper.TESTFN)
732734

733735

734-
class SelfRefTupleTest(unittest.TestCase):
735-
"""Regression test for gh-148653: TYPE_TUPLE with FLAG_REF back-reference.
736-
737-
R_REF registered the tuple in p->refs before its slots were populated.
738-
A TYPE_REF back-reference to the partial tuple could reach a hashing
739-
site (PySet_Add) with NULL slots, crashing with SIGSEGV.
740-
741-
The fix uses the two-phase r_ref_reserve/r_ref_insert pattern so the
742-
Py_None placeholder is detected by the TYPE_REF handler, raising
743-
ValueError instead.
744-
"""
736+
WrapperTargetKind = typing.Literal["tuple", "slice", "frozendict"]
737+
MutableTargetKind = typing.Literal["list", "dict_value"]
738+
BridgeStepKind = typing.Literal["list", "dict_value"]
739+
WrapperRootLayoutKind = typing.Literal[
740+
"target", "first_bridge", "outer_list_pair", "outer_dict_pair"]
741+
MutableRootLayoutKind = typing.Literal["target", "outer_list", "outer_dict"]
742+
BackrefMultiplicityKind = typing.Literal["single", "duplicate"]
743+
744+
745+
class SemanticRecursiveCase(typing.NamedTuple):
746+
name: str
747+
min_version: int
748+
build: typing.Callable[[], object]
749+
750+
751+
class RecursivePayloadCase(typing.NamedTuple):
752+
name: str
753+
payload: bytes
754+
expected: typing.Callable[[], object]
755+
756+
757+
def _marshal_atomic_signature(value: object) -> tuple[str, object] | None:
758+
if value is None:
759+
return ("none", None)
760+
if value is Ellipsis:
761+
return ("ellipsis", None)
762+
if value is StopIteration:
763+
return ("stopiter", None)
764+
if isinstance(value, bool):
765+
return ("bool", value)
766+
if isinstance(value, int):
767+
return ("int", value)
768+
if isinstance(value, float):
769+
return ("float", value)
770+
if isinstance(value, complex):
771+
return ("complex", value)
772+
if isinstance(value, str):
773+
return ("str", value)
774+
if isinstance(value, bytes):
775+
return ("bytes", value)
776+
return None
777+
778+
779+
def _marshal_mapping_signature(
780+
mapping: dict[object, object] | frozendict,
781+
encode: typing.Callable[[object], object],
782+
) -> tuple[tuple[tuple[str, object], object], ...]:
783+
entries = []
784+
for key, value in mapping.items():
785+
key_sig = _marshal_atomic_signature(key)
786+
if key_sig is None:
787+
raise TypeError(
788+
"recursive marshal test only supports atomic mapping keys")
789+
entries.append((key_sig, encode(value)))
790+
entries.sort()
791+
return tuple(entries)
792+
793+
794+
def _marshal_graph_signature(root: object) -> object:
795+
seen: dict[int, int] = {}
796+
nodes: list[object] = []
797+
798+
def encode(value: object) -> object:
799+
atomic = _marshal_atomic_signature(value)
800+
if atomic is not None:
801+
return atomic
802+
803+
obj_id = id(value)
804+
if obj_id in seen:
805+
return ("ref", seen[obj_id])
806+
807+
node_id = len(nodes)
808+
seen[obj_id] = node_id
809+
nodes.append(("pending",))
810+
811+
if isinstance(value, list):
812+
node = ("list", tuple(encode(item) for item in value))
813+
elif isinstance(value, tuple):
814+
node = ("tuple", tuple(encode(item) for item in value))
815+
elif isinstance(value, frozendict):
816+
node = ("frozendict", _marshal_mapping_signature(value, encode))
817+
elif isinstance(value, dict):
818+
node = ("dict", _marshal_mapping_signature(value, encode))
819+
elif isinstance(value, slice):
820+
node = ("slice", (
821+
("start", encode(value.start)),
822+
("stop", encode(value.stop)),
823+
("step", encode(value.step)),
824+
))
825+
else:
826+
raise TypeError(
827+
f"unsupported recursive marshal test node type: {type(value)!r}")
828+
829+
nodes[node_id] = node
830+
return ("ref", node_id)
831+
832+
return (encode(root), tuple(nodes))
833+
834+
835+
def _make_bridge(bridge_kind: BridgeStepKind) -> list[object] | dict[str, object]:
836+
if bridge_kind == "list":
837+
return []
838+
return {}
839+
840+
841+
def _link_bridge(bridge_kind: BridgeStepKind,
842+
bridge: list[object] | dict[str, object],
843+
value: object,
844+
multiplicity: BackrefMultiplicityKind) -> None:
845+
if bridge_kind == "list":
846+
bridge = typing.cast(list[object], bridge)
847+
bridge.append(value)
848+
if multiplicity == "duplicate":
849+
bridge.append(value)
850+
else:
851+
bridge = typing.cast(dict[str, object], bridge)
852+
bridge["x"] = value
853+
if multiplicity == "duplicate":
854+
bridge["y"] = value
855+
856+
857+
def _build_wrapper_target(
858+
target_kind: WrapperTargetKind,
859+
bridge: list[object] | dict[str, object],
860+
) -> object:
861+
if target_kind == "tuple":
862+
return (bridge,)
863+
if target_kind == "slice":
864+
return slice(None, bridge, None)
865+
return frozendict({None: bridge})
866+
867+
868+
def _build_wrapper_recursive_case(
869+
target_kind: WrapperTargetKind,
870+
bridge_path: tuple[BridgeStepKind, ...],
871+
root_layout: WrapperRootLayoutKind,
872+
multiplicity: BackrefMultiplicityKind,
873+
) -> object:
874+
bridges = [_make_bridge(bridge_kind) for bridge_kind in bridge_path]
875+
target = _build_wrapper_target(target_kind, bridges[0])
876+
for index, bridge_kind in enumerate(bridge_path[:-1]):
877+
_link_bridge(bridge_kind, bridges[index], bridges[index + 1], "single")
878+
_link_bridge(bridge_path[-1], bridges[-1], target, multiplicity)
879+
880+
if root_layout == "target":
881+
return target
882+
if root_layout == "first_bridge":
883+
return bridges[0]
884+
if root_layout == "outer_list_pair":
885+
return [target, bridges[-1]]
886+
return {"target": target, "bridge": bridges[-1]}
887+
888+
889+
def _build_mutable_recursive_case(
890+
target_kind: MutableTargetKind,
891+
root_layout: MutableRootLayoutKind,
892+
multiplicity: BackrefMultiplicityKind,
893+
) -> object:
894+
if target_kind == "list":
895+
target = []
896+
_link_bridge("list", target, target, multiplicity)
897+
else:
898+
target = {}
899+
_link_bridge("dict_value", target, target, multiplicity)
900+
901+
if root_layout == "target":
902+
return target
903+
if root_layout == "outer_list":
904+
return [target]
905+
return {"target": target}
906+
907+
908+
def _iter_semantic_recursive_cases() -> typing.Iterator[SemanticRecursiveCase]:
909+
mutable_target_kinds = typing.cast(
910+
tuple[MutableTargetKind, ...], ("list", "dict_value"))
911+
mutable_root_layouts = typing.cast(
912+
tuple[MutableRootLayoutKind, ...], ("target", "outer_list", "outer_dict"))
913+
wrapper_target_kinds = typing.cast(
914+
tuple[WrapperTargetKind, ...], ("tuple", "slice", "frozendict"))
915+
wrapper_root_layouts = typing.cast(
916+
tuple[WrapperRootLayoutKind, ...],
917+
("target", "first_bridge", "outer_list_pair", "outer_dict_pair"))
918+
multiplicities = typing.cast(
919+
tuple[BackrefMultiplicityKind, ...], ("single", "duplicate"))
920+
bridge_steps = typing.cast(tuple[BridgeStepKind, ...], ("list", "dict_value"))
921+
bridge_paths = tuple(
922+
typing.cast(tuple[BridgeStepKind, ...], bridge_path)
923+
for path_len in (1, 2)
924+
for bridge_path in itertools.product(bridge_steps, repeat=path_len)
925+
)
926+
927+
for target_kind, root_layout, multiplicity in itertools.product(
928+
mutable_target_kinds, mutable_root_layouts, multiplicities):
929+
def build(target_kind: MutableTargetKind = target_kind,
930+
root_layout: MutableRootLayoutKind = root_layout,
931+
multiplicity: BackrefMultiplicityKind = multiplicity) -> object:
932+
return _build_mutable_recursive_case(
933+
target_kind, root_layout, multiplicity)
934+
935+
yield SemanticRecursiveCase(
936+
name=f"{target_kind}_self_{root_layout}_{multiplicity}",
937+
min_version=3,
938+
build=build,
939+
)
940+
941+
for target_kind, bridge_path, root_layout, multiplicity in itertools.product(
942+
wrapper_target_kinds, bridge_paths, wrapper_root_layouts, multiplicities):
943+
if target_kind == "tuple":
944+
min_version = 3
945+
elif target_kind == "slice":
946+
min_version = 5
947+
else:
948+
min_version = 6
949+
950+
def build(target_kind: WrapperTargetKind = target_kind,
951+
bridge_path: tuple[BridgeStepKind, ...] = bridge_path,
952+
root_layout: WrapperRootLayoutKind = root_layout,
953+
multiplicity: BackrefMultiplicityKind = multiplicity) -> object:
954+
return _build_wrapper_recursive_case(
955+
target_kind, bridge_path, root_layout, multiplicity)
956+
957+
bridge_path_name = "_".join(bridge_path)
958+
yield SemanticRecursiveCase(
959+
name=(
960+
f"{target_kind}_via_{bridge_path_name}_"
961+
f"{root_layout}_{multiplicity}"
962+
),
963+
min_version=min_version,
964+
build=build,
965+
)
966+
967+
968+
def _iter_valid_recursive_payload_cases() -> typing.Iterator[RecursivePayloadCase]:
969+
yield RecursivePayloadCase(
970+
name="tuple_with_duplicate_backrefs_in_list_payload",
971+
payload=(b'\xa9\x01'
972+
b'[\x02\x00\x00\x00'
973+
b'r\x00\x00\x00\x00'
974+
b'r\x00\x00\x00\x00'),
975+
expected=lambda: _build_wrapper_recursive_case(
976+
"tuple", ("list",), "target", "duplicate"),
977+
)
978+
yield RecursivePayloadCase(
979+
name="root_list_with_inner_tuple_backref_payload",
980+
payload=(b'\xdb\x02\x00\x00\x00'
981+
b'\xa9\x01'
982+
b'\xdb\x01\x00\x00\x00'
983+
b'r\x01\x00\x00\x00'
984+
b'r\x02\x00\x00\x00'),
985+
expected=lambda: _build_wrapper_recursive_case(
986+
"tuple", ("list",), "outer_list_pair", "single"),
987+
)
988+
989+
990+
def _iter_invalid_recursive_payloads() -> typing.Iterator[tuple[str, bytes]]:
991+
yield (
992+
"tuple_self_in_set",
993+
b'\xa8\x02\x00\x00\x00N<\x01\x00\x00\x00r\x00\x00\x00\x00',
994+
)
995+
yield (
996+
"tuple_direct_self_reference",
997+
b'\xa9\x01r\x00\x00\x00\x00',
998+
)
999+
yield (
1000+
"tuple_as_incomplete_dict_key",
1001+
b'\xa9\x01{r\x00\x00\x00\x00N0',
1002+
)
1003+
if marshal.version >= 5:
1004+
yield (
1005+
"slice_direct_self_reference",
1006+
b'\xbaNr\x00\x00\x00\x00N',
1007+
)
1008+
if marshal.version >= 6:
1009+
yield (
1010+
"frozendict_direct_self_value",
1011+
b'\xfdNr\x00\x00\x00\x000',
1012+
)
1013+
1014+
1015+
class RecursiveGraphTest(unittest.TestCase):
1016+
def assert_marshal_graph_roundtrip(self, sample: object, version: int) -> None:
1017+
expected = _marshal_graph_signature(sample)
1018+
loaded = marshal.loads(marshal.dumps(sample, version))
1019+
self.assertEqual(expected, _marshal_graph_signature(loaded))
1020+
try:
1021+
with open(os_helper.TESTFN, "wb") as file:
1022+
marshal.dump(sample, file, version)
1023+
with open(os_helper.TESTFN, "rb") as file:
1024+
loaded = marshal.load(file)
1025+
self.assertEqual(expected, _marshal_graph_signature(loaded))
1026+
finally:
1027+
os_helper.unlink(os_helper.TESTFN)
7451028

746-
def test_self_ref_tuple(self):
747-
# TYPE_TUPLE|FLAG_REF n=2; NONE; TYPE_SET n=1; TYPE_REF(0)
748-
payload = (b'\xa8\x02\x00\x00\x00N'
749-
b'<\x01\x00\x00\x00r\x00\x00\x00\x00')
750-
with self.assertRaises(ValueError):
751-
marshal.loads(payload)
1029+
def test_constructible_recursive_case_count(self):
1030+
self.assertEqual(len(tuple(_iter_semantic_recursive_cases())), 156)
1031+
1032+
def test_constructible_recursive_roundtrips(self):
1033+
for case in _iter_semantic_recursive_cases():
1034+
for version in range(case.min_version, marshal.version + 1):
1035+
with self.subTest(case=case.name, version=version):
1036+
self.assert_marshal_graph_roundtrip(case.build(), version)
1037+
1038+
def test_handpicked_recursive_payloads(self):
1039+
for case in _iter_valid_recursive_payload_cases():
1040+
with self.subTest(case.name):
1041+
loaded = marshal.loads(case.payload)
1042+
self.assertEqual(
1043+
_marshal_graph_signature(case.expected()),
1044+
_marshal_graph_signature(loaded),
1045+
)
1046+
1047+
def test_invalid_recursive_payloads(self):
1048+
for name, payload in _iter_invalid_recursive_payloads():
1049+
with self.subTest(name):
1050+
with self.assertRaises(ValueError):
1051+
marshal.loads(payload)
7521052

7531053

7541054
if __name__ == "__main__":

0 commit comments

Comments
 (0)