forked from microsoft/agent-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathswitch_case_edge_group.py
More file actions
238 lines (190 loc) · 9.79 KB
/
switch_case_edge_group.py
File metadata and controls
238 lines (190 loc) · 9.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import os
from dataclasses import dataclass
from typing import Any, Literal
from uuid import uuid4
from agent_framework import ( # Core chat primitives used to form LLM requests
Agent,
AgentExecutor,
AgentExecutorRequest, # Message bundle sent to an AgentExecutor
AgentExecutorResponse, # Result returned by an AgentExecutor
Case,
Default, # Default branch when no cases match
Message,
WorkflowBuilder, # Fluent builder for assembling the graph
WorkflowContext, # Per-run context and event bus
executor, # Decorator to turn a function into a workflow executor
)
from agent_framework.azure import AzureOpenAIResponsesClient # Thin client for Azure OpenAI chat models
from azure.identity import AzureCliCredential # Uses your az CLI login for credentials
from dotenv import load_dotenv
from pydantic import BaseModel # Structured outputs with validation
from typing_extensions import Never
# Load environment variables from .env file
load_dotenv()
"""
Sample: Switch-Case Edge Group with an explicit Uncertain branch.
The workflow stores a single email in workflow state, asks a spam detection agent for a three way decision,
then routes with a switch-case group: NotSpam to the drafting assistant, Spam to a spam handler, and
Default to an Uncertain handler.
Purpose:
Demonstrate deterministic one of N routing with switch-case edges. Show how to:
- Persist input once in workflow state, then pass around a small typed pointer that carries the email id.
- Validate agent JSON with Pydantic models for robust parsing.
- Keep executor responsibilities narrow. Transform model output to a typed DetectionResult, then route based
on that type.
- Use ctx.yield_output() to provide workflow results - the workflow completes when idle with no pending work.
Prerequisites:
- AZURE_AI_PROJECT_ENDPOINT must be your Azure AI Foundry Agent Service (V2) project endpoint.
- Familiarity with WorkflowBuilder, executors, edges, and events.
- Understanding of switch-case edge groups and how Case and Default are evaluated in order.
- Working Azure OpenAI configuration for AzureOpenAIResponsesClient, with Azure CLI login and required environment variables.
- Access to workflow/resources/ambiguous_email.txt, or accept the inline fallback string.
"""
EMAIL_STATE_PREFIX = "email:"
CURRENT_EMAIL_ID_KEY = "current_email_id"
class DetectionResultAgent(BaseModel):
"""Structured output returned by the spam detection agent."""
# The agent classifies the email and provides a rationale.
spam_decision: Literal["NotSpam", "Spam", "Uncertain"]
reason: str
class EmailResponse(BaseModel):
"""Structured output returned by the email assistant agent."""
# The drafted professional reply.
response: str
@dataclass
class DetectionResult:
# Internal typed payload used for routing and downstream handling.
spam_decision: str
reason: str
email_id: str
@dataclass
class Email:
# In memory record of the email content stored in workflow state.
email_id: str
email_content: str
def get_case(expected_decision: str):
"""Factory that returns a predicate matching a specific spam_decision value."""
def condition(message: Any) -> bool:
# Only match when the upstream payload is a DetectionResult with the expected decision.
return isinstance(message, DetectionResult) and message.spam_decision == expected_decision
return condition
@executor(id="store_email")
async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
# Persist the raw email once. Store under a unique key and set the current pointer for convenience.
new_email = Email(email_id=str(uuid4()), email_content=email_text)
ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
# Kick off the detector by forwarding the email as a user message to the spam_detection_agent.
await ctx.send_message(
AgentExecutorRequest(messages=[Message("user", text=new_email.email_content)], should_respond=True)
)
@executor(id="to_detection_result")
async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowContext[DetectionResult]) -> None:
# Parse the detector JSON into a typed model. Attach the current email id for downstream lookups.
parsed = DetectionResultAgent.model_validate_json(response.agent_response.text)
email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY)
await ctx.send_message(DetectionResult(spam_decision=parsed.spam_decision, reason=parsed.reason, email_id=email_id))
@executor(id="submit_to_email_assistant")
async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
# Only proceed for the NotSpam branch. Guard against accidental misrouting.
if detection.spam_decision != "NotSpam":
raise RuntimeError("This executor should only handle NotSpam messages.")
# Load the original content from workflow state using the id carried in DetectionResult.
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
await ctx.send_message(
AgentExecutorRequest(messages=[Message("user", text=email.email_content)], should_respond=True)
)
@executor(id="finalize_and_send")
async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContext[Never, str]) -> None:
# Terminal step for the drafting branch. Yield the email response as output.
parsed = EmailResponse.model_validate_json(response.agent_response.text)
await ctx.yield_output(f"Email sent: {parsed.response}")
@executor(id="handle_spam")
async def handle_spam(detection: DetectionResult, ctx: WorkflowContext[Never, str]) -> None:
# Spam path terminal. Include the detector's rationale.
if detection.spam_decision == "Spam":
await ctx.yield_output(f"Email marked as spam: {detection.reason}")
else:
raise RuntimeError("This executor should only handle Spam messages.")
@executor(id="handle_uncertain")
async def handle_uncertain(detection: DetectionResult, ctx: WorkflowContext[Never, str]) -> None:
# Uncertain path terminal. Surface the original content to aid human review.
if detection.spam_decision == "Uncertain":
email: Email | None = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
await ctx.yield_output(
f"Email marked as uncertain: {detection.reason}. Email content: {getattr(email, 'email_content', '')}"
)
else:
raise RuntimeError("This executor should only handle Uncertain messages.")
def create_spam_detection_agent() -> Agent:
"""Create and return the spam detection agent."""
return AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
credential=AzureCliCredential(),
).as_agent(
instructions=(
"You are a spam detection assistant that identifies spam emails. "
"Be less confident in your assessments. "
"Always return JSON with fields 'spam_decision' (one of NotSpam, Spam, Uncertain) "
"and 'reason' (string)."
),
name="spam_detection_agent",
default_options={"response_format": DetectionResultAgent},
)
def create_email_assistant_agent() -> Agent:
"""Create and return the email assistant agent."""
return AzureOpenAIResponsesClient(
project_endpoint=os.environ["AZURE_AI_PROJECT_ENDPOINT"],
deployment_name=os.environ["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
credential=AzureCliCredential(),
).as_agent(
instructions=("You are an email assistant that helps users draft responses to emails with professionalism."),
name="email_assistant_agent",
default_options={"response_format": EmailResponse},
)
async def main():
"""Main function to run the workflow."""
# Build workflow: store -> detection agent -> to_detection_result -> switch (NotSpam or Spam or Default).
# The switch-case group evaluates cases in order, then falls back to Default when none match.
spam_detection_agent = AgentExecutor(create_spam_detection_agent())
email_assistant_agent = AgentExecutor(create_email_assistant_agent())
workflow = (
WorkflowBuilder(start_executor=store_email)
.add_edge(store_email, spam_detection_agent)
.add_edge(spam_detection_agent, to_detection_result)
.add_switch_case_edge_group(
to_detection_result,
[
Case(condition=get_case("NotSpam"), target=submit_to_email_assistant),
Case(condition=get_case("Spam"), target=handle_spam),
Default(target=handle_uncertain),
],
)
.add_edge(submit_to_email_assistant, email_assistant_agent)
.add_edge(email_assistant_agent, finalize_and_send)
.build()
)
# Read ambiguous email if available. Otherwise use a simple inline sample.
resources_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.realpath(__file__))), "resources", "ambiguous_email.txt"
)
if os.path.exists(resources_path):
with open(resources_path, encoding="utf-8") as f: # noqa: ASYNC230
email = f.read()
else:
print("Unable to find resource file, using default text.")
email = (
"Hey there, I noticed you might be interested in our latest offer—no pressure, but it expires soon. "
"Let me know if you'd like more details."
)
# Run and print the outputs from whichever branch completes.
events = await workflow.run(email)
outputs = events.get_outputs()
if outputs:
for output in outputs:
print(f"Workflow output: {output}")
if __name__ == "__main__":
asyncio.run(main())