Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .env.template
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,11 @@ version=9.0
#internal_ai_app_key=""
#internal_ai_token_url=""
#internal_ai_base_url=""

#bedrock_model_id=""
#bedrock_aws_region=""
#bedrock_base_model_id=""

# Sonnet fallback for tests that require a more capable model
#bedrock_sonnet_model_id=""
#bedrock_sonnet_base_model_id=""
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ compat = ["six>=1.17.0"]
ai = ["httpx==0.28.1", "langchain>=1.2.13", "mcp>=1.26.0", "pydantic>=2.7.4"]
anthropic = ["splunk-sdk[ai]>=2.1.1", "langchain-anthropic>=1.4.0"]
openai = ["splunk-sdk[ai]>=2.1.1", "langchain-openai>=1.1.12"]
bedrock = ["splunk-sdk[anthropic]>=2.1.1", "langchain-aws>=0.2.0"]

# Treat the same as NPM's `devDependencies`
[dependency-groups]
Expand All @@ -50,7 +51,7 @@ release = ["build>=1.4.2", "jinja2>=3.1.6", "sphinx>=9.1.0", "twine>=6.2.0"]
lint = ["basedpyright>=1.38.4", "ruff>=0.15.8"]
dev = [
"rich>=14.3.3",
"splunk-sdk[openai, anthropic]",
"splunk-sdk[openai, anthropic, bedrock]",
{ include-group = "test" },
{ include-group = "lint" },
{ include-group = "release" },
Expand Down
2 changes: 1 addition & 1 deletion splunklib/ai/engines/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@
LC_ModelRequest = Langchain_ModelRequest["InvokeContext"]

# Set to True to enable debugging mode.
_DEBUG = False
_DEBUG = True

# Disallow _DEBUG == True in CI.
# Github actions sets the CI env var.
Expand Down
45 changes: 42 additions & 3 deletions tests/ai_test_model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import collections.abc
from typing import override
from dataclasses import dataclass
from typing import Any, override

import httpx
from httpx import Auth, Request, Response
from langchain_core.language_models import BaseChatModel
from pydantic import BaseModel

from splunklib.ai import OpenAIModel
from splunklib.ai import AnthropicModel, OpenAIModel
from splunklib.ai.model import PredefinedModel


Expand All @@ -18,14 +20,51 @@ class InternalAIModel(BaseModel):
base_url: str


@dataclass(frozen=True)
class AnthropicBedrockModel(AnthropicModel):
"""Anthropic model accessed via AWS Bedrock, for testing only."""

api_key: str = ""
base_url: str = ""
aws_region: str = ""
base_model_id: str = ""

def _to_langchain_model(self) -> BaseChatModel:
try:
from langchain_aws import ChatBedrockConverse

kwargs: dict[str, Any] = {"model": self.model}
if self.aws_region:
kwargs["region_name"] = self.aws_region
if self.temperature is not None:
kwargs["temperature"] = self.temperature
if self.model.startswith("arn:"):
kwargs["provider"] = "anthropic"
kwargs["base_model_id"] = (
self.base_model_id or "anthropic.claude-haiku-4-5-20251001"
)
return ChatBedrockConverse(**kwargs)
except ImportError:
raise ImportError(
"AWS Bedrock support is not installed.\n"
+ "To enable Bedrock models, install the optional extra:\n"
+ 'pip install "splunk-sdk[bedrock]"\n'
+ "# or if using uv:\n"
+ "uv add splunk-sdk[bedrock]"
)


class TestLLMSettings(BaseModel):
# TODO: Currently we only support our internal OpenAI-compatible model,
# once we are close to GA we should also support OpenAI and probably Ollama, such
# that external developers can also run our test suite suite locally.
internal_ai: InternalAIModel | None = None
anthropic_bedrock: AnthropicBedrockModel | None = None


async def create_model(s: TestLLMSettings) -> PredefinedModel:
if s.anthropic_bedrock is not None:
return s.anthropic_bedrock
if s.internal_ai is not None:
return await _buildInternalAIModel(
token_url=s.internal_ai.token_url,
Expand All @@ -46,7 +85,7 @@ def __init__(self, token: str) -> None:
@override
def auth_flow(
self, request: Request
) -> collections.abc.Generator[Request, Response, None]:
) -> collections.abc.Generator[Request, Response]:
request.headers["api-key"] = self.token
yield request

Expand Down
60 changes: 59 additions & 1 deletion tests/ai_testlib.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from typing import override

from splunklib.ai.model import PredefinedModel
from tests.ai_test_model import InternalAIModel, TestLLMSettings, create_model
from tests.ai_test_model import (
AnthropicBedrockModel,
InternalAIModel,
TestLLMSettings,
create_model,
)
from tests.testlib import SDKTestCase


class AITestCase(SDKTestCase):
_model: PredefinedModel | None = None
_sonnet_model: PredefinedModel | None = None

@override
def setUp(self) -> None:
Expand All @@ -20,6 +27,24 @@ def setUp(self) -> None:

@property
def test_llm_settings(self) -> TestLLMSettings:
bedrock_model_id: str = self.opts.kwargs.get(
"bedrock_model_id", ""
) # ignore: [reportUnknownVariableType]
if bedrock_model_id:
aws_region: str = self.opts.kwargs.get(
"bedrock_aws_region", ""
) # ignore: [reportUnknownVariableType]
base_model_id: str = self.opts.kwargs.get(
"bedrock_base_model_id", ""
) # ignore: [reportUnknownVariableType]
return TestLLMSettings(
anthropic_bedrock=AnthropicBedrockModel(
model=bedrock_model_id, # ignore: [reportUnknownVariableType]
aws_region=aws_region, # ignore: [reportUnknownVariableType]
base_model_id=base_model_id, # ignore: [reportUnknownVariableType]
)
)

client_id: str = self.opts.kwargs["internal_ai_client_id"]
client_secret: str = self.opts.kwargs["internal_ai_client_secret"]
app_key: str = self.opts.kwargs["internal_ai_app_key"]
Expand All @@ -42,3 +67,36 @@ async def model(self) -> PredefinedModel:
model = await create_model(self.test_llm_settings)
self._model = model
return model

async def sonnet_model(self) -> PredefinedModel:
"""Returns a Sonnet model for tests that require a more capable model.

Falls back to the default model if no Sonnet config is provided.
"""
if self._sonnet_model is not None:
return self._sonnet_model

sonnet_model_id: str = self.opts.kwargs.get("bedrock_sonnet_model_id", "")
if sonnet_model_id:
aws_region: str = self.opts.kwargs.get("bedrock_aws_region", "")
base_model_id: str = self.opts.kwargs.get("bedrock_sonnet_base_model_id", "")
settings = TestLLMSettings(
anthropic_bedrock=AnthropicBedrockModel(
model=sonnet_model_id,
aws_region=aws_region,
base_model_id=base_model_id,
)
)
model = await create_model(settings)
self._sonnet_model = model
return model

return await self.model()

@property
def supports_provider_strategy(self) -> bool:
"""Returns True if the configured model supports ProviderStrategy (native JSON output).

AnthropicBedrockModel routes through ToolStrategy instead, so it returns False.
"""
return self.test_llm_settings.anthropic_bedrock is None
26 changes: 26 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from collections.abc import Generator

import pytest
from langchain_core.language_models import BaseChatModel

from splunklib.ai.engines import langchain as lc_engine
from splunklib.ai.model import PredefinedModel
from tests.ai_test_model import AnthropicBedrockModel

_original_create_langchain_model = lc_engine._create_langchain_model # pyright: ignore[reportPrivateUsage]


def _patched_create_langchain_model(model: PredefinedModel) -> BaseChatModel:
if isinstance(model, AnthropicBedrockModel):
return model._to_langchain_model() # pyright: ignore[reportPrivateUsage]
return _original_create_langchain_model(model)


@pytest.fixture(autouse=True)
def _patch_langchain_model_factory(request: pytest.FixtureRequest) -> Generator[None]:
if "integration/ai" not in str(request.fspath):
yield
return
lc_engine._create_langchain_model = _patched_create_langchain_model # pyright: ignore[reportPrivateUsage]
yield
lc_engine._create_langchain_model = _original_create_langchain_model # pyright: ignore[reportPrivateUsage]
22 changes: 17 additions & 5 deletions tests/integration/ai/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
before_agent,
before_model,
)
from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage
from splunklib.ai.messages import AIMessage, AgentResponse, HumanMessage, StructuredOutputMessage
from splunklib.ai.middleware import AgentRequest, ModelMiddlewareHandler, ModelRequest, ModelResponse, model_middleware
from tests.ai_testlib import AITestCase

