forked from microsoft/agent-framework
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclass_based_middleware.py
More file actions
131 lines (105 loc) · 4.59 KB
/
class_based_middleware.py
File metadata and controls
131 lines (105 loc) · 4.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import time
from collections.abc import Awaitable, Callable
from random import randint
from typing import Annotated
from agent_framework import (
AgentContext,
AgentMiddleware,
AgentResponse,
FunctionInvocationContext,
FunctionMiddleware,
Message,
tool,
)
from agent_framework.azure import AzureAIAgentClient
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
# Load environment variables from .env file
load_dotenv()
"""
Class-based MiddlewareTypes Example
This sample demonstrates how to implement middleware using class-based approach by inheriting
from AgentMiddleware and FunctionMiddleware base classes. The example includes:
- SecurityAgentMiddleware: Checks for security violations in user queries and blocks requests
containing sensitive information like passwords or secrets
- LoggingFunctionMiddleware: Logs function execution details including timing and parameters
This approach is useful when you need stateful middleware or complex logic that benefits
from object-oriented design patterns.
"""
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
) -> str:
"""Get the weather for a given location."""
conditions = ["sunny", "cloudy", "rainy", "stormy"]
return f"The weather in {location} is {conditions[randint(0, 3)]} with a high of {randint(10, 30)}°C."
class SecurityAgentMiddleware(AgentMiddleware):
"""Agent middleware that checks for security violations."""
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
# Check for potential security violations in the query
# Look at the last user message
last_message = context.messages[-1] if context.messages else None
if last_message and last_message.text:
query = last_message.text
if "password" in query.lower() or "secret" in query.lower():
print("[SecurityAgentMiddleware] Security Warning: Detected sensitive information, blocking request.")
# Override the result with warning message
context.result = AgentResponse(
messages=[Message("assistant", ["Detected sensitive information, the request is blocked."])]
)
# Simply don't call call_next() to prevent execution
return
print("[SecurityAgentMiddleware] Security check passed.")
await call_next()
class LoggingFunctionMiddleware(FunctionMiddleware):
"""Function middleware that logs function calls."""
async def process(
self,
context: FunctionInvocationContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
function_name = context.function.name
print(f"[LoggingFunctionMiddleware] About to call function: {function_name}.")
start_time = time.time()
await call_next()
end_time = time.time()
duration = end_time - start_time
print(f"[LoggingFunctionMiddleware] Function {function_name} completed in {duration:.5f}s.")
async def main() -> None:
"""Example demonstrating class-based middleware."""
print("=== Class-based MiddlewareTypes Example ===")
# For authentication, run `az login` command in terminal or replace AzureCliCredential with preferred
# authentication option.
async with (
AzureCliCredential() as credential,
AzureAIAgentClient(credential=credential).as_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=get_weather,
middleware=[SecurityAgentMiddleware(), LoggingFunctionMiddleware()],
) as agent,
):
# Test with normal query
print("\n--- Normal Query ---")
query = "What's the weather like in Seattle?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
# Test with security-related query
print("--- Security Test ---")
query = "What's the password for the weather service?"
print(f"User: {query}")
result = await agent.run(query)
print(f"Agent: {result.text}\n")
if __name__ == "__main__":
asyncio.run(main())