-
Notifications
You must be signed in to change notification settings - Fork 140
Expand file tree
/
Copy pathtest_coroutine_util.py
More file actions
148 lines (103 loc) · 3.32 KB
/
Copy pathtest_coroutine_util.py
File metadata and controls
148 lines (103 loc) · 3.32 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# SPDX-FileCopyrightText: Copyright (c) <2026> NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
from cuda.tile._coroutine_util import resume_after, run_coroutine
import pytest
import traceback
async def series(n):
if n == 0:
return 0
r = await resume_after(series(n - 1))
return r + n
def test_run_coroutine():
n = 10000
res = run_coroutine(series(n))
assert res == sum(range(n + 1))
async def raise_if_zero(n):
if n == 0:
raise ValueError("Hello")
await resume_after(raise_if_zero(n - 1))
def test_propagate_exception():
with pytest.raises(ValueError, match="Hello"):
run_coroutine(raise_if_zero(5))
async def raise_then_catch(n):
if n == 0:
raise ValueError("Hello")
if n == 1:
try:
await resume_after(raise_then_catch(0))
except ValueError as e:
assert str(e) == "Hello"
return 100
assert False
r = await resume_after(raise_then_catch(n - 1))
return r + n
def test_raise_then_catch():
res = run_coroutine(raise_then_catch(4))
assert res == 100 + 2 + 3 + 4
async def return_123():
return 123
async def raise_then_catch_and_call_another(n):
if n == 0:
raise ValueError("Hello")
if n == 1:
try:
await resume_after(raise_then_catch_and_call_another(0))
except ValueError as e:
assert str(e) == "Hello"
x = await resume_after(return_123())
return x
assert False
r = await resume_after(raise_then_catch_and_call_another(n - 1))
return r + n
def test_raise_then_catch_and_call_another():
res = run_coroutine(raise_then_catch_and_call_another(4))
assert res == 123 + 2 + 3 + 4
async def two_calls():
t1 = await resume_after(series(3))
t2 = await resume_after(series(4))
return t1, t2
def test_return_values():
res = run_coroutine(two_calls())
assert res == (1 + 2 + 3, 1 + 2 + 3 + 4)
async def raise_in_leaf():
raise ValueError("leaf")
async def call_leaf():
await resume_after(raise_in_leaf())
def test_traceback_preserved():
try:
run_coroutine(call_leaf())
except ValueError as e:
traceback.print_tb(e.__traceback__)
frame_names = [f.name for f in traceback.extract_tb(e.__traceback__)]
assert "raise_in_leaf" in frame_names
else:
assert False
class WeirdAwaitable:
def __await__(self):
return iter([123])
async def weird_await():
await WeirdAwaitable()
async def call_weird_await():
await weird_await()
def test_unexpected_awaitable():
try:
run_coroutine(call_weird_await())
except TypeError as e:
assert "Expected a continuation coroutine" in str(e)
traceback.print_tb(e.__traceback__)
frame_names = [f.name for f in traceback.extract_tb(e.__traceback__)]
assert "call_weird_await" in frame_names
else:
assert False
async def resume_after_call_weird_await(flag):
try:
await resume_after(call_weird_await())
finally:
flag[0] = True
def test_cleanup_after_internal_error():
flag = [False]
coro = resume_after_call_weird_await(flag)
with pytest.raises(TypeError, match="Expected a continuation coroutine"):
run_coroutine(coro)
assert flag[0] is True