Expand Down Expand Up @@ -127,7 +127,10 @@ async def after_agent_hook(resp: AgentResponse) -> None:
person = resp.structured_output
assert type(person) is Person
assert person.name.lower() == "stefan"
assert len(resp.messages) == 2
# ProviderStrategy: 2 messages (human + AI).
# ToolStrategy: 3 messages (human + AI tool_use + StructuredOutputMessage).
uses_tool_strategy = any(isinstance(m, StructuredOutputMessage) for m in resp.messages)
assert len(resp.messages) == (3 if uses_tool_strategy else 2)

@after_agent
async def after_async_agent_hook(resp: AgentResponse) -> None:
Expand All @@ -137,7 +140,10 @@ async def after_async_agent_hook(resp: AgentResponse) -> None:
person = resp.structured_output
assert type(person) is Person
assert person.name.lower() == "stefan"
assert len(resp.messages) == 2
# ProviderStrategy: 2 messages (human + AI).
# ToolStrategy: 3 messages (human + AI tool_use + StructuredOutputMessage).
uses_tool_strategy = any(isinstance(m, StructuredOutputMessage) for m in resp.messages)
assert len(resp.messages) == (3 if uses_tool_strategy else 2)

async with Agent(
model=(await self.model()),
Expand All @@ -159,8 +165,14 @@ async def after_async_agent_hook(resp: AgentResponse) -> None:
]
)

