forked from microsoft/agent-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchat_middleware.py
More file actions
253 lines (207 loc) · 9.18 KB
/
chat_middleware.py
File metadata and controls
253 lines (207 loc) · 9.18 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
ChatContext,
ChatMiddleware,
ChatResponse,
Message,
MiddlewareTermination,
chat_middleware,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Chat MiddlewareTypes Example
This sample demonstrates how to use chat middleware to observe and override
inputs sent to AI models. Chat middleware intercepts chat requests before they reach
the underlying AI service, allowing you to:
1. Observe and log input messages
2. Modify input messages before sending to AI
3. Override the entire response
The example covers:
- Class-based chat middleware inheriting from ChatMiddleware
- Function-based chat middleware with @chat_middleware decorator
- MiddlewareTypes registration at agent level (applies to all runs)
- MiddlewareTypes registration at run level (applies to specific run only)
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class InputObserverMiddleware(ChatMiddleware):
"""Class-based middleware that observes and modifies input messages."""
def __init__(self, replacement: str | None = None):
"""Initialize with a replacement for user messages."""
self.replacement = replacement
async def process(
self,
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Observe and modify input messages before they are sent to AI."""
print("[InputObserverMiddleware] Observing input messages:")
for i, message in enumerate(context.messages):
content = message.text if message.text else str(message.contents)
print(f" Message {i + 1} ({message.role}): {content}")
print(f"[InputObserverMiddleware] Total messages: {len(context.messages)}")
# Modify user messages by creating new messages with enhanced text
modified_messages: list[Message] = []
modified_count = 0
for message in context.messages:
if message.role == "user" and message.text:
original_text = message.text
updated_text = original_text
if self.replacement:
updated_text = self.replacement
print(f"[InputObserverMiddleware] Updated: '{original_text}' -> '{updated_text}'")
modified_message = Message(message.role, [updated_text])
modified_messages.append(modified_message)
modified_count += 1
else:
modified_messages.append(message)
# Replace messages in context
context.messages[:] = modified_messages
# Continue to next middleware or AI execution
await call_next()
# Observe that processing is complete
print("[InputObserverMiddleware] Processing completed")
@chat_middleware
async def security_and_override_middleware(
context: ChatContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Function-based middleware that implements security filtering and response override."""
print("[SecurityMiddleware] Processing input...")
# Security check - block sensitive information
blocked_terms = ["password", "secret", "api_key", "token"]
for message in context.messages:
if message.text:
message_lower = message.text.lower()
for term in blocked_terms:
if term in message_lower:
print(f"[SecurityMiddleware] BLOCKED: Found '{term}' in message")
# Override the response instead of calling AI
context.result = ChatResponse(
messages=[
Message(
role="assistant",
text="I cannot process requests containing sensitive information. "
"Please rephrase your question without including passwords, secrets, or other "
"sensitive data.",
)
]
)
# Set terminate flag to stop execution
raise MiddlewareTermination
# Continue to next middleware or AI execution
await call_next()
async def class_based_chat_middleware() -> None:
"""Demonstrate class-based middleware at agent level."""
print("\n" + "=" * 60)
print("Class-based Chat MiddlewareTypes (Agent Level)")
print("=" * 60)
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="EnhancedChatAgent",
instructions="You are a helpful AI assistant.",
# Register class-based middleware at agent level (applies to all runs)
middleware=[InputObserverMiddleware()],
tools=get_weather,
) as agent,
):
query = "What's the weather in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
async def function_based_chat_middleware() -> None:
"""Demonstrate function-based middleware at agent level."""
print("\n" + "=" * 60)
print("Function-based Chat MiddlewareTypes (Agent Level)")
print("=" * 60)
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="FunctionMiddlewareAgent",
instructions="You are a helpful AI assistant.",
# Register function-based middleware at agent level
middleware=[security_and_override_middleware],
) as agent,
):
# Scenario with normal query
print("\n--- Scenario 1: Normal Query ---")
query = "Hello, how are you?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
# Scenario with security violation
print("\n--- Scenario 2: Security Violation ---")
query = "What is my password for this account?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Final Response: {result.text if result.text else 'No response'}")
async def run_level_middleware() -> None:
"""Demonstrate middleware registration at run level."""
print("\n" + "=" * 60)
print("Run-level Chat MiddlewareTypes")
print("=" * 60)
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="RunLevelAgent",
instructions="You are a helpful AI assistant.",
tools=get_weather,
# No middleware at agent level
) as agent,
):
# Scenario 1: Run without any middleware
print("\n--- Scenario 1: No MiddlewareTypes ---")
query = "What's the weather in Tokyo?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Response: {result.text if result.text else 'No response'}")
# Scenario 2: Run with specific middleware for this call only (both enhancement and security)
print("\n--- Scenario 2: With Run-level MiddlewareTypes ---")
print(f"User: {query}")
result = await agent.run(
query,
middleware=[
InputObserverMiddleware(replacement="What's the weather in Madrid?"),
security_and_override_middleware,
],
)
print(f"Response: {result.text if result.text else 'No response'}")
# Scenario 3: Security test with run-level middleware
print("\n--- Scenario 3: Security Test with Run-level MiddlewareTypes ---")
query = "Can you help me with my secret API key?"
print(f"User: {query}")
result = await agent.run(
query,
middleware=[security_and_override_middleware],
)
print(f"Response: {result.text if result.text else 'No response'}")
async def main() -> None:
"""Run all chat middleware examples."""
print("Chat MiddlewareTypes Examples")
print("========================")
await class_based_chat_middleware()
await function_based_chat_middleware()
await run_level_middleware()
if __name__ == "__main__":
asyncio.run(main())