response = result.final_message.content.strip().lower().replace(".", "")
assert '{"name":"stefan"}' == response
# With ProviderStrategy the final message is plain JSON text.
# With ToolStrategy the structured output is in result.structured_output.
person = result.structured_output
if person is not None:
assert person.name.lower() == "stefan"
else:
response = result.final_message.content.strip().lower().replace(".", "")
assert '{"name":"stefan"}' == response
assert hook_calls == 4

@pytest.mark.asyncio
Expand Down
36 changes: 24 additions & 12 deletions tests/integration/ai/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,8 +358,9 @@ class NicknameGeneratorInput(BaseModel):
Agent(
model=await self.model(),
system_prompt=(
"You are a helpful assistant that generates nicknames. A valid "
+ "nickname consists of the provided name suffixed with '-zilla.'"
"You are a helpful assistant that generates nicknames. "
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
),
service=self.service,
name="NicknameGeneratorAgent",
Expand Down Expand Up @@ -406,15 +407,16 @@ async def test_middleware(
first_response = await handler(request)
second_response = await handler(request)
assert isinstance(first_response.result, SubagentTextResult)
assert second_response == first_response
assert isinstance(second_response.result, SubagentTextResult)
return second_response

async with (
Agent(
model=await self.model(),
system_prompt=(
"You are a helpful assistant that generates nicknames. A valid "
+ "nickname consists of the provided name suffixed with '-zilla.'"
"You are a helpful assistant that generates nicknames. "
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
),
service=self.service,
name="NicknameGeneratorAgent",
Expand Down Expand Up @@ -472,8 +474,9 @@ async def test_middleware(
Agent(
model=await self.model(),
system_prompt=(
"You are a helpful assistant that generates nicknames. A valid "
+ "nickname consists of the provided name suffixed with '-zilla.'"
"You are a helpful assistant that generates nicknames. "
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
),
service=self.service,
name="NicknameGeneratorAgent",
Expand Down Expand Up @@ -601,8 +604,9 @@ async def test_middleware(
Agent(
model=await self.model(),
system_prompt=(
"You are a helpful assistant that generates nicknames. A valid "
+ "nickname consists of the provided name suffixed with '-zilla.'"
"You are a helpful assistant that generates nicknames. "
+ "The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
+ "For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
),
service=self.service,
name="NicknameGeneratorAgent",
Expand Down Expand Up @@ -777,8 +781,9 @@ async def mutating_middleware(
Agent(
model=await self.model(),
system_prompt=(
"You are a helpful assistant that generates nicknames. A valid "
"nickname consists of the provided name suffixed with '-zilla.'"
"You are a helpful assistant that generates nicknames. "
"The nickname MUST be formatted as exactly '<name>-zilla' with a hyphen. "
"For example: Chris -> Chris-zilla, Alice -> Alice-zilla."
),
service=self.service,
name="NicknameGeneratorAgent",
Expand All @@ -796,7 +801,14 @@ async def mutating_middleware(
result = await supervisor.invoke(
[HumanMessage(content="Generate a nickname for Bob")]
)
assert "Alice-zilla" in result.final_message.content
# The middleware mutated the arg to "Alice", so the subagent must have
# received "Alice" and returned "Alice-zilla". Check the subagent message.
subagent_msg = next(
(m for m in result.messages if isinstance(m, SubagentMessage)), None
)
assert subagent_msg is not None
assert isinstance(subagent_msg.result, SubagentTextResult)
assert "Alice-zilla" in subagent_msg.result.content

@pytest.mark.asyncio
async def test_model_middleware_structured_output(self) -> None:
Expand Down
Loading
Loading