From 7c4dbde773ef25e446de4bf9c6ac59b27f43275b Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Mon, 22 Jun 2026 15:08:44 -0500 Subject: [PATCH 01/23] [AIC-2664] Impl trackers (first pass) --- .../sdk/server/ai/LDAIClient.java | 15 + .../sdk/server/ai/LDAIClientImpl.java | 57 +- .../sdk/server/ai/LDAIConfigTracker.java | 163 +++- .../ai/datamodel/LDAITrackingTypes.java | 706 ++++++++++++++++++ .../ai/internal/LDAIConfigTrackerImpl.java | 381 ++++++++++ .../ai/internal/NoOpAIConfigTracker.java | 83 +- .../server/ai/internal/ResumptionTokens.java | 297 ++++++++ .../internal/LDAIConfigTrackerImplTest.java | 695 +++++++++++++++++ .../ai/internal/ResumptionTokensTest.java | 174 +++++ 9 files changed, 2552 insertions(+), 19 deletions(-) create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java create mode 100644 lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java create mode 100644 lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index 16ce751d..a2a0f76b 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -81,4 +81,19 @@ AIJudgeConfig judgeConfig( LDContext context, AIJudgeConfigDefault defaultValue, Map variables); + + /** + * Reconstructs a tracker from a resumption token, preserving the original run's identity. + *

+ * Use this when a multi-turn or streaming AI interaction spans multiple requests. The caller + * stores the resumption token from a previous tracker (via + * {@link LDAIConfigTracker#getResumptionToken()}) and passes it back here to continue tracking + * against the same run. + * + * @param resumptionToken the token returned by a previous tracker; must not be {@code null} + * @param context the evaluation context for the new request; must not be {@code null} + * @return a tracker with the decoded run identity, never {@code null} + * @throws IllegalArgumentException if the token is malformed + */ + LDAIConfigTracker createTracker(String resumptionToken, LDContext context); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 650fdeed..8f261604 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -8,13 +8,13 @@ import com.launchdarkly.sdk.LDContext; import com.launchdarkly.sdk.LDValue; import com.launchdarkly.sdk.LDValueType; -import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; +import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.internal.AIConfigFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigParser; import com.launchdarkly.sdk.server.ai.internal.AISdkInfo; import com.launchdarkly.sdk.server.ai.internal.Interpolator; -import com.launchdarkly.sdk.server.ai.internal.NoOpAIConfigTracker; +import com.launchdarkly.sdk.server.ai.internal.LDAIConfigTrackerImpl; import com.launchdarkly.sdk.server.interfaces.LDClientInterface; import java.util.ArrayList; @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.UUID; import java.util.function.Supplier; /** @@ -51,8 +52,6 @@ public final class LDAIClientImpl implements LDAIClient { .anonymous(true) .build(); - // Tracking is implemented in a later step; until then every config hands out the no-op tracker. - private static final Supplier TRACKER_FACTORY = () -> NoOpAIConfigTracker.INSTANCE; private final LDClientInterface client; private final LDLogger logger; @@ -187,6 +186,9 @@ private AIConfig buildConfig( AIConfigFlagValue parsed, LDContext context, Map variables) { + Supplier factory = trackerFactory( + key, parsed.getVariationKey(), parsed.getVersion(), + parsed.getModel(), parsed.getProvider(), context); switch (mode) { case AGENT: return new AIAgentConfig( @@ -197,7 +199,7 @@ private AIConfig buildConfig( interpolate(parsed.getInstructions(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + factory); case JUDGE: return new AIJudgeConfig( key, @@ -206,7 +208,7 @@ private AIConfig buildConfig( parsed.getProvider(), interpolateMessages(parsed.getMessages(), variables, context), parsed.getEvaluationMetricKey(), - TRACKER_FACTORY); + factory); case COMPLETION: default: return new AICompletionConfig( @@ -217,7 +219,7 @@ private AIConfig buildConfig( interpolateMessages(parsed.getMessages(), variables, context), parsed.getJudgeConfiguration(), parsed.getTools(), - TRACKER_FACTORY); + factory); } } @@ -231,6 +233,9 @@ private AIConfig buildConfigFromDefault( AIConfigDefault defaultValue, LDContext context, Map variables) { + // Default configs still get real trackers — the configKey was requested even if no flag was found. + // variationKey is null because no flag evaluation occurred. + Supplier factory = trackerFactory(key, null, null, null, null, context); switch (mode) { case AGENT: { AIAgentConfigDefault agent = (AIAgentConfigDefault) defaultValue; @@ -242,7 +247,7 @@ private AIConfig buildConfigFromDefault( interpolate(agent.getInstructions(), variables, context), agent.getJudgeConfiguration(), agent.getTools(), - TRACKER_FACTORY); + factory); } case JUDGE: { AIJudgeConfigDefault judge = (AIJudgeConfigDefault) defaultValue; @@ -253,7 +258,7 @@ private AIConfig buildConfigFromDefault( judge.getProvider(), interpolateMessages(judge.getMessages(), variables, context), judge.getEvaluationMetricKey(), - TRACKER_FACTORY); + factory); } case COMPLETION: default: { @@ -266,11 +271,43 @@ private AIConfig buildConfigFromDefault( interpolateMessages(completion.getMessages(), variables, context), completion.getJudgeConfiguration(), completion.getTools(), - TRACKER_FACTORY); + factory); } } } + /** + * Creates a per-evaluation tracker factory. Each call to the returned {@link Supplier} produces + * a fresh {@link LDAIConfigTrackerImpl} with a new {@code runId}. + */ + private Supplier trackerFactory( + String configKey, + String variationKey, + Integer version, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, + LDContext context) { + String modelName = model != null && model.getName() != null ? model.getName() : ""; + String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; + int ver = version != null ? version : 0; + return () -> new LDAIConfigTrackerImpl( + client, + UUID.randomUUID().toString(), + configKey, + variationKey, + ver, + modelName, + providerName, + context, + null, // graphKey — set by agentGraph() in Plan 3 + logger); + } + + @Override + public LDAIConfigTracker createTracker(String resumptionToken, LDContext context) { + return LDAIConfigTrackerImpl.fromResumptionToken(resumptionToken, client, context, logger); + } + private List interpolateMessages( List messages, Map variables, LDContext context) { if (messages == null) { diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java index a298e33b..3364d21e 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java @@ -1,16 +1,165 @@ package com.launchdarkly.sdk.server.ai; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.Function; + /** * Reports events related to a single AI run of an {@link AIConfig}. *

- * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}. Each tracker - * corresponds to one AI run and is used to record metrics such as model usage, duration, and - * feedback against the AI Config it was created from. + * A tracker is obtained from a retrieved config via {@link AIConfig#createTracker()}, or + * reconstructed from a resumption token via {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)}. + * Each tracker corresponds to one AI run and is used to record metrics such as model usage, + * duration, and feedback against the AI Config it was created from. + *

+ * Most tracking methods are at-most-once: a second call to the same method on the same tracker + * is silently dropped. {@link #trackToolCall(String)} and {@link #trackJudgeResult(JudgeResult)} + * are multi-fire — each call records a distinct event. *

- * This interface is an intentional placeholder. The metric- and feedback-reporting - * methods (and resumption-token support) are introduced in a later step of the AI SDK build-out; it - * is defined here so that the public config types expose a stable {@code createTracker()} surface. - * The only implementation in this release is an internal no-op. + * Implementations are thread-safe. */ public interface LDAIConfigTracker { + + /** + * Returns the correlation metadata for this tracker's run. + * + * @return the track data, never {@code null} + */ + TrackData getTrackData(); + + /** + * Returns the resumption token for this run. + *

+ * The resumption token encodes the run's identity and can be passed to + * {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)} to reconstruct a + * tracker on a subsequent request (for example, in a streaming scenario). + * + * @return the resumption token, or {@code null} if not available + */ + String getResumptionToken(); + + /** + * Records the duration of the AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param duration the duration; ignored if {@code null} + */ + void trackDuration(Duration duration); + + /** + * Executes the given operation and records its wall-clock duration. + *

+ * The duration is recorded even if the operation throws. Equivalent to wrapping the operation + * in a try/finally that calls {@link #trackDuration(Duration)}. + * + * @param the return type of the operation + * @param operation the operation to execute and time; must not be {@code null} + * @return the result of the operation + * @throws Exception if the operation throws + */ + T trackDurationOf(Callable operation) throws Exception; + + /** + * Records the time from request start to receipt of the first token. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param duration the time to first token; ignored if {@code null} + */ + void trackTimeToFirstToken(Duration duration); + + /** + * Records that the AI generation succeeded. + *

+ * At-most-once and mutually exclusive with {@link #trackError()}: whichever is called first wins. + */ + void trackSuccess(); + + /** + * Records that the AI generation failed. + *

+ * At-most-once and mutually exclusive with {@link #trackSuccess()}: whichever is called first wins. + */ + void trackError(); + + /** + * Records user feedback for this AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param kind the feedback kind; ignored if {@code null} + */ + void trackFeedback(FeedbackKind kind); + + /** + * Records token usage for this AI generation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. Calls where all + * counts are zero do not consume the at-most-once slot. + * + * @param tokens the token usage; ignored if {@code null} + */ + void trackTokens(TokenUsage tokens); + + /** + * Records a single tool call made during this AI generation. + *

+ * Multi-fire: every call emits an event. + * + * @param toolKey the tool key; ignored if {@code null} + */ + void trackToolCall(String toolKey); + + /** + * Records multiple tool calls made during this AI generation. + *

+ * Equivalent to calling {@link #trackToolCall(String)} for each key. + * + * @param toolKeys the tool keys; ignored if {@code null} + */ + void trackToolCalls(List toolKeys); + + /** + * Records the result of a judge evaluation. + *

+ * Multi-fire per judge metric key. The result is silently skipped if it was not sampled, if + * the evaluation did not succeed, or if the metric key or score is absent. + * + * @param result the judge result; ignored if {@code null} + */ + void trackJudgeResult(JudgeResult result); + + /** + * Executes the given operation and tracks its metrics using the extracted {@link AIMetrics}. + *

+ * Tracks duration (preferring runner-reported duration when present), success or error, tokens, + * and tool calls. If the operation throws, {@link #trackError()} is called and the exception + * is re-thrown. + * + * @param the return type of the operation + * @param metricsExtractor a function that extracts {@link AIMetrics} from the operation result; + * exceptions from the extractor propagate to the caller + * @param operation the AI operation to execute; must not be {@code null} + * @return the result of the operation + * @throws Exception if the operation or the metrics extractor throws + */ + T trackMetricsOf( + Function metricsExtractor, + Callable operation) throws Exception; + + /** + * Returns a snapshot of all metrics tracked so far on this tracker. + * + * @return the metric summary, never {@code null} + */ + MetricSummary getSummary(); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java new file mode 100644 index 00000000..534e3aed --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -0,0 +1,706 @@ +package com.launchdarkly.sdk.server.ai.datamodel; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * Container for the shared, immutable AI tracking data types. + *

+ * These shapes ({@link FeedbackKind}, {@link TokenUsage}, {@link AIMetrics}, {@link JudgeResult}, + * {@link MetricSummary}, and {@link TrackData}) are used by {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker} + * and its implementations to report AI run metrics and feedback. They are grouped under this single + * type, rather than declared as separate top-level classes, to keep the package small and to free + * up generic names. + *

+ * This class is not instantiable. + */ +public final class LDAITrackingTypes { + private LDAITrackingTypes() { + } + + /** + * The kind of user feedback reported via {@code trackFeedback}. + */ + public enum FeedbackKind { + /** + * Positive (thumbs-up) feedback. + */ + POSITIVE("positive"), + + /** + * Negative (thumbs-down) feedback. + */ + NEGATIVE("negative"); + + private final String value; + + FeedbackKind(String value) { + this.value = value; + } + + /** + * Returns the wire representation of this feedback kind. + * + * @return the wire value (for example {@code "positive"}) + */ + public String getValue() { + return value; + } + } + + /** + * Token usage counts for a single AI generation. + *

+ * Instances are immutable. + */ + public static final class TokenUsage { + private final long total; + private final long input; + private final long output; + + /** + * Constructs token usage counts. + * + * @param total the total token count + * @param input the input (prompt) token count + * @param output the output (completion) token count + */ + public TokenUsage(long total, long input, long output) { + this.total = total; + this.input = input; + this.output = output; + } + + /** + * Returns the total token count. + * + * @return the total token count + */ + public long getTotal() { + return total; + } + + /** + * Returns the input (prompt) token count. + * + * @return the input token count + */ + public long getInput() { + return input; + } + + /** + * Returns the output (completion) token count. + * + * @return the output token count + */ + public long getOutput() { + return output; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof TokenUsage)) { + return false; + } + TokenUsage other = (TokenUsage) o; + return total == other.total && input == other.input && output == other.output; + } + + @Override + public int hashCode() { + return Objects.hash(total, input, output); + } + + @Override + public String toString() { + return "TokenUsage{total=" + total + ", input=" + input + ", output=" + output + '}'; + } + } + + /** + * Metrics extracted from a single AI generation, used with {@code trackMetricsOf}. + *

+ * Build instances with {@link #builder()}. + *

+ * Instances are immutable. + */ + public static final class AIMetrics { + private final boolean success; + private final TokenUsage tokens; + private final List toolCalls; + private final Long durationMs; + + private AIMetrics(Builder b) { + this.success = b.success; + this.tokens = b.tokens; + this.toolCalls = b.toolCalls == null ? null : Collections.unmodifiableList(new ArrayList<>(b.toolCalls)); + this.durationMs = b.durationMs; + } + + /** + * Returns whether the AI generation succeeded. + * + * @return {@code true} if the generation succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns the token usage for this generation. + * + * @return the token usage, or {@code null} if not reported + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the tool calls made during this generation. + * + * @return an unmodifiable list of tool call keys, or {@code null} if not reported + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the duration of the AI generation in milliseconds, as reported by the runner. + *

+ * When set, {@code trackMetricsOf} uses this value instead of its own wall-clock measurement. + * + * @return the runner-reported duration in milliseconds, or {@code null} if not reported + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Creates a new builder. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link AIMetrics}. + */ + public static final class Builder { + private boolean success; + private TokenUsage tokens; + private List toolCalls; + private Long durationMs; + + private Builder() { + } + + /** + * Sets whether the AI generation succeeded. + * + * @param success {@code true} if the generation succeeded + * @return this builder + */ + public Builder success(boolean success) { + this.success = success; + return this; + } + + /** + * Sets the token usage. + * + * @param tokens the token usage; may be {@code null} + * @return this builder + */ + public Builder tokens(TokenUsage tokens) { + this.tokens = tokens; + return this; + } + + /** + * Sets the tool calls made during this generation. + * + * @param toolCalls the tool call keys; may be {@code null} + * @return this builder + */ + public Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + /** + * Sets the runner-reported duration in milliseconds. + * + * @param durationMs the duration; may be {@code null} + * @return this builder + */ + public Builder durationMs(Long durationMs) { + this.durationMs = durationMs; + return this; + } + + /** + * Builds the immutable {@link AIMetrics}. + * + * @return a new {@link AIMetrics} + */ + public AIMetrics build() { + return new AIMetrics(this); + } + } + } + + /** + * The result of a judge evaluation, reported via {@code trackJudgeResult}. + *

+ * Build instances with {@link #builder()}. + *

+ * Instances are immutable. + */ + public static final class JudgeResult { + private final String judgeConfigKey; + private final boolean success; + private final String errorMessage; + private final boolean sampled; + private final String metricKey; + private final Double score; + private final String reasoning; + + private JudgeResult(Builder b) { + this.judgeConfigKey = b.judgeConfigKey; + this.success = b.success; + this.errorMessage = b.errorMessage; + this.sampled = b.sampled; + this.metricKey = b.metricKey; + this.score = b.score; + this.reasoning = b.reasoning; + } + + /** + * Returns the key of the judge AI Config, if known. + * + * @return the judge config key, or {@code null} if not set + */ + public String getJudgeConfigKey() { + return judgeConfigKey; + } + + /** + * Returns whether the judge evaluation succeeded. + * + * @return {@code true} if the evaluation succeeded + */ + public boolean isSuccess() { + return success; + } + + /** + * Returns an error message from the judge evaluation, if any. + * + * @return the error message, or {@code null} if none + */ + public String getErrorMessage() { + return errorMessage; + } + + /** + * Returns whether this result was selected for sampling. + * + * @return {@code true} if the result was sampled + */ + public boolean isSampled() { + return sampled; + } + + /** + * Returns the metric key to use when emitting this result. + * + * @return the metric key, or {@code null} if not set + */ + public String getMetricKey() { + return metricKey; + } + + /** + * Returns the judge score. + *

+ * A {@code null} score is distinct from a score of {@code 0.0} — a null score means no score + * was produced, while {@code 0.0} is a valid score. + * + * @return the score, or {@code null} if not set + */ + public Double getScore() { + return score; + } + + /** + * Returns the judge's reasoning, if any. + * + * @return the reasoning, or {@code null} if none + */ + public String getReasoning() { + return reasoning; + } + + /** + * Creates a new builder. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link JudgeResult}. + */ + public static final class Builder { + private String judgeConfigKey; + private boolean success; + private String errorMessage; + private boolean sampled; + private String metricKey; + private Double score; + private String reasoning; + + private Builder() { + } + + /** + * Sets the judge config key. + * + * @param judgeConfigKey the key; may be {@code null} + * @return this builder + */ + public Builder judgeConfigKey(String judgeConfigKey) { + this.judgeConfigKey = judgeConfigKey; + return this; + } + + /** + * Sets whether the judge evaluation succeeded. + * + * @param success {@code true} if succeeded + * @return this builder + */ + public Builder success(boolean success) { + this.success = success; + return this; + } + + /** + * Sets the error message. + * + * @param errorMessage the error message; may be {@code null} + * @return this builder + */ + public Builder errorMessage(String errorMessage) { + this.errorMessage = errorMessage; + return this; + } + + /** + * Sets whether this result was sampled. + * + * @param sampled {@code true} if sampled + * @return this builder + */ + public Builder sampled(boolean sampled) { + this.sampled = sampled; + return this; + } + + /** + * Sets the metric key. + * + * @param metricKey the metric key; may be {@code null} + * @return this builder + */ + public Builder metricKey(String metricKey) { + this.metricKey = metricKey; + return this; + } + + /** + * Sets the judge score. + * + * @param score the score; may be {@code null} + * @return this builder + */ + public Builder score(Double score) { + this.score = score; + return this; + } + + /** + * Sets the reasoning. + * + * @param reasoning the reasoning; may be {@code null} + * @return this builder + */ + public Builder reasoning(String reasoning) { + this.reasoning = reasoning; + return this; + } + + /** + * Builds the immutable {@link JudgeResult}. + * + * @return a new {@link JudgeResult} + */ + public JudgeResult build() { + return new JudgeResult(this); + } + } + } + + /** + * A snapshot of all metrics tracked by a single {@link com.launchdarkly.sdk.server.ai.LDAIConfigTracker}. + *

+ * Returned by {@code getSummary()}. All fields are nullable — {@code null} indicates the + * corresponding metric has not been recorded yet. {@link #getSuccess()} is a tri-state: + * {@code null} = not yet tracked, {@code true} = success was recorded, {@code false} = error + * was recorded. + *

+ * Instances are immutable. + */ + public static final class MetricSummary { + private final Boolean success; + private final TokenUsage tokens; + private final Long durationMs; + private final FeedbackKind feedback; + private final Long timeToFirstTokenMs; + private final List toolCalls; + private final String resumptionToken; + + /** + * Constructs a metric summary snapshot. + * + * @param success tri-state outcome: {@code null} = not tracked, {@code true} = success, {@code false} = error + * @param tokens the token usage, or {@code null} + * @param durationMs the duration in milliseconds, or {@code null} + * @param feedback the feedback kind, or {@code null} + * @param timeToFirstTokenMs the time to first token in milliseconds, or {@code null} + * @param toolCalls the tool calls made, or {@code null} + * @param resumptionToken the resumption token, or {@code null} + */ + public MetricSummary( + Boolean success, + TokenUsage tokens, + Long durationMs, + FeedbackKind feedback, + Long timeToFirstTokenMs, + List toolCalls, + String resumptionToken) { + this.success = success; + this.tokens = tokens; + this.durationMs = durationMs; + this.feedback = feedback; + this.timeToFirstTokenMs = timeToFirstTokenMs; + this.toolCalls = toolCalls == null ? null : Collections.unmodifiableList(new ArrayList<>(toolCalls)); + this.resumptionToken = resumptionToken; + } + + /** + * Returns the outcome of the AI generation, as a tri-state. + * + * @return {@code null} if not tracked, {@code true} if success was recorded, {@code false} if error was recorded + */ + public Boolean getSuccess() { + return success; + } + + /** + * Returns the token usage. + * + * @return the token usage, or {@code null} if not tracked + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the duration in milliseconds. + * + * @return the duration, or {@code null} if not tracked + */ + public Long getDurationMs() { + return durationMs; + } + + /** + * Returns the feedback kind. + * + * @return the feedback, or {@code null} if not tracked + */ + public FeedbackKind getFeedback() { + return feedback; + } + + /** + * Returns the time to first token in milliseconds. + * + * @return the time to first token, or {@code null} if not tracked + */ + public Long getTimeToFirstTokenMs() { + return timeToFirstTokenMs; + } + + /** + * Returns the tool calls made during the generation. + * + * @return an unmodifiable list of tool call keys, or {@code null} if none were tracked + */ + public List getToolCalls() { + return toolCalls; + } + + /** + * Returns the resumption token for this tracker. + * + * @return the resumption token, or {@code null} if not available + */ + public String getResumptionToken() { + return resumptionToken; + } + } + + /** + * Correlation metadata attached to every metric event emitted by a tracker. + *

+ * Instances are immutable. + */ + public static final class TrackData { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String modelName; + private final String providerName; + private final String graphKey; + + /** + * Constructs track data. + * + * @param runId the unique run identifier; must not be {@code null} + * @param configKey the AI Config key; must not be {@code null} + * @param variationKey the variation key, or {@code null} when a default config is used + * @param version the config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param graphKey the agent graph key, or {@code null} when not part of a graph + */ + public TrackData( + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + String graphKey) { + this.runId = Objects.requireNonNull(runId, "runId"); + this.configKey = Objects.requireNonNull(configKey, "configKey"); + this.variationKey = variationKey; + this.version = version; + this.modelName = modelName == null ? "" : modelName; + this.providerName = providerName == null ? "" : providerName; + this.graphKey = graphKey; + } + + /** + * Returns the unique run identifier. + * + * @return the run ID, never {@code null} + */ + public String getRunId() { + return runId; + } + + /** + * Returns the AI Config key. + * + * @return the config key, never {@code null} + */ + public String getConfigKey() { + return configKey; + } + + /** + * Returns the variation key. + * + * @return the variation key, or {@code null} when a default config is used + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the config version. + * + * @return the version + */ + public int getVersion() { + return version; + } + + /** + * Returns the model name. + * + * @return the model name, or empty string when unknown + */ + public String getModelName() { + return modelName; + } + + /** + * Returns the provider name. + * + * @return the provider name, or empty string when unknown + */ + public String getProviderName() { + return providerName; + } + + /** + * Returns the agent graph key. + * + * @return the graph key, or {@code null} when not part of a graph + */ + public String getGraphKey() { + return graphKey; + } + + /** + * Builds an {@link LDValue} representation of this track data using camelCase keys. + *

+ * {@code variationKey} and {@code graphKey} are omitted when {@code null}. + * + * @return an {@link LDValue} object containing all non-null fields + */ + public LDValue toLDValue() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("configKey", configKey) + .put("version", version) + .put("modelName", modelName) + .put("providerName", providerName); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + if (graphKey != null) { + b.put("graphKey", graphKey); + } + return b.build(); + } + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java new file mode 100644 index 00000000..3c7faf6a --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -0,0 +1,381 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.Callable; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +/** + * The default {@link LDAIConfigTracker} implementation. + *

+ * Tracks AI run metrics and emits them as LaunchDarkly custom events via the wrapped + * {@link LDClientInterface}. At-most-once semantics for each metric type are enforced using + * {@link AtomicReference#compareAndSet} — a single atomic operation that serves as both guard + * and value store, eliminating the race window present in a two-step check-then-act pattern. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class LDAIConfigTrackerImpl implements LDAIConfigTracker { + + private static final String DURATION_TOTAL = "$ld:ai:duration:total"; + private static final String TOKENS_TTF = "$ld:ai:tokens:ttf"; + private static final String GENERATION_SUCCESS = "$ld:ai:generation:success"; + private static final String GENERATION_ERROR = "$ld:ai:generation:error"; + private static final String FEEDBACK_POSITIVE = "$ld:ai:feedback:user:positive"; + private static final String FEEDBACK_NEGATIVE = "$ld:ai:feedback:user:negative"; + private static final String TOKENS_TOTAL = "$ld:ai:tokens:total"; + private static final String TOKENS_INPUT = "$ld:ai:tokens:input"; + private static final String TOKENS_OUTPUT = "$ld:ai:tokens:output"; + private static final String TOOL_CALL = "$ld:ai:tool_call"; + + private final LDClientInterface client; + private final LDContext context; + private final LDLogger logger; + + // Identity fields + private final String runId; + private final String configKey; + private final String variationKey; // nullable — null when using a default config + private final int version; + private final String modelName; // empty string when unknown + private final String providerName; // empty string when unknown + private final String graphKey; // nullable + + // Computed once at construction + private final String resumptionToken; + + // At-most-once slots: null = not yet recorded, non-null = recorded with this value. + // AtomicReference.compareAndSet(null, value) is a single atomic operation — both guard and + // value store — eliminating the race window in an AtomicBoolean + volatile approach. + private final AtomicReference durationMs = new AtomicReference<>(); + private final AtomicReference timeToFirstTokenMs = new AtomicReference<>(); + // Shared by trackSuccess and trackError: true = success, false = error + private final AtomicReference outcome = new AtomicReference<>(); + private final AtomicReference feedbackRef = new AtomicReference<>(); + private final AtomicReference tokensRef = new AtomicReference<>(); + + // Multi-fire accumulator — not at-most-once + private final CopyOnWriteArrayList toolCalls = new CopyOnWriteArrayList<>(); + + /** + * Creates a tracker for a new AI run. + * + * @param client the LaunchDarkly client used to emit events; must not be {@code null} + * @param runId the unique run identifier (UUID v4); must not be {@code null} + * @param configKey the AI Config key; must not be {@code null} + * @param variationKey the variation key, or {@code null} when using a default config + * @param version the config version + * @param modelName the model name, or empty string when unknown + * @param providerName the provider name, or empty string when unknown + * @param context the evaluation context; must not be {@code null} + * @param graphKey the agent graph key, or {@code null} when not part of a graph + * @param logger the logger; must not be {@code null} + */ + public LDAIConfigTrackerImpl( + LDClientInterface client, + String runId, + String configKey, + String variationKey, + int version, + String modelName, + String providerName, + LDContext context, + String graphKey, + LDLogger logger) { + this.client = Objects.requireNonNull(client, "client"); + this.runId = Objects.requireNonNull(runId, "runId"); + this.configKey = Objects.requireNonNull(configKey, "configKey"); + this.variationKey = variationKey; + this.version = version; + this.modelName = modelName == null ? "" : modelName; + this.providerName = providerName == null ? "" : providerName; + this.context = Objects.requireNonNull(context, "context"); + this.graphKey = graphKey; + this.logger = Objects.requireNonNull(logger, "logger"); + + // Compute once at construction — all inputs are immutable. + this.resumptionToken = ResumptionTokens.encode(runId, configKey, variationKey, version, graphKey); + } + + /** + * Reconstructs a tracker from a resumption token, preserving the original run's identity. + * + * @param token the resumption token + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @param logger the logger; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static LDAIConfigTrackerImpl fromResumptionToken( + String token, LDClientInterface client, LDContext context, LDLogger logger) { + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + return new LDAIConfigTrackerImpl( + client, + d.getRunId(), + d.getConfigKey(), + d.getVariationKey(), + d.getVersion(), + "", // modelName not carried in token + "", // providerName not carried in token + context, + d.getGraphKey(), + logger); + } + + @Override + public TrackData getTrackData() { + return new TrackData(runId, configKey, variationKey, version, modelName, providerName, graphKey); + } + + @Override + public String getResumptionToken() { + return resumptionToken; + } + + @Override + public void trackDuration(Duration duration) { + if (duration == null) { + logger.warn("Skipping trackDuration: duration was null."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + if (!durationMs.compareAndSet(null, ms)) { + logger.warn("Skipping trackDuration: duration already recorded on this tracker."); + return; + } + client.trackMetric(DURATION_TOTAL, context, baseData().build(), ms); + } + + @Override + public T trackDurationOf(Callable operation) throws Exception { + Objects.requireNonNull(operation, "operation"); + long start = System.nanoTime(); + try { + return operation.call(); + } finally { + long elapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + trackDuration(Duration.ofMillis(elapsedMs)); + } + } + + @Override + public void trackTimeToFirstToken(Duration duration) { + if (duration == null) { + logger.warn("Skipping trackTimeToFirstToken: duration was null."); + return; + } + long ms = Math.max(0L, duration.toMillis()); + if (!timeToFirstTokenMs.compareAndSet(null, ms)) { + logger.warn("Skipping trackTimeToFirstToken: time-to-first-token already recorded on this tracker."); + return; + } + client.trackMetric(TOKENS_TTF, context, baseData().build(), ms); + } + + @Override + public void trackSuccess() { + if (!outcome.compareAndSet(null, Boolean.TRUE)) { + logger.warn("Skipping trackSuccess: outcome already recorded on this tracker."); + return; + } + client.trackMetric(GENERATION_SUCCESS, context, baseData().build(), 1); + } + + @Override + public void trackError() { + if (!outcome.compareAndSet(null, Boolean.FALSE)) { + logger.warn("Skipping trackError: outcome already recorded on this tracker."); + return; + } + client.trackMetric(GENERATION_ERROR, context, baseData().build(), 1); + } + + @Override + public void trackFeedback(FeedbackKind kind) { + if (kind == null) { + logger.warn("Skipping trackFeedback: kind was null."); + return; + } + // Resolve event name BEFORE claiming the guard — an exception here must not burn the slot. + String eventName = kind == FeedbackKind.POSITIVE ? FEEDBACK_POSITIVE : FEEDBACK_NEGATIVE; + if (!feedbackRef.compareAndSet(null, kind)) { + logger.warn("Skipping trackFeedback: feedback already recorded on this tracker."); + return; + } + client.trackMetric(eventName, context, baseData().build(), 1); + } + + @Override + public void trackTokens(TokenUsage tokens) { + if (tokens == null) { + logger.warn("Skipping trackTokens: tokens was null."); + return; + } + boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; + if (!hasPositive) { + // Do not burn the at-most-once slot when all counts are zero. + return; + } + if (!tokensRef.compareAndSet(null, tokens)) { + logger.warn("Skipping trackTokens: token usage already recorded on this tracker."); + return; + } + if (tokens.getTotal() > 0) { + client.trackMetric(TOKENS_TOTAL, context, baseData().build(), tokens.getTotal()); + } + if (tokens.getInput() > 0) { + client.trackMetric(TOKENS_INPUT, context, baseData().build(), tokens.getInput()); + } + if (tokens.getOutput() > 0) { + client.trackMetric(TOKENS_OUTPUT, context, baseData().build(), tokens.getOutput()); + } + } + + @Override + public void trackToolCall(String toolKey) { + if (toolKey == null) { + logger.warn("Skipping trackToolCall: toolKey was null."); + return; + } + toolCalls.add(toolKey); + LDValue data = baseData().put("toolKey", toolKey).build(); + client.trackMetric(TOOL_CALL, context, data, 1); + } + + @Override + public void trackToolCalls(List toolKeys) { + if (toolKeys == null) { + return; + } + for (String key : toolKeys) { + trackToolCall(key); + } + } + + @Override + public void trackJudgeResult(JudgeResult result) { + if (result == null) { + logger.warn("Skipping trackJudgeResult: result was null."); + return; + } + if (!result.isSampled()) { + return; + } + if (!result.isSuccess()) { + return; + } + if (result.getMetricKey() == null) { + return; + } + if (result.getScore() == null) { + return; + } + ObjectBuilder data = baseData(); + if (result.getJudgeConfigKey() != null) { + data.put("judgeConfigKey", result.getJudgeConfigKey()); + } + client.trackMetric(result.getMetricKey(), context, data.build(), result.getScore()); + } + + @Override + public T trackMetricsOf( + Function metricsExtractor, + Callable operation) throws Exception { + Objects.requireNonNull(metricsExtractor, "metricsExtractor"); + Objects.requireNonNull(operation, "operation"); + + long start = System.nanoTime(); + T result; + try { + result = operation.call(); + } catch (Exception e) { + // Operation failed — track measured duration + error, then re-throw. + long elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + trackDuration(Duration.ofMillis(elapsed)); + trackError(); + throw e; + } + + // Extractor exceptions propagate to the caller — do NOT catch them here. + // Do NOT call trackError() on extractor failure; the AI operation itself succeeded. + AIMetrics metrics = metricsExtractor.apply(result); + + // Duration: prefer runner-reported value (§1.1.13.2), fall back to wall-clock. + if (metrics.getDurationMs() != null) { + trackDuration(Duration.ofMillis(metrics.getDurationMs())); + } else { + long elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); + trackDuration(Duration.ofMillis(elapsed)); + } + + if (metrics.isSuccess()) { + trackSuccess(); + } else { + trackError(); + } + + if (metrics.getTokens() != null) { + trackTokens(metrics.getTokens()); + } + if (metrics.getToolCalls() != null) { + trackToolCalls(metrics.getToolCalls()); + } + + return result; + } + + @Override + public MetricSummary getSummary() { + List snapshot = toolCalls.isEmpty() + ? null + : Collections.unmodifiableList(new ArrayList<>(toolCalls)); + return new MetricSummary( + outcome.get(), + tokensRef.get(), + durationMs.get(), + feedbackRef.get(), + timeToFirstTokenMs.get(), + snapshot, + resumptionToken); + } + + /** + * Returns a pre-populated {@link LDValue.ObjectBuilder} containing the base track-data fields. + * Individual track methods add per-event fields before calling {@link LDValue.ObjectBuilder#build()}. + */ + private ObjectBuilder baseData() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("configKey", configKey) + .put("version", version) + .put("modelName", modelName) + .put("providerName", providerName); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + if (graphKey != null) { + b.put("graphKey", graphKey); + } + return b; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java index 1cbc3c51..a1d65d5a 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java @@ -1,10 +1,22 @@ package com.launchdarkly.sdk.server.ai.internal; import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; + +import java.time.Duration; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.Function; /** - * The no-op {@link LDAIConfigTracker} used until metric reporting is implemented in a later step of - * the AI SDK. It is immutable and stateless, so a single shared instance is safe to reuse. + * The no-op {@link LDAIConfigTracker} used when tracking is not applicable (for example, for + * disabled configs or in testing contexts). It is immutable and stateless, so a single shared + * instance is safe to reuse. *

* This class is an internal implementation detail and is not part of the supported API. */ @@ -14,6 +26,73 @@ public final class NoOpAIConfigTracker implements LDAIConfigTracker { */ public static final NoOpAIConfigTracker INSTANCE = new NoOpAIConfigTracker(); + private static final TrackData EMPTY_TRACK_DATA = new TrackData("", "", null, 0, "", "", null); + private static final MetricSummary EMPTY_SUMMARY = + new MetricSummary(null, null, null, null, null, null, null); + private NoOpAIConfigTracker() { } + + @Override + public TrackData getTrackData() { + return EMPTY_TRACK_DATA; + } + + @Override + public String getResumptionToken() { + return null; + } + + @Override + public void trackDuration(Duration duration) { + } + + @Override + public T trackDurationOf(Callable operation) throws Exception { + return operation.call(); + } + + @Override + public void trackTimeToFirstToken(Duration duration) { + } + + @Override + public void trackSuccess() { + } + + @Override + public void trackError() { + } + + @Override + public void trackFeedback(FeedbackKind kind) { + } + + @Override + public void trackTokens(TokenUsage tokens) { + } + + @Override + public void trackToolCall(String toolKey) { + } + + @Override + public void trackToolCalls(List toolKeys) { + } + + @Override + public void trackJudgeResult(JudgeResult result) { + } + + @Override + public T trackMetricsOf( + Function metricsExtractor, + Callable operation) throws Exception { + return operation.call(); + } + + @Override + public MetricSummary getSummary() { + return EMPTY_SUMMARY; + } } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java new file mode 100644 index 00000000..82a90a7b --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -0,0 +1,297 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +/** + * Encodes and decodes resumption tokens for {@link LDAIConfigTrackerImpl}. + *

+ * A resumption token is a URL-safe Base64 (RFC 4648, no padding) encoding of a canonical JSON + * object that carries the run's identity fields. Tokens can be stored by callers and passed back + * to {@link com.launchdarkly.sdk.server.ai.LDAIClient#createTracker} to reconstruct a tracker + * across requests (for example, in a streaming or multi-turn scenario). + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +final class ResumptionTokens { + private static final int MAX_TOKEN_BYTES = 4096; + private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); + private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); + + private ResumptionTokens() { + } + + /** + * Encodes a resumption token from the given run identity fields. + *

+ * Field order in the JSON: {@code runId}, {@code configKey}, {@code variationKey} (omitted if + * {@code null}), {@code version}, {@code graphKey} (omitted if {@code null}). + * + * @param runId the run ID + * @param configKey the AI Config key + * @param variationKey the variation key, or {@code null} to omit + * @param version the config version + * @param graphKey the graph key, or {@code null} to omit + * @return the URL-safe Base64-encoded token + */ + static String encode(String runId, String configKey, String variationKey, + int version, String graphKey) { + StringBuilder sb = new StringBuilder(); + sb.append("{\"runId\":\"").append(escapeJson(runId)).append('"'); + sb.append(",\"configKey\":\"").append(escapeJson(configKey)).append('"'); + if (variationKey != null) { + sb.append(",\"variationKey\":\"").append(escapeJson(variationKey)).append('"'); + } + sb.append(",\"version\":").append(version); + if (graphKey != null) { + sb.append(",\"graphKey\":\"").append(escapeJson(graphKey)).append('"'); + } + sb.append('}'); + return ENCODER.encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Decodes a resumption token previously produced by {@link #encode}. + * + * @param token the URL-safe Base64 token + * @return the decoded fields + * @throws IllegalArgumentException if the token is malformed, oversized, or missing required fields + */ + static Decoded decode(String token) { + if (token == null) { + throw new IllegalArgumentException("Resumption token must not be null"); + } + if (token.length() > MAX_TOKEN_BYTES) { + throw new IllegalArgumentException("Resumption token exceeds maximum length of " + MAX_TOKEN_BYTES + " bytes"); + } + + String json; + try { + byte[] bytes = DECODER.decode(token); + json = new String(bytes, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Resumption token is not valid Base64: " + e.getMessage(), e); + } + + return parseJson(json); + } + + /** + * Minimal JSON parser for the fixed token structure. Handles only the fields we write. + */ + private static Decoded parseJson(String json) { + json = json.trim(); + if (!json.startsWith("{") || !json.endsWith("}")) { + throw new IllegalArgumentException("Resumption token JSON must be an object"); + } + + String runId = null; + String configKey = null; + String variationKey = null; + Integer version = null; + String graphKey = null; + + // Walk through the JSON object fields + int pos = 1; // skip opening '{' + while (pos < json.length() - 1) { + pos = skipWhitespace(json, pos); + if (pos >= json.length() - 1) { + break; + } + if (json.charAt(pos) == ',') { + pos++; + pos = skipWhitespace(json, pos); + } + if (pos >= json.length() - 1) { + break; + } + + // Read key + if (json.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos + " in resumption token"); + } + int[] keyEnd = new int[1]; + String key = readString(json, pos, keyEnd); + pos = keyEnd[0]; + + pos = skipWhitespace(json, pos); + if (pos >= json.length() || json.charAt(pos) != ':') { + throw new IllegalArgumentException("Expected ':' after key in resumption token"); + } + pos++; // skip ':' + pos = skipWhitespace(json, pos); + + // Read value + if (json.charAt(pos) == '"') { + int[] valEnd = new int[1]; + String value = readString(json, pos, valEnd); + pos = valEnd[0]; + switch (key) { + case "runId": runId = value; break; + case "configKey": configKey = value; break; + case "variationKey": variationKey = value; break; + case "graphKey": graphKey = value; break; + default: break; + } + } else { + // numeric value + int start = pos; + while (pos < json.length() && json.charAt(pos) != ',' && json.charAt(pos) != '}') { + pos++; + } + String numStr = json.substring(start, pos).trim(); + if ("version".equals(key)) { + try { + version = Integer.parseInt(numStr); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Field 'version' must be an integer in resumption token", e); + } + } + } + } + + if (runId == null) { + throw new IllegalArgumentException("Resumption token missing required field 'runId'"); + } + if (configKey == null) { + throw new IllegalArgumentException("Resumption token missing required field 'configKey'"); + } + if (version == null) { + throw new IllegalArgumentException("Resumption token missing required field 'version'"); + } + + // Clamp version: a missing version field defaults to 0 in older tokens, but configs default to 1. + int clampedVersion = Math.max(1, version); + + return new Decoded(runId, configKey, variationKey, clampedVersion, graphKey); + } + + private static int skipWhitespace(String s, int pos) { + while (pos < s.length() && Character.isWhitespace(s.charAt(pos))) { + pos++; + } + return pos; + } + + /** + * Reads a JSON string starting at {@code pos} (which must point to the opening {@code "}). + * Populates {@code end[0]} with the position after the closing {@code "}. + */ + private static String readString(String s, int pos, int[] end) { + if (s.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos); + } + pos++; // skip opening quote + StringBuilder sb = new StringBuilder(); + while (pos < s.length()) { + char c = s.charAt(pos); + if (c == '"') { + end[0] = pos + 1; + return sb.toString(); + } else if (c == '\\') { + pos++; + if (pos >= s.length()) { + throw new IllegalArgumentException("Unterminated escape in resumption token"); + } + char escaped = s.charAt(pos); + switch (escaped) { + case '"': sb.append('"'); break; + case '\\': sb.append('\\'); break; + case '/': sb.append('/'); break; + case 'b': sb.append('\b'); break; + case 'f': sb.append('\f'); break; + case 'n': sb.append('\n'); break; + case 'r': sb.append('\r'); break; + case 't': sb.append('\t'); break; + case 'u': + if (pos + 4 >= s.length()) { + throw new IllegalArgumentException("Incomplete Unicode escape in resumption token"); + } + String hex = s.substring(pos + 1, pos + 5); + try { + sb.append((char) Integer.parseInt(hex, 16)); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Invalid Unicode escape \\u" + hex + " in resumption token"); + } + pos += 4; + break; + default: + throw new IllegalArgumentException("Unknown escape character '\\" + escaped + "' in resumption token"); + } + } else { + sb.append(c); + } + pos++; + } + throw new IllegalArgumentException("Unterminated string in resumption token"); + } + + /** + * Escapes a string for inclusion in a JSON string literal. Handles the characters required by + * RFC 8259 §7. + */ + static String escapeJson(String s) { + if (s == null) { + return ""; + } + StringBuilder sb = new StringBuilder(s.length()); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"': sb.append("\\\""); break; + case '\\': sb.append("\\\\"); break; + case '\b': sb.append("\\b"); break; + case '\f': sb.append("\\f"); break; + case '\n': sb.append("\\n"); break; + case '\r': sb.append("\\r"); break; + case '\t': sb.append("\\t"); break; + default: + if (c < 0x20) { + sb.append(String.format("\\u%04x", (int) c)); + } else { + sb.append(c); + } + } + } + return sb.toString(); + } + + /** + * The decoded fields from a resumption token. + */ + static final class Decoded { + private final String runId; + private final String configKey; + private final String variationKey; + private final int version; + private final String graphKey; + + Decoded(String runId, String configKey, String variationKey, int version, String graphKey) { + this.runId = runId; + this.configKey = configKey; + this.variationKey = variationKey; + this.version = version; + this.graphKey = graphKey; + } + + String getRunId() { + return runId; + } + + String getConfigKey() { + return configKey; + } + + String getVariationKey() { + return variationKey; + } + + int getVersion() { + return version; + } + + String getGraphKey() { + return graphKey; + } + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java new file mode 100644 index 00000000..9594c33c --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -0,0 +1,695 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import com.launchdarkly.logging.LDLogLevel; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LogCapture; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.time.Duration; +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +@SuppressWarnings("javadoc") +public class LDAIConfigTrackerImplTest { + private LDClientInterface client; + private LogCapture logCapture; + private LDLogger logger; + private LDAIConfigTrackerImpl tracker; + + private static final LDContext CONTEXT = LDContext.create("user-key"); + private static final String RUN_ID = "test-run-id"; + private static final String CONFIG_KEY = "my-config"; + private static final String VARIATION_KEY = "variation-abc"; + private static final int VERSION = 3; + private static final String MODEL_NAME = "gpt-4"; + private static final String PROVIDER_NAME = "openai"; + + @Before + public void setUp() { + client = mock(LDClientInterface.class); + logCapture = Logs.capture(); + logger = LDLogger.withAdapter(logCapture, "test"); + tracker = makeTracker(VARIATION_KEY); + } + + private LDAIConfigTrackerImpl makeTracker(String variationKey) { + return new LDAIConfigTrackerImpl( + client, RUN_ID, CONFIG_KEY, variationKey, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, null, logger); + } + + private List warnings() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private LDValue baseExpectedData() { + return LDValue.buildObject() + .put("runId", RUN_ID) + .put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY) + .put("version", VERSION) + .put("modelName", MODEL_NAME) + .put("providerName", PROVIDER_NAME) + .build(); + } + + // ---- getTrackData / getResumptionToken ------------------------------------ + + @Test + public void getTrackDataReturnsCorrectFields() { + TrackData data = tracker.getTrackData(); + assertThat(data.getRunId(), is(RUN_ID)); + assertThat(data.getConfigKey(), is(CONFIG_KEY)); + assertThat(data.getVariationKey(), is(VARIATION_KEY)); + assertThat(data.getVersion(), is(VERSION)); + assertThat(data.getModelName(), is(MODEL_NAME)); + assertThat(data.getProviderName(), is(PROVIDER_NAME)); + assertThat(data.getGraphKey(), is(nullValue())); + } + + @Test + public void getTrackDataOmitsVariationKeyWhenNull() { + LDAIConfigTrackerImpl t = makeTracker(null); + assertThat(t.getTrackData().getVariationKey(), is(nullValue())); + LDValue ldv = t.getTrackData().toLDValue(); + assertThat(ldv.get("variationKey").isNull(), is(true)); // absent key returns LDValue.ofNull() + } + + @Test + public void getResumptionTokenIsNotNull() { + assertThat(tracker.getResumptionToken(), is(notNullValue())); + } + + @Test + public void resumptionTokenRoundTrips() throws Exception { + String token = tracker.getResumptionToken(); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(RUN_ID)); + assertThat(d.getConfigKey(), is(CONFIG_KEY)); + assertThat(d.getVariationKey(), is(VARIATION_KEY)); + assertThat(d.getVersion(), is(VERSION)); + assertThat(d.getGraphKey(), is(nullValue())); + } + + @Test + public void fromResumptionTokenRestoresCorrectFields() { + String token = tracker.getResumptionToken(); + LDAIConfigTrackerImpl restored = + LDAIConfigTrackerImpl.fromResumptionToken(token, client, CONTEXT, logger); + TrackData data = restored.getTrackData(); + assertThat(data.getRunId(), is(RUN_ID)); + assertThat(data.getConfigKey(), is(CONFIG_KEY)); + assertThat(data.getVariationKey(), is(VARIATION_KEY)); + assertThat(data.getVersion(), is(VERSION)); + assertThat(data.getModelName(), is("")); // not in token + assertThat(data.getProviderName(), is("")); // not in token + } + + // ---- trackDuration -------------------------------------------------------- + + @Test + public void trackDurationEmitsCorrectEvent() { + tracker.trackDuration(Duration.ofMillis(500)); + verify(client).trackMetric( + eq("$ld:ai:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(500.0)); + } + + @Test + public void trackDurationClampsNegativeToZero() { + tracker.trackDuration(Duration.ofMillis(-100)); + verify(client).trackMetric( + eq("$ld:ai:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(0.0)); + } + + @Test + public void trackDurationAtMostOnce() { + tracker.trackDuration(Duration.ofMillis(100)); + tracker.trackDuration(Duration.ofMillis(200)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(warnings().get(0), containsString("duration")); + } + + @Test + public void trackDurationNullIsIgnoredWithWarning() { + tracker.trackDuration(null); + verify(client, never()).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackDurationOf ------------------------------------------------------ + + @Test + public void trackDurationOfReturnResultAndTracksDuration() throws Exception { + String result = tracker.trackDurationOf(() -> "hello"); + assertThat(result, is("hello")); + verify(client, times(1)).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + + @Test + public void trackDurationOfTracksDurationEvenOnException() { + try { + tracker.trackDurationOf(() -> { + throw new RuntimeException("boom"); + }); + } catch (Exception ignored) { + } + verify(client, times(1)).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + + // ---- trackTimeToFirstToken ------------------------------------------------ + + @Test + public void trackTimeToFirstTokenEmitsCorrectEvent() { + tracker.trackTimeToFirstToken(Duration.ofMillis(250)); + verify(client).trackMetric( + eq("$ld:ai:tokens:ttf"), eq(CONTEXT), eq(baseExpectedData()), eq(250.0)); + } + + @Test + public void trackTimeToFirstTokenAtMostOnce() { + tracker.trackTimeToFirstToken(Duration.ofMillis(100)); + tracker.trackTimeToFirstToken(Duration.ofMillis(200)); + verify(client, times(1)).trackMetric(eq("$ld:ai:tokens:ttf"), any(), any(), anyDouble()); + } + + @Test + public void trackTimeToFirstTokenNullIsIgnoredWithWarning() { + tracker.trackTimeToFirstToken(null); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:ttf"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackSuccess / trackError -------------------------------------------- + + @Test + public void trackSuccessEmitsCorrectEvent() { + tracker.trackSuccess(); + verify(client).trackMetric( + eq("$ld:ai:generation:success"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackErrorEmitsCorrectEvent() { + tracker.trackError(); + verify(client).trackMetric( + eq("$ld:ai:generation:error"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackSuccessAtMostOnce() { + tracker.trackSuccess(); + tracker.trackSuccess(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackErrorAtMostOnce() { + tracker.trackError(); + tracker.trackError(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackSuccessAndErrorShareGuard_successFirst() { + tracker.trackSuccess(); + tracker.trackError(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackSuccessAndErrorShareGuard_errorFirst() { + tracker.trackError(); + tracker.trackSuccess(); + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackFeedback -------------------------------------------------------- + + @Test + public void trackFeedbackPositiveEmitsCorrectEvent() { + tracker.trackFeedback(FeedbackKind.POSITIVE); + verify(client).trackMetric( + eq("$ld:ai:feedback:user:positive"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackFeedbackNegativeEmitsCorrectEvent() { + tracker.trackFeedback(FeedbackKind.NEGATIVE); + verify(client).trackMetric( + eq("$ld:ai:feedback:user:negative"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackFeedbackAtMostOnce() { + tracker.trackFeedback(FeedbackKind.POSITIVE); + tracker.trackFeedback(FeedbackKind.NEGATIVE); + verify(client, times(1)).trackMetric( + eq("$ld:ai:feedback:user:positive"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:feedback:user:negative"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackFeedbackNullIsIgnoredWithWarning_slotNotBurned() { + tracker.trackFeedback(null); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + // Slot should not be burned — a subsequent valid call should still work + tracker.trackFeedback(FeedbackKind.POSITIVE); + verify(client, times(1)).trackMetric(eq("$ld:ai:feedback:user:positive"), any(), any(), anyDouble()); + } + + // ---- trackTokens ---------------------------------------------------------- + + @Test + public void trackTokensEmitsEventsForPositiveCounts() { + tracker.trackTokens(new TokenUsage(100, 60, 40)); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), eq(CONTEXT), eq(baseExpectedData()), eq(100.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:input"), eq(CONTEXT), eq(baseExpectedData()), eq(60.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), eq(CONTEXT), eq(baseExpectedData()), eq(40.0)); + } + + @Test + public void trackTokensSkipsZeroCounts() { + tracker.trackTokens(new TokenUsage(0, 0, 40)); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:input"), any(), any(), anyDouble()); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), any(), any(), eq(40.0)); + } + + @Test + public void trackTokensAllZeroDoesNotBurnSlot() { + tracker.trackTokens(new TokenUsage(0, 0, 0)); + // Slot not burned — next valid call should succeed + tracker.trackTokens(new TokenUsage(10, 5, 5)); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), eq(10.0)); + } + + @Test + public void trackTokensAtMostOnce() { + tracker.trackTokens(new TokenUsage(10, 5, 5)); + tracker.trackTokens(new TokenUsage(20, 10, 10)); + verify(client, times(1)).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + @Test + public void trackTokensNullIsIgnoredWithWarning() { + tracker.trackTokens(null); + verify(client, never()).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackToolCall -------------------------------------------------------- + + @Test + public void trackToolCallEmitsOnEveryCall() { + LDValue expectedDataWithTool = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("toolKey", "search") + .build(); + + tracker.trackToolCall("search"); + tracker.trackToolCall("search"); + tracker.trackToolCall("fetch"); + + verify(client, times(2)).trackMetric( + eq("$ld:ai:tool_call"), eq(CONTEXT), eq(expectedDataWithTool), eq(1.0)); + LDValue fetchData = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("toolKey", "fetch") + .build(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:tool_call"), eq(CONTEXT), eq(fetchData), eq(1.0)); + } + + @Test + public void trackToolCallsDelegate() { + tracker.trackToolCalls(Arrays.asList("a", "b")); + verify(client, times(2)).trackMetric(eq("$ld:ai:tool_call"), any(), any(), anyDouble()); + } + + @Test + public void trackToolCallNullIsIgnoredWithWarning() { + tracker.trackToolCall(null); + verify(client, never()).trackMetric(eq("$ld:ai:tool_call"), any(), any(), anyDouble()); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackJudgeResult ----------------------------------------------------- + + @Test + public void trackJudgeResultEmitsWhenSampledAndSucceeded() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true) + .metricKey("judge-score").score(0.85) + .judgeConfigKey("my-judge") + .build(); + + LDValue expectedData = LDValue.buildObject() + .put("runId", RUN_ID).put("configKey", CONFIG_KEY) + .put("variationKey", VARIATION_KEY).put("version", VERSION) + .put("modelName", MODEL_NAME).put("providerName", PROVIDER_NAME) + .put("judgeConfigKey", "my-judge") + .build(); + + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("judge-score"), eq(CONTEXT), eq(expectedData), eq(0.85)); + } + + @Test + public void trackJudgeResultSkipsWhenNotSampled() { + JudgeResult result = JudgeResult.builder() + .sampled(false).success(true).metricKey("k").score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenNotSuccess() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(false).metricKey("k").score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenMetricKeyNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey(null).score(1.0).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(any(), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultSkipsWhenScoreNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(null).build(); + tracker.trackJudgeResult(result); + verify(client, never()).trackMetric(eq("k"), any(), any(), anyDouble()); + } + + @Test + public void trackJudgeResultFiresWhenScoreIsZero() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(0.0).build(); + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("k"), any(), any(), eq(0.0)); + } + + @Test + public void trackJudgeResultOmitsJudgeConfigKeyWhenNull() { + JudgeResult result = JudgeResult.builder() + .sampled(true).success(true).metricKey("k").score(1.0).judgeConfigKey(null).build(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + tracker.trackJudgeResult(result); + verify(client).trackMetric(eq("k"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("judgeConfigKey").isNull(), is(true)); + } + + @Test + public void trackJudgeResultIsNotAtMostOnce() { + JudgeResult r1 = JudgeResult.builder().sampled(true).success(true).metricKey("k1").score(1.0).build(); + JudgeResult r2 = JudgeResult.builder().sampled(true).success(true).metricKey("k2").score(2.0).build(); + tracker.trackJudgeResult(r1); + tracker.trackJudgeResult(r2); + verify(client).trackMetric(eq("k1"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("k2"), any(), any(), eq(2.0)); + } + + @Test + public void trackJudgeResultNullIsIgnoredWithWarning() { + tracker.trackJudgeResult(null); + assertThat(warnings().size(), greaterThanOrEqualTo(1)); + } + + // ---- trackMetricsOf ------------------------------------------------------- + + @Test + public void trackMetricsOfTracksSuccessAndDurationAndTokens() throws Exception { + AIMetrics metrics = AIMetrics.builder() + .success(true) + .tokens(new TokenUsage(10, 6, 4)) + .build(); + + String result = tracker.trackMetricsOf(r -> metrics, () -> "ok"); + assertThat(result, is("ok")); + + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + verify(client).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), eq(10.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:input"), any(), any(), eq(6.0)); + verify(client).trackMetric(eq("$ld:ai:tokens:output"), any(), any(), eq(4.0)); + } + + @Test + public void trackMetricsOfUsesRunnerReportedDurationWhenPresent() throws Exception { + AIMetrics metrics = AIMetrics.builder().success(true).durationMs(999L).build(); + tracker.trackMetricsOf(r -> metrics, () -> "ok"); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), eq(999.0)); + } + + @Test + public void trackMetricsOfTracksErrorAndRethrowsOnOperationException() { + try { + tracker.trackMetricsOf( + r -> AIMetrics.builder().success(true).build(), + () -> { throw new RuntimeException("ai failed"); }); + } catch (Exception e) { + assertThat(e.getMessage(), is("ai failed")); + } + verify(client).trackMetric(eq("$ld:ai:generation:error"), any(), any(), eq(1.0)); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + @Test + public void trackMetricsOfExtractorExceptionPropagatesAndDoesNotCallTrackError() { + try { + tracker.trackMetricsOf( + r -> { throw new RuntimeException("extractor failed"); }, + () -> "ok"); + } catch (Exception e) { + assertThat(e.getMessage(), is("extractor failed")); + } + verify(client, never()).trackMetric(eq("$ld:ai:generation:error"), any(), any(), anyDouble()); + verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + @Test + public void trackMetricsOfTracksToolCalls() throws Exception { + AIMetrics metrics = AIMetrics.builder() + .success(true) + .toolCalls(Arrays.asList("search", "fetch")) + .build(); + tracker.trackMetricsOf(r -> metrics, () -> "ok"); + verify(client, times(2)).trackMetric(eq("$ld:ai:tool_call"), any(), any(), eq(1.0)); + } + + // ---- getSummary ----------------------------------------------------------- + + @Test + public void getSummaryReturnsNullsBeforeAnyTracking() { + MetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(nullValue())); + assertThat(summary.getDurationMs(), is(nullValue())); + assertThat(summary.getTokens(), is(nullValue())); + assertThat(summary.getFeedback(), is(nullValue())); + assertThat(summary.getTimeToFirstTokenMs(), is(nullValue())); + assertThat(summary.getToolCalls(), is(nullValue())); + assertThat(summary.getResumptionToken(), is(notNullValue())); + } + + @Test + public void getSummaryReflectsAllTrackedValues() { + tracker.trackDuration(Duration.ofMillis(300)); + tracker.trackTimeToFirstToken(Duration.ofMillis(50)); + tracker.trackSuccess(); + tracker.trackFeedback(FeedbackKind.POSITIVE); + tracker.trackTokens(new TokenUsage(30, 20, 10)); + tracker.trackToolCall("search"); + tracker.trackToolCall("fetch"); + + MetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(Boolean.TRUE)); + assertThat(summary.getDurationMs(), is(300L)); + assertThat(summary.getTimeToFirstTokenMs(), is(50L)); + assertThat(summary.getFeedback(), is(FeedbackKind.POSITIVE)); + assertThat(summary.getTokens().getTotal(), is(30L)); + assertThat(summary.getToolCalls(), containsInAnyOrder("search", "fetch")); + assertThat(summary.getResumptionToken(), is(tracker.getResumptionToken())); + } + + @Test + public void getSummarySuccessIsFalseWhenErrorTracked() { + tracker.trackError(); + assertThat(tracker.getSummary().getSuccess(), is(Boolean.FALSE)); + } + + @Test + public void getSummaryToolCallsIsImmutableSnapshot() { + tracker.trackToolCall("a"); + List snapshot1 = tracker.getSummary().getToolCalls(); + tracker.trackToolCall("b"); + List snapshot2 = tracker.getSummary().getToolCalls(); + assertThat(snapshot1.size(), is(1)); + assertThat(snapshot2.size(), is(2)); + } + + // ---- variationKey omission ------------------------------------------------ + + @Test + public void variationKeyOmittedFromPayloadWhenNull() { + LDAIConfigTrackerImpl t = makeTracker(null); + t.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("variationKey").isNull(), is(true)); + } + + @Test + public void variationKeyIncludedInPayloadWhenPresent() { + tracker.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("variationKey").stringValue(), is(VARIATION_KEY)); + } + + // ---- graphKey inclusion --------------------------------------------------- + + @Test + public void graphKeyIncludedInPayloadWhenSet() { + LDAIConfigTrackerImpl t = new LDAIConfigTrackerImpl( + client, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, "my-graph", logger); + t.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("graphKey").stringValue(), is("my-graph")); + } + + @Test + public void graphKeyOmittedFromPayloadWhenNull() { + tracker.trackSuccess(); + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:generation:success"), any(), dataCaptor.capture(), anyDouble()); + assertThat(dataCaptor.getValue().get("graphKey").isNull(), is(true)); + } + + // ---- concurrency: at-most-once under contention --------------------------- + + @Test + public void trackDurationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + AtomicInteger callCount = new AtomicInteger(0); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackDuration(Duration.ofMillis(100)); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + ArgumentCaptor valueCaptor = ArgumentCaptor.forClass(Double.class); + verify(client, times(1)).trackMetric( + eq("$ld:ai:duration:total"), any(), any(), valueCaptor.capture()); + } + + @Test + public void trackSuccessAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackSuccess(); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + verify(client, times(1)).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); + } + + // ---- constructor null checks ---------------------------------------------- + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullClient() { + new LDAIConfigTrackerImpl(null, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, CONTEXT, null, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullContext() { + new LDAIConfigTrackerImpl(client, RUN_ID, CONFIG_KEY, VARIATION_KEY, VERSION, + MODEL_NAME, PROVIDER_NAME, null, null, logger); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java new file mode 100644 index 00000000..737773a8 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java @@ -0,0 +1,174 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class ResumptionTokensTest { + + // ---- encode + decode round-trips ------------------------------------------ + + @Test + public void roundTripWithAllFields() { + String token = ResumptionTokens.encode("run-1", "config-key", "var-abc", 2, "graph-x"); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("config-key")); + assertThat(d.getVariationKey(), is("var-abc")); + assertThat(d.getVersion(), is(2)); + assertThat(d.getGraphKey(), is("graph-x")); + } + + @Test + public void roundTripWithNullVariationKey() { + String token = ResumptionTokens.encode("run-1", "config-key", null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is("run-1")); + assertThat(d.getConfigKey(), is("config-key")); + assertThat(d.getVariationKey(), is(nullValue())); + assertThat(d.getVersion(), is(1)); + assertThat(d.getGraphKey(), is(nullValue())); + } + + @Test + public void roundTripWithNullGraphKey() { + String token = ResumptionTokens.encode("run-2", "cfg", "v1", 3, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getGraphKey(), is(nullValue())); + assertThat(d.getVariationKey(), is("v1")); + } + + @Test + public void variationKeyOmittedFromTokenWhenNull() { + // Tokens with null variationKey should NOT contain the "variationKey" JSON field. + String token = ResumptionTokens.encode("r", "c", null, 1, null); + // Decode and check no variationKey + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getVariationKey(), is(nullValue())); + } + + @Test + public void graphKeyOmittedFromTokenWhenNull() { + String token = ResumptionTokens.encode("r", "c", "v", 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getGraphKey(), is(nullValue())); + } + + // ---- special character escaping ------------------------------------------- + + @Test + public void roundTripWithSpecialCharactersInKeys() { + String runId = "run\"with\\special\nchars"; + String configKey = "config\twith\rtabs"; + String token = ResumptionTokens.encode(runId, configKey, null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(runId)); + assertThat(d.getConfigKey(), is(configKey)); + } + + @Test + public void roundTripWithUnicodeInKeys() { + String runId = "run-\u00e9\u4e2d\u6587"; + String token = ResumptionTokens.encode(runId, "cfg", null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getRunId(), is(runId)); + } + + // ---- version clamping ----------------------------------------------------- + + @Test + public void versionBelowOneIsClampedToOne() { + // Construct a token with version 0 directly to simulate an old token. + String json = "{\"runId\":\"r\",\"configKey\":\"c\",\"version\":0}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getVersion(), is(1)); + } + + @Test + public void versionOneIsNotClamped() { + String token = ResumptionTokens.encode("r", "c", null, 1, null); + assertThat(ResumptionTokens.decode(token).getVersion(), is(1)); + } + + // ---- decode error handling ------------------------------------------------ + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNull() { + ResumptionTokens.decode(null); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsOversizedToken() { + // Build a token larger than 4096 bytes + String largeValue = new String(new char[5000]).replace('\0', 'x'); + String json = "{\"runId\":\"" + largeValue + "\",\"configKey\":\"c\",\"version\":1}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsInvalidBase64() { + ResumptionTokens.decode("not-valid-base64!!!!"); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingRunId() { + String json = "{\"configKey\":\"c\",\"version\":1}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingConfigKey() { + String json = "{\"runId\":\"r\",\"version\":1}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsMissingVersion() { + String json = "{\"runId\":\"r\",\"configKey\":\"c\"}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNonIntegerVersion() { + String json = "{\"runId\":\"r\",\"configKey\":\"c\",\"version\":\"one\"}"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + @Test(expected = IllegalArgumentException.class) + public void decodeRejectsNonObjectJson() { + String json = "[\"not\",\"an\",\"object\"]"; + String token = java.util.Base64.getUrlEncoder().withoutPadding() + .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); + ResumptionTokens.decode(token); + } + + // ---- escapeJson helper ---------------------------------------------------- + + @Test + public void escapeJsonHandlesControlCharacters() { + assertThat(ResumptionTokens.escapeJson("\n\r\t"), is("\\n\\r\\t")); + assertThat(ResumptionTokens.escapeJson("\"hello\""), is("\\\"hello\\\"")); + assertThat(ResumptionTokens.escapeJson("back\\slash"), is("back\\\\slash")); + } + + @Test + public void escapeJsonReturnsEmptyStringForNull() { + assertThat(ResumptionTokens.escapeJson(null), is("")); + } +} From a0c87845f13f3bffff53ce5c6c547b661abb9965 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 10:11:53 -0500 Subject: [PATCH 02/23] fix: default tracker version to 1 and remove version clamp from token decode --- .../launchdarkly/sdk/server/ai/LDAIClientImpl.java | 2 +- .../sdk/server/ai/internal/ResumptionTokens.java | 5 +---- .../server/ai/internal/ResumptionTokensTest.java | 14 ++------------ 3 files changed, 4 insertions(+), 17 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 8f261604..8bf81e71 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -289,7 +289,7 @@ private Supplier trackerFactory( LDContext context) { String modelName = model != null && model.getName() != null ? model.getName() : ""; String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; - int ver = version != null ? version : 0; + int ver = version != null ? version : 1; return () -> new LDAIConfigTrackerImpl( client, UUID.randomUUID().toString(), diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index 82a90a7b..de3885c2 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -160,10 +160,7 @@ private static Decoded parseJson(String json) { throw new IllegalArgumentException("Resumption token missing required field 'version'"); } - // Clamp version: a missing version field defaults to 0 in older tokens, but configs default to 1. - int clampedVersion = Math.max(1, version); - - return new Decoded(runId, configKey, variationKey, clampedVersion, graphKey); + return new Decoded(runId, configKey, variationKey, version, graphKey); } private static int skipWhitespace(String s, int pos) { diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java index 737773a8..21a263b1 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java @@ -78,20 +78,10 @@ public void roundTripWithUnicodeInKeys() { assertThat(d.getRunId(), is(runId)); } - // ---- version clamping ----------------------------------------------------- + // ---- version round-trip --------------------------------------------------- @Test - public void versionBelowOneIsClampedToOne() { - // Construct a token with version 0 directly to simulate an old token. - String json = "{\"runId\":\"r\",\"configKey\":\"c\",\"version\":0}"; - String token = java.util.Base64.getUrlEncoder().withoutPadding() - .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); - ResumptionTokens.Decoded d = ResumptionTokens.decode(token); - assertThat(d.getVersion(), is(1)); - } - - @Test - public void versionOneIsNotClamped() { + public void versionIsPreservedOnRoundTrip() { String token = ResumptionTokens.encode("r", "c", null, 1, null); assertThat(ResumptionTokens.decode(token).getVersion(), is(1)); } From 9ae20ca1109936b053f7df4e8f20fec58884ad7c Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:33:05 -0500 Subject: [PATCH 03/23] agent graph support (first pass) --- .../sdk/server/ai/AIGraphMetricSummary.java | 84 +++ .../sdk/server/ai/AIGraphTracker.java | 295 +++++++++++ .../sdk/server/ai/AgentGraphDefinition.java | 289 +++++++++++ .../sdk/server/ai/AgentGraphNode.java | 59 +++ .../launchdarkly/sdk/server/ai/GraphEdge.java | 43 ++ .../sdk/server/ai/LDAIClient.java | 42 ++ .../sdk/server/ai/LDAIClientImpl.java | 172 ++++++- .../ai/internal/AgentGraphFlagValue.java | 220 ++++++++ .../server/ai/internal/ResumptionTokens.java | 163 +++++- .../sdk/server/ai/AIGraphTrackerTest.java | 480 ++++++++++++++++++ .../server/ai/AgentGraphDefinitionTest.java | 445 ++++++++++++++++ .../sdk/server/ai/LDAIClientImplTest.java | 192 +++++++ .../ai/internal/AgentGraphFlagValueTest.java | 276 ++++++++++ 13 files changed, 2745 insertions(+), 15 deletions(-) create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java create mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java create mode 100644 lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java create mode 100644 lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java create mode 100644 lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java new file mode 100644 index 00000000..05e45475 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphMetricSummary.java @@ -0,0 +1,84 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; + +import java.util.List; + +/** + * A snapshot of the metrics tracked so far by an {@link AIGraphTracker}. + *

+ * All fields are nullable: a {@code null} value means the corresponding metric has not been + * recorded yet on the tracker. {@link #getResumptionToken()} is always present. + *

+ * Instances are immutable. + */ +public final class AIGraphMetricSummary { + private final Boolean success; + private final Double durationMs; + private final TokenUsage tokens; + private final List path; + private final String resumptionToken; + + AIGraphMetricSummary( + Boolean success, + Double durationMs, + TokenUsage tokens, + List path, + String resumptionToken) { + this.success = success; + this.durationMs = durationMs; + this.tokens = tokens; + this.path = path; + this.resumptionToken = resumptionToken; + } + + /** + * Returns the invocation outcome: {@code true} if {@code trackInvocationSuccess} was called, + * {@code false} if {@code trackInvocationFailure} was called, or {@code null} if neither has + * been called yet. + * + * @return the success flag, or {@code null} if not yet recorded + */ + public Boolean getSuccess() { + return success; + } + + /** + * Returns the tracked graph-level duration in milliseconds, or {@code null} if not recorded. + * + * @return the duration in ms, or {@code null} + */ + public Double getDurationMs() { + return durationMs; + } + + /** + * Returns the tracked token usage, or {@code null} if not recorded. + * + * @return the token usage, or {@code null} + */ + public TokenUsage getTokens() { + return tokens; + } + + /** + * Returns the tracked node path (ordered list of node keys visited), or {@code null} if not + * recorded. + * + * @return an unmodifiable list of node keys, or {@code null} + */ + public List getPath() { + return path; + } + + /** + * Returns the resumption token for this graph run, which can be passed to + * {@link LDAIClient#createGraphTracker(String, com.launchdarkly.sdk.LDContext)} to reconstruct + * the tracker on a subsequent request. + * + * @return the resumption token; never {@code null} + */ + public String getResumptionToken() { + return resumptionToken; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java new file mode 100644 index 00000000..299143ce --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java @@ -0,0 +1,295 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.logging.LDLogAdapter; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LDSLF4J; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.ArrayBuilder; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.ai.internal.ResumptionTokens; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Reports graph-level events for a single invocation of an {@link AgentGraphDefinition}. + *

+ * An {@code AIGraphTracker} is obtained from an enabled graph definition via + * {@link AgentGraphDefinition#createTracker()}, or reconstructed from a resumption token via + * {@link LDAIClient#createGraphTracker(String, LDContext)}. + *

+ * Graph-level methods (invocation, duration, tokens, path) are at-most-once: a second call on + * the same tracker is silently dropped. Edge-level methods (redirect, handoff) are multi-fire — + * each call records a distinct event. + *

+ * Implementations are thread-safe. + */ +public final class AIGraphTracker { + + private static final String GRAPH_INVOCATION_SUCCESS = "$ld:ai:graph:invocation_success"; + private static final String GRAPH_INVOCATION_FAILURE = "$ld:ai:graph:invocation_failure"; + private static final String GRAPH_DURATION_TOTAL = "$ld:ai:graph:duration:total"; + private static final String GRAPH_TOTAL_TOKENS = "$ld:ai:graph:total_tokens"; + private static final String GRAPH_PATH = "$ld:ai:graph:path"; + private static final String GRAPH_REDIRECT = "$ld:ai:graph:redirect"; + private static final String GRAPH_HANDOFF_SUCCESS = "$ld:ai:graph:handoff_success"; + private static final String GRAPH_HANDOFF_FAILURE = "$ld:ai:graph:handoff_failure"; + + private final LDClientInterface client; + private final LDContext context; + private final LDLogger logger; + + private final String runId; + private final String graphKey; + private final String variationKey; + private final int version; + + private final String resumptionToken; + + // At-most-once guards: null = not yet recorded, non-null = recorded. + // trackInvocationSuccess and trackInvocationFailure share invocationRecorded: + // true = success was recorded, false = failure was recorded. + private final AtomicReference invocationRecorded = new AtomicReference<>(); + private final AtomicReference durationRecorded = new AtomicReference<>(); + private final AtomicReference tokensRecorded = new AtomicReference<>(); + private final AtomicReference> pathRecorded = new AtomicReference<>(); + + AIGraphTracker( + LDClientInterface client, + String runId, + String graphKey, + String variationKey, + int version, + LDContext context, + LDLogger logger) { + this.client = Objects.requireNonNull(client, "client"); + this.runId = Objects.requireNonNull(runId, "runId"); + this.graphKey = Objects.requireNonNull(graphKey, "graphKey"); + this.variationKey = variationKey; + this.version = version; + this.context = Objects.requireNonNull(context, "context"); + this.logger = Objects.requireNonNull(logger, "logger"); + + this.resumptionToken = ResumptionTokens.encodeGraph(runId, graphKey, variationKey, version); + } + + /** + * Reconstructs a graph tracker from a resumption token, preserving the original run identity. + * + * @param token the resumption token produced by {@link #getResumptionToken()} + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static AIGraphTracker fromResumptionToken( + String token, LDClientInterface client, LDContext context) { + ResumptionTokens.DecodedGraph d = ResumptionTokens.decodeGraph(token); + int version = Math.max(1, d.getVersion()); + return new AIGraphTracker( + client, + d.getRunId(), + d.getGraphKey(), + d.getVariationKey(), + version, + context, + defaultLogger()); + } + + /** + * Records that the graph invocation succeeded. + *

+ * At-most-once and mutually exclusive with {@link #trackInvocationFailure()}: whichever is + * called first wins. + */ + public void trackInvocationSuccess() { + if (!invocationRecorded.compareAndSet(null, Boolean.TRUE)) { + logger.warn("Skipping trackInvocationSuccess: invocation already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_INVOCATION_SUCCESS, context, baseData().build(), 1); + } + + /** + * Records that the graph invocation failed. + *

+ * At-most-once and mutually exclusive with {@link #trackInvocationSuccess()}: whichever is + * called first wins. + */ + public void trackInvocationFailure() { + if (!invocationRecorded.compareAndSet(null, Boolean.FALSE)) { + logger.warn("Skipping trackInvocationFailure: invocation already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_INVOCATION_FAILURE, context, baseData().build(), 1); + } + + /** + * Records the total wall-clock duration of the graph invocation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param durationMs the duration in milliseconds + */ + public void trackDuration(double durationMs) { + if (!durationRecorded.compareAndSet(null, durationMs)) { + logger.warn("Skipping trackDuration: duration already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_DURATION_TOTAL, context, baseData().build(), durationMs); + } + + /** + * Records the total token usage for the graph invocation. + *

+ * At-most-once: subsequent calls are silently dropped. Calls where all counts are zero do not + * consume the at-most-once slot. + * + * @param tokens the token usage; ignored if {@code null} + */ + public void trackTotalTokens(TokenUsage tokens) { + if (tokens == null) { + logger.warn("Skipping trackTotalTokens: tokens was null."); + return; + } + boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; + if (!hasPositive) { + return; + } + if (!tokensRecorded.compareAndSet(null, tokens)) { + logger.warn("Skipping trackTotalTokens: token usage already recorded on this graph tracker."); + return; + } + client.trackMetric(GRAPH_TOTAL_TOKENS, context, baseData().build(), tokens.getTotal()); + } + + /** + * Records the ordered path of node keys visited during the graph invocation. + *

+ * At-most-once: subsequent calls on the same tracker are silently dropped. + * + * @param path the ordered list of node keys; ignored if {@code null} or empty + */ + public void trackPath(List path) { + if (path == null || path.isEmpty()) { + logger.warn("Skipping trackPath: path was null or empty."); + return; + } + List snapshot = Collections.unmodifiableList(path); + if (!pathRecorded.compareAndSet(null, snapshot)) { + logger.warn("Skipping trackPath: path already recorded on this graph tracker."); + return; + } + ArrayBuilder ab = LDValue.buildArray(); + for (String s : path) { + ab.add(LDValue.of(s)); + } + LDValue data = baseData().put("path", ab.build()).build(); + client.trackMetric(GRAPH_PATH, context, data, 1); + } + + /** + * Records a redirect event, where the graph transitioned from one node to a different target + * than the edge originally specified. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param redirectedTarget the key of the node that was actually used + */ + public void trackRedirect(String sourceKey, String redirectedTarget) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("redirectedTarget", redirectedTarget) + .build(); + client.trackMetric(GRAPH_REDIRECT, context, data, 1); + } + + /** + * Records a successful handoff from one node to another. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param targetKey the key of the target node + */ + public void trackHandoffSuccess(String sourceKey, String targetKey) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("targetKey", targetKey) + .build(); + client.trackMetric(GRAPH_HANDOFF_SUCCESS, context, data, 1); + } + + /** + * Records a failed handoff from one node to another. + *

+ * Multi-fire: every call emits an event. + * + * @param sourceKey the key of the source node + * @param targetKey the key of the target node + */ + public void trackHandoffFailure(String sourceKey, String targetKey) { + LDValue data = baseData() + .put("sourceKey", sourceKey) + .put("targetKey", targetKey) + .build(); + client.trackMetric(GRAPH_HANDOFF_FAILURE, context, data, 1); + } + + /** + * Returns a snapshot of all graph-level metrics tracked so far on this tracker. + * + * @return the metric summary; never {@code null} + */ + public AIGraphMetricSummary getSummary() { + return new AIGraphMetricSummary( + invocationRecorded.get(), + durationRecorded.get(), + tokensRecorded.get(), + pathRecorded.get(), + resumptionToken); + } + + /** + * Returns the resumption token for this graph run. + *

+ * The token encodes the run identity and can be passed to + * {@link LDAIClient#createGraphTracker(String, LDContext)} to reconstruct the tracker across + * requests. + * + * @return the resumption token; never {@code null} + */ + public String getResumptionToken() { + return resumptionToken; + } + + private ObjectBuilder baseData() { + ObjectBuilder b = LDValue.buildObject() + .put("runId", runId) + .put("graphKey", graphKey) + .put("version", version); + if (variationKey != null) { + b.put("variationKey", variationKey); + } + return b; + } + + private static LDLogger defaultLogger() { + LDLogAdapter adapter; + try { + Class.forName("org.slf4j.LoggerFactory"); + adapter = LDSLF4J.adapter(); + } catch (ClassNotFoundException e) { + adapter = Logs.toConsole(); + } + return LDLogger.withAdapter(adapter, "LaunchDarkly.AI"); + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java new file mode 100644 index 00000000..b18b68a2 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinition.java @@ -0,0 +1,289 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.function.BiFunction; +import java.util.function.Supplier; + +/** + * The fully resolved definition of an agent graph, containing all nodes and their edges. + *

+ * An {@code AgentGraphDefinition} is obtained from {@link LDAIClient#agentGraph}. When + * {@link #isEnabled()} returns {@code false}, the graph definition was not fetchable or failed + * validation; in that case all node collections are empty and traversal methods are no-ops. + *

+ * Traversal methods ({@link #traverse} and {@link #reverseTraverse}) are BFS-based and + * cycle-safe: each node is visited at most once. + *

+ * This class is thread-safe. All returned collections are unmodifiable. + */ +public final class AgentGraphDefinition { + private final AgentGraphFlagValue flagValue; + private final Map nodes; + private final boolean enabled; + private final Supplier trackerFactory; + + AgentGraphDefinition( + AgentGraphFlagValue flagValue, + Map nodes, + boolean enabled, + Supplier trackerFactory) { + this.flagValue = flagValue; + this.nodes = nodes; + this.enabled = enabled; + this.trackerFactory = trackerFactory; + } + + /** + * Returns {@code true} if this graph definition is enabled and all nodes were successfully + * fetched. + * + * @return whether the graph is enabled + */ + public boolean isEnabled() { + return enabled; + } + + /** + * Returns the root node of the graph (the entry point). + * + * @return the root node, or {@code null} if the graph is disabled or the root key is not in the + * node map + */ + public AgentGraphNode rootNode() { + return nodes.get(flagValue.getRoot()); + } + + /** + * Returns the node with the given key, or {@code null} if not found. + * + * @param nodeKey the node key to look up + * @return the node, or {@code null} + */ + public AgentGraphNode getNode(String nodeKey) { + return nodes.get(nodeKey); + } + + /** + * Returns the immediate child nodes of the node with the given key, following its outgoing + * edges. + * + * @param nodeKey the source node key + * @return an unmodifiable list of child nodes; empty if the node is terminal or not found + */ + public List getChildNodes(String nodeKey) { + AgentGraphNode node = nodes.get(nodeKey); + if (node == null) { + return Collections.emptyList(); + } + List children = new ArrayList<>(); + for (GraphEdge edge : node.getEdges()) { + AgentGraphNode child = nodes.get(edge.getKey()); + if (child != null) { + children.add(child); + } + } + return Collections.unmodifiableList(children); + } + + /** + * Returns all nodes that have an outgoing edge pointing to the given node key. + * + * @param nodeKey the target node key + * @return an unmodifiable list of parent nodes; empty if none found + */ + public List getParentNodes(String nodeKey) { + List parents = new ArrayList<>(); + for (AgentGraphNode node : nodes.values()) { + for (GraphEdge edge : node.getEdges()) { + if (nodeKey.equals(edge.getKey())) { + parents.add(node); + break; + } + } + } + return Collections.unmodifiableList(parents); + } + + /** + * Returns all terminal nodes (nodes with no outgoing edges). + * + * @return an unmodifiable list of terminal nodes; empty if the graph is disabled + */ + public List terminalNodes() { + List terminals = new ArrayList<>(); + for (AgentGraphNode node : nodes.values()) { + if (node.isTerminal()) { + terminals.add(node); + } + } + return Collections.unmodifiableList(terminals); + } + + /** + * Returns the internal parsed flag value for this graph. This is an internal type and is not + * part of the supported public API. + * + * @return the parsed flag value + */ + AgentGraphFlagValue getConfig() { + return flagValue; + } + + /** + * Creates a new {@link AIGraphTracker} for this graph invocation. + *

+ * Each call produces a fresh tracker with a new run ID. Returns {@code null} if the graph is + * disabled. + * + * @return a new tracker, or {@code null} if disabled + */ + public AIGraphTracker createTracker() { + if (!enabled || trackerFactory == null) { + return null; + } + return trackerFactory.get(); + } + + /** + * Performs a BFS traversal of the graph starting from the root node. + *

+ * For each node visited, {@code fn} is called with the node and the mutable context map. The + * return value of {@code fn} is stored in the context map under the node's key, making it + * available to subsequently visited nodes. Each node is visited exactly once (cycle-safe). + *

+ * This is a no-op when the graph is disabled or the root node is absent. + * + * @param fn the visitor function; receives the current node and the context map + * @param ctx the mutable context map; values from earlier nodes are available to later ones + */ + public void traverse(BiFunction, Object> fn, + Map ctx) { + AgentGraphNode root = rootNode(); + if (root == null) { + return; + } + + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + visited.add(root.getKey()); + queue.add(root); + + while (!queue.isEmpty()) { + AgentGraphNode node = queue.poll(); + Object result = fn.apply(node, ctx); + ctx.put(node.getKey(), result); + + for (AgentGraphNode child : getChildNodes(node.getKey())) { + if (visited.add(child.getKey())) { + queue.add(child); + } + } + } + } + + /** + * Performs a reverse BFS traversal of the graph, starting from terminal nodes and working + * upward toward the root. + *

+ * The root node is always processed last. Each node is visited exactly once (cycle-safe). This + * is a no-op when the graph is disabled or there are no terminal nodes. + * + * @param fn the visitor function; receives the current node and the context map + * @param ctx the mutable context map; values from earlier nodes are available to later ones + */ + public void reverseTraverse(BiFunction, Object> fn, + Map ctx) { + AgentGraphNode root = rootNode(); + if (root == null) { + return; + } + + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + + // Seed from terminals, excluding root (it will be processed last). + for (AgentGraphNode terminal : terminalNodes()) { + if (!terminal.getKey().equals(root.getKey()) && visited.add(terminal.getKey())) { + queue.add(terminal); + } + } + + while (!queue.isEmpty()) { + AgentGraphNode node = queue.poll(); + Object result = fn.apply(node, ctx); + ctx.put(node.getKey(), result); + + for (AgentGraphNode parent : getParentNodes(node.getKey())) { + if (!parent.getKey().equals(root.getKey()) && visited.add(parent.getKey())) { + queue.add(parent); + } + } + } + + // Process root last (whether or not it was encountered as a parent above). + if (visited.add(root.getKey())) { + Object result = fn.apply(root, ctx); + ctx.put(root.getKey(), result); + } + } + + /** + * Builds the node map from the parsed flag value and pre-fetched agent configs. + *

+ * For each key in {@link #collectAllKeys}, looks up the agent config from {@code configs} and + * the outgoing edges from the flag value's edge map. Returns an unmodifiable map. + * + * @param flagValue the parsed flag value + * @param configs the pre-fetched agent configs keyed by node key + * @return an unmodifiable map of nodes keyed by config key + */ + static Map buildNodes( + AgentGraphFlagValue flagValue, Map configs) { + Set allKeys = collectAllKeys(flagValue); + Map result = new HashMap<>(); + for (String key : allKeys) { + AIAgentConfig config = configs.get(key); + if (config == null) { + continue; + } + List edges = flagValue.getEdges().get(key); + if (edges == null) { + edges = Collections.emptyList(); + } + result.put(key, new AgentGraphNode(key, config, edges)); + } + return Collections.unmodifiableMap(result); + } + + /** + * Collects all unique node keys referenced anywhere in the flag value: the root key, all edge + * source keys, and all edge target keys. + * + * @param flagValue the parsed flag value + * @return the set of all unique node keys + */ + static Set collectAllKeys(AgentGraphFlagValue flagValue) { + Set keys = new HashSet<>(); + String root = flagValue.getRoot(); + if (root != null && !root.isEmpty()) { + keys.add(root); + } + for (Map.Entry> entry : flagValue.getEdges().entrySet()) { + keys.add(entry.getKey()); + for (GraphEdge edge : entry.getValue()) { + keys.add(edge.getKey()); + } + } + return keys; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java new file mode 100644 index 00000000..bea5cc05 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AgentGraphNode.java @@ -0,0 +1,59 @@ +package com.launchdarkly.sdk.server.ai; + +import java.util.Collections; +import java.util.List; + +/** + * A node in an {@link AgentGraphDefinition}, wrapping a single agent config and its outgoing + * edges to other nodes. + *

+ * Nodes are retrieved from a graph definition via {@link AgentGraphDefinition#getNode(String)}, + * {@link AgentGraphDefinition#rootNode()}, etc. Instances are immutable. + */ +public final class AgentGraphNode { + private final String key; + private final AIAgentConfig config; + private final List edges; + + AgentGraphNode(String key, AIAgentConfig config, List edges) { + this.key = key; + this.config = config; + this.edges = edges == null ? Collections.emptyList() : edges; + } + + /** + * Returns the AI Config key identifying this node. + * + * @return the node key; never {@code null} + */ + public String getKey() { + return key; + } + + /** + * Returns the retrieved agent config for this node. + * + * @return the agent config; never {@code null} + */ + public AIAgentConfig getConfig() { + return config; + } + + /** + * Returns the outgoing edges from this node to other nodes. + * + * @return an unmodifiable list of outgoing edges; never {@code null} but may be empty + */ + public List getEdges() { + return edges; + } + + /** + * Returns {@code true} if this node has no outgoing edges. + * + * @return {@code true} if terminal (no edges), {@code false} otherwise + */ + public boolean isTerminal() { + return edges.isEmpty(); + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java new file mode 100644 index 00000000..84cc0477 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java @@ -0,0 +1,43 @@ +package com.launchdarkly.sdk.server.ai; + +import com.launchdarkly.sdk.LDValue; + +import java.util.Collections; +import java.util.Map; + +/** + * An edge in an agent graph, representing a directed connection from one node to a target node. + *

+ * Each edge carries the key of the target {@link AgentGraphNode} and an optional handoff map of + * arbitrary data that may be passed when transitioning to the target node. Instances are immutable. + */ +public final class GraphEdge { + private final String key; + private final Map handoff; + + public GraphEdge(String key, Map handoff) { + this.key = key; + this.handoff = handoff; + } + + /** + * Returns the key of the target node that this edge points to. + * + * @return the target node key; never {@code null} + */ + public String getKey() { + return key; + } + + /** + * Returns the handoff options for this edge. + *

+ * The handoff is an optional map of arbitrary values that may be passed when transitioning + * to the target node. If no handoff was defined for this edge, returns {@code null}. + * + * @return an unmodifiable map of handoff values, or {@code null} if absent + */ + public Map getHandoff() { + return handoff; + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index a2a0f76b..a4cac371 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -2,6 +2,7 @@ import com.launchdarkly.sdk.LDContext; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -96,4 +97,45 @@ AIJudgeConfig judgeConfig( * @throws IllegalArgumentException if the token is malformed */ LDAIConfigTracker createTracker(String resumptionToken, LDContext context); + + /** + * Fetches and validates an agent graph definition identified by {@code graphKey}. + *

+ * Evaluates the graph flag, fetches all referenced node configs, and validates the graph + * structure. If validation fails (disabled flag, empty root, unreachable nodes, or any + * non-enabled child config) the returned definition has {@link AgentGraphDefinition#isEnabled()} + * {@code == false} and an empty node map. + *

+ * Also emits a {@code $ld:ai:usage:agent-graph} usage event. + * + * @param graphKey the flag key identifying the agent graph + * @param context the evaluation context + * @param variables Mustache template variables applied to each node's instructions + * @return the resolved graph definition; never {@code null} + */ + AgentGraphDefinition agentGraph(String graphKey, LDContext context, Map variables); + + /** + * Fetches and validates an agent graph definition with no template variables. + * + * @param graphKey the flag key identifying the agent graph + * @param context the evaluation context + * @return the resolved graph definition; never {@code null} + */ + default AgentGraphDefinition agentGraph(String graphKey, LDContext context) { + return agentGraph(graphKey, context, Collections.emptyMap()); + } + + /** + * Reconstructs an {@link AIGraphTracker} from a resumption token. + *

+ * Use this to continue tracking a graph run across requests by passing the token produced by + * {@link AIGraphTracker#getResumptionToken()}. + * + * @param resumptionToken the token produced by a prior {@link AIGraphTracker} + * @param context the evaluation context + * @return a reconstructed tracker with the original run identity + * @throws IllegalArgumentException if the token is malformed + */ + AIGraphTracker createGraphTracker(String resumptionToken, LDContext context); } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 8bf81e71..391e9069 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -10,6 +10,7 @@ import com.launchdarkly.sdk.LDValueType; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigFlagValue; import com.launchdarkly.sdk.server.ai.internal.AIConfigParser; import com.launchdarkly.sdk.server.ai.internal.AISdkInfo; @@ -18,10 +19,16 @@ import com.launchdarkly.sdk.server.interfaces.LDClientInterface; import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; import java.util.LinkedHashMap; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Queue; +import java.util.Set; import java.util.UUID; import java.util.function.Supplier; @@ -45,6 +52,7 @@ public final class LDAIClientImpl implements LDAIClient { private static final String TRACK_USAGE_AGENT_CONFIG = "$ld:ai:usage:agent-config"; private static final String TRACK_USAGE_AGENT_CONFIGS = "$ld:ai:usage:agent-configs"; private static final String TRACK_USAGE_JUDGE_CONFIG = "$ld:ai:usage:judge-config"; + private static final String TRACK_USAGE_AGENT_GRAPH = "$ld:ai:usage:agent-graph"; private static final LDContext INIT_TRACK_CONTEXT = LDContext .builder("ld-internal-tracking") @@ -110,21 +118,26 @@ public AIAgentConfig agentConfig( @Override public Map agentConfigs( List agentConfigs, LDContext context) { - Map result = new LinkedHashMap<>(); int count = 0; if (agentConfigs != null) { for (AIAgentConfigRequest request : agentConfigs) { - if (request == null) { - continue; + if (request != null) { + count++; } - count++; - result.put( - request.getKey(), - evaluateAgent(request.getKey(), context, request.getDefaultValue(), request.getVariables())); } } client.trackMetric(TRACK_USAGE_AGENT_CONFIGS, context, LDValue.of(count), count); + Map result = new LinkedHashMap<>(); + if (agentConfigs != null) { + for (AIAgentConfigRequest request : agentConfigs) { + if (request != null) { + result.put( + request.getKey(), + evaluateAgent(request.getKey(), context, request.getDefaultValue(), request.getVariables())); + } + } + } return result; } @@ -142,9 +155,15 @@ public AIJudgeConfig judgeConfig( private AIAgentConfig evaluateAgent( String key, LDContext context, AIAgentConfigDefault defaultValue, Map variables) { + return evaluateAgent(key, context, defaultValue, variables, null); + } + + private AIAgentConfig evaluateAgent( + String key, LDContext context, AIAgentConfigDefault defaultValue, + Map variables, String graphKey) { AIAgentConfigDefault effectiveDefault = defaultValue != null ? defaultValue : AIAgentConfigDefault.disabled(); - return (AIAgentConfig) evaluate(key, context, effectiveDefault, Mode.AGENT, variables); + return (AIAgentConfig) evaluate(key, context, effectiveDefault, Mode.AGENT, variables, graphKey); } /** @@ -158,13 +177,23 @@ private AIConfig evaluate( AIConfigDefault defaultValue, Mode mode, Map variables) { + return evaluate(key, context, defaultValue, mode, variables, null); + } + + private AIConfig evaluate( + String key, + LDContext context, + AIConfigDefault defaultValue, + Mode mode, + Map variables, + String graphKey) { LDValue value = client.jsonValueVariation(key, context, LDValue.ofNull()); // A valid AI Config variation is always a JSON object (it carries the _ldMeta block). When the // flag is absent or cannot be evaluated the base SDK hands back our null sentinel; in that case // we return the caller's typed default directly rather than serializing it and parsing it back. if (value == null || value.getType() != LDValueType.OBJECT) { - return buildConfigFromDefault(key, mode, defaultValue, context, variables); + return buildConfigFromDefault(key, mode, defaultValue, context, variables, graphKey); } AIConfigFlagValue parsed = AIConfigParser.parse(value); @@ -174,10 +203,10 @@ private AIConfig evaluate( logger.warn( "AI Config mode mismatch for {}: expected {}, got {}. Returning default config.", key, mode.getWireValue(), flagMode.getWireValue()); - return buildConfigFromDefault(key, mode, defaultValue, context, variables); + return buildConfigFromDefault(key, mode, defaultValue, context, variables, graphKey); } - return buildConfig(key, mode, parsed, context, variables); + return buildConfig(key, mode, parsed, context, variables, graphKey); } private AIConfig buildConfig( @@ -186,9 +215,19 @@ private AIConfig buildConfig( AIConfigFlagValue parsed, LDContext context, Map variables) { + return buildConfig(key, mode, parsed, context, variables, null); + } + + private AIConfig buildConfig( + String key, + Mode mode, + AIConfigFlagValue parsed, + LDContext context, + Map variables, + String graphKey) { Supplier factory = trackerFactory( key, parsed.getVariationKey(), parsed.getVersion(), - parsed.getModel(), parsed.getProvider(), context); + parsed.getModel(), parsed.getProvider(), context, graphKey); switch (mode) { case AGENT: return new AIAgentConfig( @@ -233,9 +272,19 @@ private AIConfig buildConfigFromDefault( AIConfigDefault defaultValue, LDContext context, Map variables) { + return buildConfigFromDefault(key, mode, defaultValue, context, variables, null); + } + + private AIConfig buildConfigFromDefault( + String key, + Mode mode, + AIConfigDefault defaultValue, + LDContext context, + Map variables, + String graphKey) { // Default configs still get real trackers — the configKey was requested even if no flag was found. // variationKey is null because no flag evaluation occurred. - Supplier factory = trackerFactory(key, null, null, null, null, context); + Supplier factory = trackerFactory(key, null, null, null, null, context, graphKey); switch (mode) { case AGENT: { AIAgentConfigDefault agent = (AIAgentConfigDefault) defaultValue; @@ -287,6 +336,17 @@ private Supplier trackerFactory( com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, LDContext context) { + return trackerFactory(configKey, variationKey, version, model, provider, context, null); + } + + private Supplier trackerFactory( + String configKey, + String variationKey, + Integer version, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model model, + com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Provider provider, + LDContext context, + String graphKey) { String modelName = model != null && model.getName() != null ? model.getName() : ""; String providerName = provider != null && provider.getName() != null ? provider.getName() : ""; int ver = version != null ? version : 1; @@ -299,8 +359,92 @@ private Supplier trackerFactory( modelName, providerName, context, - null, // graphKey — set by agentGraph() in Plan 3 + graphKey, + logger); + } + + @Override + public AgentGraphDefinition agentGraph( + String graphKey, LDContext context, Map variables) { + client.trackMetric(TRACK_USAGE_AGENT_GRAPH, context, LDValue.of(graphKey), 1); + + LDValue flagValue = client.jsonValueVariation(graphKey, context, LDValue.ofNull()); + AgentGraphFlagValue parsed = AgentGraphFlagValue.parse( + (flagValue != null && flagValue.getType() == LDValueType.OBJECT) ? flagValue : null); + + Map effectiveVars = variables != null ? variables : Collections.emptyMap(); + Supplier trackerFactory = () -> new AIGraphTracker( + client, + UUID.randomUUID().toString(), + graphKey, + parsed.getVariationKey(), + parsed.getVersion(), + context, logger); + + AgentGraphDefinition disabled = new AgentGraphDefinition( + parsed, Collections.emptyMap(), false, trackerFactory); + + // Validation step 1: _ldMeta.enabled + if (!parsed.isEnabled()) { + return disabled; + } + + // Validation step 2: root must be non-empty + String root = parsed.getRoot(); + if (root == null || root.isEmpty()) { + return disabled; + } + + // Validation step 3: all keys reachable from root (no unconnected nodes) + Set allKeys = AgentGraphDefinition.collectAllKeys(parsed); + Set reachableKeys = collectReachableKeys(parsed); + for (String key : allKeys) { + if (!reachableKeys.contains(key)) { + return disabled; + } + } + + // Validation step 4: fetch each child config (without emitting usage events) + Map configs = new HashMap<>(); + for (String key : allKeys) { + AIAgentConfig config = evaluateAgent(key, context, null, effectiveVars, graphKey); + if (!config.isEnabled()) { + return disabled; + } + configs.put(key, config); + } + + Map nodes = AgentGraphDefinition.buildNodes(parsed, configs); + return new AgentGraphDefinition(parsed, nodes, true, trackerFactory); + } + + private static Set collectReachableKeys(AgentGraphFlagValue flagValue) { + Set visited = new HashSet<>(); + Queue queue = new LinkedList<>(); + String root = flagValue.getRoot(); + if (root == null || root.isEmpty()) { + return visited; + } + visited.add(root); + queue.add(root); + while (!queue.isEmpty()) { + String key = queue.poll(); + List edges = flagValue.getEdges().get(key); + if (edges != null) { + for (GraphEdge edge : edges) { + if (visited.add(edge.getKey())) { + queue.add(edge.getKey()); + } + } + } + } + return visited; + } + + @Override + public AIGraphTracker createGraphTracker(String resumptionToken, LDContext context) { + return AIGraphTracker.fromResumptionToken(resumptionToken, client, context); } @Override diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java new file mode 100644 index 00000000..335983d1 --- /dev/null +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValue.java @@ -0,0 +1,220 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.LDValueType; +import com.launchdarkly.sdk.server.ai.GraphEdge; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** + * The parsed, strongly-typed representation of an agent graph flag variation's JSON protocol. + *

+ * Mirrors the wire structure: the {@code _ldMeta} block (enabled / variationKey / version) plus + * the {@code root} config key and the {@code edges} adjacency map. Produced by {@link #parse} and + * consumed when assembling {@link com.launchdarkly.sdk.server.ai.AgentGraphDefinition}. + *

+ * Parsing is intentionally defensive: malformed, missing, or wrong-typed fields never raise an + * exception. When {@code _ldMeta.enabled} is absent the default is {@code true}; when + * {@code _ldMeta.version} is absent the default is {@code 1}. + *

+ * This class is an internal implementation detail and is not part of the supported API. + */ +public final class AgentGraphFlagValue { + private static final int DEFAULT_VERSION = 1; + + private final String root; + private final Map> edges; + private final String variationKey; + private final int version; + private final boolean enabled; + + private AgentGraphFlagValue(Builder b) { + this.root = b.root; + this.edges = b.edges == null + ? Collections.>emptyMap() + : Collections.unmodifiableMap(b.edges); + this.variationKey = b.variationKey; + this.version = b.version; + this.enabled = b.enabled; + } + + /** + * Returns a disabled flag value with empty root and no edges. Used when the raw flag value is + * not a JSON object or when validation fails. + * + * @return a disabled {@link AgentGraphFlagValue} + */ + public static AgentGraphFlagValue disabled() { + return new Builder().enabled(false).build(); + } + + /** + * Parses a raw {@link LDValue} flag variation into a strongly-typed {@link AgentGraphFlagValue}. + *

+ * Returns {@link #disabled()} when {@code value} is not a JSON object. + * + * @param value the raw flag value; may be {@code null} or any JSON type + * @return the parsed representation; never {@code null} + */ + public static AgentGraphFlagValue parse(LDValue value) { + if (value == null || value.getType() != LDValueType.OBJECT) { + return disabled(); + } + + Builder builder = new Builder(); // defaults: enabled=true, version=1 + + LDValue meta = value.get("_ldMeta"); + if (meta != null && meta.getType() == LDValueType.OBJECT) { + LDValue enabledVal = meta.get("enabled"); + if (enabledVal.getType() == LDValueType.BOOLEAN) { + builder.enabled(enabledVal.booleanValue()); + } + LDValue variationKeyVal = meta.get("variationKey"); + if (variationKeyVal.getType() == LDValueType.STRING) { + builder.variationKey(variationKeyVal.stringValue()); + } + LDValue versionVal = meta.get("version"); + if (versionVal.getType() == LDValueType.NUMBER) { + builder.version(versionVal.intValue()); + } + } + + LDValue rootVal = value.get("root"); + if (rootVal.getType() == LDValueType.STRING) { + builder.root(rootVal.stringValue()); + } + + LDValue edgesVal = value.get("edges"); + if (edgesVal.getType() == LDValueType.OBJECT) { + Map> edges = new LinkedHashMap<>(); + for (String sourceKey : edgesVal.keys()) { + LDValue edgeArray = edgesVal.get(sourceKey); + if (edgeArray.getType() != LDValueType.ARRAY) { + continue; + } + List edgeList = new ArrayList<>(); + for (LDValue edgeObj : edgeArray.values()) { + if (edgeObj == null || edgeObj.getType() != LDValueType.OBJECT) { + continue; + } + LDValue keyVal = edgeObj.get("key"); + if (keyVal.getType() != LDValueType.STRING) { + continue; + } + String targetKey = keyVal.stringValue(); + Map handoff = parseHandoff(edgeObj.get("handoff")); + edgeList.add(new GraphEdge(targetKey, handoff)); + } + edges.put(sourceKey, Collections.unmodifiableList(edgeList)); + } + builder.edges(edges); + } + + return builder.build(); + } + + private static Map parseHandoff(LDValue handoff) { + if (handoff == null || handoff.getType() != LDValueType.OBJECT) { + return null; + } + Map result = new LinkedHashMap<>(); + for (String key : handoff.keys()) { + result.put(key, handoff.get(key)); + } + return result.isEmpty() ? null : Collections.unmodifiableMap(result); + } + + /** + * Returns the root node's AI Config key. + * + * @return the root key; never {@code null}, but may be empty when not specified + */ + public String getRoot() { + return root; + } + + /** + * Returns the adjacency map of outgoing edges keyed by source node config key. + * + * @return an unmodifiable map; never {@code null} but may be empty + */ + public Map> getEdges() { + return edges; + } + + /** + * Returns the {@code _ldMeta.variationKey}. + * + * @return the variation key, or {@code null} if absent + */ + public String getVariationKey() { + return variationKey; + } + + /** + * Returns the {@code _ldMeta.version}, defaulting to 1 when absent. + * + * @return the version; always >= 1 + */ + public int getVersion() { + return version; + } + + /** + * Returns {@code true} if the graph is enabled. Defaults to {@code true} when + * {@code _ldMeta.enabled} is absent. + * + * @return whether the graph is enabled + */ + public boolean isEnabled() { + return enabled; + } + + static Builder builder() { + return new Builder(); + } + + static final class Builder { + private String root = ""; + private Map> edges; + private String variationKey; + private int version = DEFAULT_VERSION; + private boolean enabled = true; + + private Builder() { + } + + Builder root(String v) { + this.root = v; + return this; + } + + Builder edges(Map> v) { + this.edges = v; + return this; + } + + Builder variationKey(String v) { + this.variationKey = v; + return this; + } + + Builder version(int v) { + this.version = v; + return this; + } + + Builder enabled(boolean v) { + this.enabled = v; + return this; + } + + AgentGraphFlagValue build() { + return new AgentGraphFlagValue(this); + } + } +} diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index de3885c2..d64910fc 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -13,7 +13,7 @@ *

* This class is an internal implementation detail and is not part of the supported API. */ -final class ResumptionTokens { +public final class ResumptionTokens { private static final int MAX_TOKEN_BYTES = 4096; private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); @@ -253,6 +253,167 @@ static String escapeJson(String s) { return sb.toString(); } + /** + * Encodes a graph-level resumption token from the given graph identity fields. + *

+ * Field order in the JSON: {@code runId}, {@code graphKey}, {@code variationKey} (omitted if + * {@code null}), {@code version}. + * + * @param runId the run ID + * @param graphKey the agent graph key + * @param variationKey the variation key, or {@code null} to omit + * @param version the graph version + * @return the URL-safe Base64-encoded token + */ + public static String encodeGraph(String runId, String graphKey, String variationKey, int version) { + StringBuilder sb = new StringBuilder(); + sb.append("{\"runId\":\"").append(escapeJson(runId)).append('"'); + sb.append(",\"graphKey\":\"").append(escapeJson(graphKey)).append('"'); + if (variationKey != null) { + sb.append(",\"variationKey\":\"").append(escapeJson(variationKey)).append('"'); + } + sb.append(",\"version\":").append(version); + sb.append('}'); + return ENCODER.encodeToString(sb.toString().getBytes(StandardCharsets.UTF_8)); + } + + /** + * Decodes a graph resumption token previously produced by {@link #encodeGraph}. + * + * @param token the URL-safe Base64 token + * @return the decoded fields + * @throws IllegalArgumentException if the token is malformed, oversized, or missing required fields + */ + public static DecodedGraph decodeGraph(String token) { + if (token == null) { + throw new IllegalArgumentException("Graph resumption token must not be null"); + } + if (token.length() > MAX_TOKEN_BYTES) { + throw new IllegalArgumentException("Graph resumption token exceeds maximum length of " + MAX_TOKEN_BYTES + " bytes"); + } + + String json; + try { + byte[] bytes = DECODER.decode(token); + json = new String(bytes, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Graph resumption token is not valid Base64: " + e.getMessage(), e); + } + + return parseGraphJson(json); + } + + private static DecodedGraph parseGraphJson(String json) { + json = json.trim(); + if (!json.startsWith("{") || !json.endsWith("}")) { + throw new IllegalArgumentException("Graph resumption token JSON must be an object"); + } + + String runId = null; + String graphKey = null; + String variationKey = null; + Integer version = null; + + int pos = 1; + while (pos < json.length() - 1) { + pos = skipWhitespace(json, pos); + if (pos >= json.length() - 1) { + break; + } + if (json.charAt(pos) == ',') { + pos++; + pos = skipWhitespace(json, pos); + } + if (pos >= json.length() - 1) { + break; + } + + if (json.charAt(pos) != '"') { + throw new IllegalArgumentException("Expected '\"' at position " + pos + " in graph resumption token"); + } + int[] keyEnd = new int[1]; + String key = readString(json, pos, keyEnd); + pos = keyEnd[0]; + + pos = skipWhitespace(json, pos); + if (pos >= json.length() || json.charAt(pos) != ':') { + throw new IllegalArgumentException("Expected ':' after key in graph resumption token"); + } + pos++; + pos = skipWhitespace(json, pos); + + if (json.charAt(pos) == '"') { + int[] valEnd = new int[1]; + String value = readString(json, pos, valEnd); + pos = valEnd[0]; + switch (key) { + case "runId": runId = value; break; + case "graphKey": graphKey = value; break; + case "variationKey": variationKey = value; break; + default: break; + } + } else { + int start = pos; + while (pos < json.length() && json.charAt(pos) != ',' && json.charAt(pos) != '}') { + pos++; + } + String numStr = json.substring(start, pos).trim(); + if ("version".equals(key)) { + try { + version = Integer.parseInt(numStr); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("Field 'version' must be an integer in graph resumption token", e); + } + } + } + } + + if (runId == null) { + throw new IllegalArgumentException("Graph resumption token missing required field 'runId'"); + } + if (graphKey == null) { + throw new IllegalArgumentException("Graph resumption token missing required field 'graphKey'"); + } + if (version == null) { + throw new IllegalArgumentException("Graph resumption token missing required field 'version'"); + } + + return new DecodedGraph(runId, graphKey, variationKey, version); + } + + /** + * The decoded fields from a graph resumption token. + */ + public static final class DecodedGraph { + private final String runId; + private final String graphKey; + private final String variationKey; + private final int version; + + public DecodedGraph(String runId, String graphKey, String variationKey, int version) { + this.runId = runId; + this.graphKey = graphKey; + this.variationKey = variationKey; + this.version = version; + } + + public String getRunId() { + return runId; + } + + public String getGraphKey() { + return graphKey; + } + + public String getVariationKey() { + return variationKey; + } + + public int getVersion() { + return version; + } + } + /** * The decoded fields from a resumption token. */ diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java new file mode 100644 index 00000000..7d2b2fb7 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java @@ -0,0 +1,480 @@ +package com.launchdarkly.sdk.server.ai; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import com.launchdarkly.logging.LDLogLevel; +import com.launchdarkly.logging.LDLogger; +import com.launchdarkly.logging.LogCapture; +import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.LDContext; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +@SuppressWarnings("javadoc") +public class AIGraphTrackerTest { + private LDClientInterface client; + private LogCapture logCapture; + private LDLogger logger; + private AIGraphTracker tracker; + + private static final LDContext CONTEXT = LDContext.create("user-key"); + private static final String RUN_ID = "test-run-id"; + private static final String GRAPH_KEY = "my-graph"; + private static final String VARIATION_KEY = "var-abc"; + private static final int VERSION = 2; + + @Before + public void setUp() { + client = mock(LDClientInterface.class); + logCapture = Logs.capture(); + logger = LDLogger.withAdapter(logCapture, "test"); + tracker = makeTracker(VARIATION_KEY); + } + + private AIGraphTracker makeTracker(String variationKey) { + return new AIGraphTracker(client, RUN_ID, GRAPH_KEY, variationKey, VERSION, CONTEXT, logger); + } + + private List warnings() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + + private LDValue baseExpectedData() { + return LDValue.buildObject() + .put("runId", RUN_ID) + .put("graphKey", GRAPH_KEY) + .put("variationKey", VARIATION_KEY) + .put("version", VERSION) + .build(); + } + + // ---- trackInvocationSuccess ----------------------------------------------- + + @Test + public void trackInvocationSuccessEmitsCorrectEvent() { + tracker.trackInvocationSuccess(); + verify(client).trackMetric( + eq("$ld:ai:graph:invocation_success"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackInvocationSuccessIsAtMostOnce() { + tracker.trackInvocationSuccess(); + tracker.trackInvocationSuccess(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + } + + // ---- trackInvocationFailure ----------------------------------------------- + + @Test + public void trackInvocationFailureEmitsCorrectEvent() { + tracker.trackInvocationFailure(); + verify(client).trackMetric( + eq("$ld:ai:graph:invocation_failure"), eq(CONTEXT), eq(baseExpectedData()), eq(1.0)); + } + + @Test + public void trackInvocationFailureIsAtMostOnce() { + tracker.trackInvocationFailure(); + tracker.trackInvocationFailure(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + } + + @Test + public void successAndFailureShareGuard_successFirst() { + tracker.trackInvocationSuccess(); + tracker.trackInvocationFailure(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + } + + @Test + public void successAndFailureShareGuard_failureFirst() { + tracker.trackInvocationFailure(); + tracker.trackInvocationSuccess(); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + } + + // ---- trackDuration -------------------------------------------------------- + + @Test + public void trackDurationEmitsCorrectEvent() { + tracker.trackDuration(250.0); + verify(client).trackMetric( + eq("$ld:ai:graph:duration:total"), eq(CONTEXT), eq(baseExpectedData()), eq(250.0)); + } + + @Test + public void trackDurationIsAtMostOnce() { + tracker.trackDuration(100.0); + tracker.trackDuration(200.0); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:duration:total"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("duration already recorded")), is(true)); + } + + // ---- trackTotalTokens ----------------------------------------------------- + + @Test + public void trackTotalTokensEmitsCorrectEvent() { + TokenUsage tokens = new TokenUsage(30, 20, 10); + tracker.trackTotalTokens(tokens); + verify(client).trackMetric( + eq("$ld:ai:graph:total_tokens"), eq(CONTEXT), eq(baseExpectedData()), eq(30.0)); + } + + @Test + public void trackTotalTokensIsAtMostOnce() { + tracker.trackTotalTokens(new TokenUsage(10, 5, 5)); + tracker.trackTotalTokens(new TokenUsage(20, 10, 10)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("token usage already recorded")), is(true)); + } + + @Test + public void trackTotalTokensAllZeroDoesNotBurnSlot() { + tracker.trackTotalTokens(new TokenUsage(0, 0, 0)); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + // Slot not consumed — a subsequent non-zero call should fire + tracker.trackTotalTokens(new TokenUsage(5, 5, 0)); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + } + + @Test + public void trackTotalTokensNullIsIgnored() { + tracker.trackTotalTokens(null); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("tokens was null")), is(true)); + } + + // ---- trackPath ------------------------------------------------------------ + + @Test + public void trackPathEmitsCorrectEvent() { + List path = Arrays.asList("node-a", "node-b", "node-c"); + tracker.trackPath(path); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:path"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("graphKey").stringValue(), is(GRAPH_KEY)); + assertThat(data.get("path").size(), is(3)); + assertThat(data.get("path").get(0).stringValue(), is("node-a")); + assertThat(data.get("path").get(1).stringValue(), is("node-b")); + assertThat(data.get("path").get(2).stringValue(), is("node-c")); + } + + @Test + public void trackPathIsAtMostOnce() { + tracker.trackPath(Arrays.asList("node-a", "node-b")); + tracker.trackPath(Arrays.asList("node-c")); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + assertThat(warnings().stream().anyMatch(w -> w.contains("path already recorded")), is(true)); + } + + @Test + public void trackPathNullOrEmptyIsIgnored() { + tracker.trackPath(null); + tracker.trackPath(Arrays.asList()); + verify(client, never()).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + // Slot not consumed — a valid path should still fire + tracker.trackPath(Arrays.asList("node-a")); + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:path"), any(), any(), anyDouble()); + } + + // ---- trackRedirect -------------------------------------------------------- + + @Test + public void trackRedirectEmitsCorrectEvent() { + tracker.trackRedirect("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:redirect"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("redirectedTarget").stringValue(), is("target-b")); + assertThat(data.get("graphKey").stringValue(), is(GRAPH_KEY)); + } + + @Test + public void trackRedirectIsMultiFire() { + tracker.trackRedirect("source-a", "target-b"); + tracker.trackRedirect("source-a", "target-c"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:redirect"), any(), any(), anyDouble()); + } + + // ---- trackHandoffSuccess -------------------------------------------------- + + @Test + public void trackHandoffSuccessEmitsCorrectEvent() { + tracker.trackHandoffSuccess("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:handoff_success"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("targetKey").stringValue(), is("target-b")); + } + + @Test + public void trackHandoffSuccessIsMultiFire() { + tracker.trackHandoffSuccess("source-a", "target-b"); + tracker.trackHandoffSuccess("source-b", "target-c"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:handoff_success"), any(), any(), anyDouble()); + } + + // ---- trackHandoffFailure -------------------------------------------------- + + @Test + public void trackHandoffFailureEmitsCorrectEvent() { + tracker.trackHandoffFailure("source-a", "target-b"); + + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric( + eq("$ld:ai:graph:handoff_failure"), eq(CONTEXT), dataCaptor.capture(), eq(1.0)); + + LDValue data = dataCaptor.getValue(); + assertThat(data.get("sourceKey").stringValue(), is("source-a")); + assertThat(data.get("targetKey").stringValue(), is("target-b")); + } + + @Test + public void trackHandoffFailureIsMultiFire() { + tracker.trackHandoffFailure("source-a", "target-b"); + tracker.trackHandoffFailure("source-c", "target-d"); + verify(client, times(2)).trackMetric( + eq("$ld:ai:graph:handoff_failure"), any(), any(), anyDouble()); + } + + // ---- getSummary ----------------------------------------------------------- + + @Test + public void getSummaryNullWhenNothingTracked() { + AIGraphMetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(nullValue())); + assertThat(summary.getDurationMs(), is(nullValue())); + assertThat(summary.getTokens(), is(nullValue())); + assertThat(summary.getPath(), is(nullValue())); + assertThat(summary.getResumptionToken(), is(notNullValue())); + } + + @Test + public void getSummaryReflectsTrackedValues() { + tracker.trackInvocationSuccess(); + tracker.trackDuration(150.0); + tracker.trackTotalTokens(new TokenUsage(40, 25, 15)); + tracker.trackPath(Arrays.asList("a", "b", "c")); + + AIGraphMetricSummary summary = tracker.getSummary(); + assertThat(summary.getSuccess(), is(Boolean.TRUE)); + assertThat(summary.getDurationMs(), is(150.0)); + assertThat(summary.getTokens().getTotal(), is(40L)); + assertThat(summary.getPath(), is(Arrays.asList("a", "b", "c"))); + assertThat(summary.getResumptionToken(), is(tracker.getResumptionToken())); + } + + @Test + public void getSummarySuccessIsFalseWhenFailureTracked() { + tracker.trackInvocationFailure(); + assertThat(tracker.getSummary().getSuccess(), is(Boolean.FALSE)); + } + + // ---- variationKey in track data ------------------------------------------ + + @Test + public void variationKeyIncludedWhenPresent() { + tracker.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("variationKey").stringValue(), is(VARIATION_KEY)); + } + + @Test + public void variationKeyOmittedWhenNull() { + AIGraphTracker t = makeTracker(null); + t.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("variationKey").isNull(), is(true)); + } + + // ---- resumption token ---------------------------------------------------- + + @Test + public void getResumptionTokenIsNotNull() { + assertThat(tracker.getResumptionToken(), is(notNullValue())); + } + + @Test + public void fromResumptionTokenRoundTrips() { + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + assertThat(reconstructed.getResumptionToken(), is(token)); + } + + @Test + public void fromResumptionTokenPreservesRunId() { + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + // Verify same events are emitted by the reconstructed tracker + reconstructed.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("runId").stringValue(), is(RUN_ID)); + assertThat(captor.getValue().get("graphKey").stringValue(), is(GRAPH_KEY)); + } + + @Test + public void fromResumptionTokenClampsVersionLessThanOne() { + AIGraphTracker t = new AIGraphTracker(client, RUN_ID, GRAPH_KEY, null, 0, CONTEXT, logger); + // Version 0 → token contains 0, but fromResumptionToken should clamp to 1 + // Actually: the tracker stores version as-is, but fromResumptionToken clamps + // Encode a token with version = 0 manually: + String token = com.launchdarkly.sdk.server.ai.internal.ResumptionTokens.encodeGraph( + RUN_ID, GRAPH_KEY, null, 0); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT); + reconstructed.trackInvocationSuccess(); + ArgumentCaptor captor = ArgumentCaptor.forClass(LDValue.class); + verify(client).trackMetric(eq("$ld:ai:graph:invocation_success"), any(), captor.capture(), anyDouble()); + assertThat(captor.getValue().get("version").intValue(), is(1)); + } + + // ---- constructor null checks --------------------------------------------- + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullClient() { + new AIGraphTracker(null, RUN_ID, GRAPH_KEY, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullRunId() { + new AIGraphTracker(client, null, GRAPH_KEY, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullGraphKey() { + new AIGraphTracker(client, RUN_ID, null, VARIATION_KEY, VERSION, CONTEXT, logger); + } + + @Test(expected = NullPointerException.class) + public void constructorRejectsNullContext() { + new AIGraphTracker(client, RUN_ID, GRAPH_KEY, VARIATION_KEY, VERSION, null, logger); + } + + // ---- concurrency: at-most-once under contention ------------------------- + + @Test + public void trackInvocationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + final int idx = i; + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + if (idx % 2 == 0) { + tracker.trackInvocationSuccess(); + } else { + tracker.trackInvocationFailure(); + } + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + // Exactly one invocation event fires total + int successCount = 0, failureCount = 0; + try { + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_success"), any(), any(), anyDouble()); + successCount = 1; + } catch (AssertionError ignored) {} + try { + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:invocation_failure"), any(), any(), anyDouble()); + failureCount = 1; + } catch (AssertionError ignored) {} + assertThat(successCount + failureCount, is(1)); + } + + @Test + public void trackDurationAtMostOnceUnderConcurrency() throws InterruptedException { + int threads = 20; + CountDownLatch ready = new CountDownLatch(threads); + CountDownLatch go = new CountDownLatch(1); + ExecutorService exec = Executors.newFixedThreadPool(threads); + + for (int i = 0; i < threads; i++) { + exec.submit(() -> { + ready.countDown(); + try { go.await(); } catch (InterruptedException ignored) {} + tracker.trackDuration(100.0); + }); + } + + ready.await(); + go.countDown(); + exec.shutdown(); + exec.awaitTermination(5, TimeUnit.SECONDS); + + verify(client, times(1)).trackMetric( + eq("$ld:ai:graph:duration:total"), any(), any(), anyDouble()); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java new file mode 100644 index 00000000..938f198c --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AgentGraphDefinitionTest.java @@ -0,0 +1,445 @@ +package com.launchdarkly.sdk.server.ai; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.Mockito.mock; + +import com.launchdarkly.sdk.ArrayBuilder; +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; +import com.launchdarkly.sdk.server.ai.internal.AgentGraphFlagValue; +import com.launchdarkly.sdk.server.interfaces.LDClientInterface; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.function.BiFunction; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class AgentGraphDefinitionTest { + + // ---- helpers -------------------------------------------------------------- + + private static AgentGraphFlagValue flagValue(String root, String[][] edges) { + ObjectBuilder edgesObj = LDValue.buildObject(); + if (edges != null) { + Map> adj = new LinkedHashMap<>(); + for (String[] edge : edges) { + if (!adj.containsKey(edge[0])) { + adj.put(edge[0], new ArrayList<>()); + } + adj.get(edge[0]).add(edge[1]); + } + for (Map.Entry> entry : adj.entrySet()) { + ArrayBuilder arr = LDValue.buildArray(); + for (String target : entry.getValue()) { + arr.add(LDValue.buildObject().put("key", target).build()); + } + edgesObj.put(entry.getKey(), arr.build()); + } + } + LDValue value = LDValue.buildObject() + .put("root", root) + .put("edges", edgesObj.build()) + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .put("version", 1) + .build()) + .build(); + return AgentGraphFlagValue.parse(value); + } + + private static AIAgentConfig makeConfig(String key, boolean enabled) { + return new AIAgentConfig(key, enabled, null, null, null, null, null, + () -> mock(com.launchdarkly.sdk.server.ai.LDAIConfigTracker.class)); + } + + private static Map configs(String... keys) { + Map m = new HashMap<>(); + for (String key : keys) { + m.put(key, makeConfig(key, true)); + } + return m; + } + + private AgentGraphDefinition buildEnabled(String root, String[][] edges, String... nodeKeys) { + AgentGraphFlagValue fv = flagValue(root, edges); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs(nodeKeys)); + return new AgentGraphDefinition(fv, nodes, true, null); + } + + // ---- collectAllKeys ------------------------------------------------------- + + @Test + public void collectAllKeysIncludesRoot() { + AgentGraphFlagValue fv = flagValue("root-node", null); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys.contains("root-node"), is(true)); + } + + @Test + public void collectAllKeysIncludesEdgeSourcesAndTargets() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "c"}}); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, containsInAnyOrder("a", "b", "c")); + } + + @Test + public void collectAllKeysWithNoEdges() { + AgentGraphFlagValue fv = flagValue("solo", null); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, containsInAnyOrder("solo")); + } + + @Test + public void collectAllKeysEmptyRootIsExcluded() { + AgentGraphFlagValue fv = AgentGraphFlagValue.disabled(); + Set keys = AgentGraphDefinition.collectAllKeys(fv); + assertThat(keys, is(empty())); + } + + // ---- buildNodes ----------------------------------------------------------- + + @Test + public void buildNodesCreatesCorrectNodeMap() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}}); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a", "b")); + assertThat(nodes.size(), is(2)); + assertThat(nodes.get("a").getKey(), is("a")); + assertThat(nodes.get("b").getKey(), is("b")); + } + + @Test + public void buildNodesAttachesEdgesToNodes() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"a", "c"}}); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a", "b", "c")); + List edges = nodes.get("a").getEdges(); + assertThat(edges.size(), is(2)); + } + + @Test + public void buildNodesSkipsMissingConfigs() { + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}}); + // Only provide config for "a", not "b" + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a")); + assertThat(nodes.size(), is(1)); + assertThat(nodes.containsKey("a"), is(true)); + assertThat(nodes.containsKey("b"), is(false)); + } + + // ---- rootNode / getNode -------------------------------------------------- + + @Test + public void rootNodeReturnsCorrectNode() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + AgentGraphNode root = graph.rootNode(); + assertThat(root, is(notNullValue())); + assertThat(root.getKey(), is("a")); + } + + @Test + public void rootNodeReturnsNullWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(graph.rootNode(), is(nullValue())); + } + + @Test + public void getNodeReturnsCorrectNode() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getNode("b").getKey(), is("b")); + } + + @Test + public void getNodeReturnsNullForUnknownKey() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.getNode("not-here"), is(nullValue())); + } + + // ---- isTerminal ---------------------------------------------------------- + + @Test + public void terminalNodeHasNoEdges() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getNode("b").isTerminal(), is(true)); + assertThat(graph.getNode("a").isTerminal(), is(false)); + } + + @Test + public void singleNodeGraphIsTerminal() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.rootNode().isTerminal(), is(true)); + } + + // ---- getChildNodes ------------------------------------------------------- + + @Test + public void getChildNodesFollowsEdges() { + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + List children = graph.getChildNodes("a"); + assertThat(children.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : children) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("b", "c")); + } + + @Test + public void getChildNodesReturnsEmptyForTerminal() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getChildNodes("b"), is(empty())); + } + + @Test + public void getChildNodesReturnsEmptyForUnknownKey() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.getChildNodes("no-such-key"), is(empty())); + } + + // ---- getParentNodes ------------------------------------------------------ + + @Test + public void getParentNodesFindsDirectParents() { + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "c"}, {"b", "c"}}, "a", "b", "c"); + List parents = graph.getParentNodes("c"); + assertThat(parents.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : parents) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("a", "b")); + } + + @Test + public void getParentNodesReturnsEmptyForRoot() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + assertThat(graph.getParentNodes("a"), is(empty())); + } + + // ---- terminalNodes ------------------------------------------------------- + + @Test + public void terminalNodesReturnsAllTerminals() { + // a -> b, a -> c; b and c are terminals + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + List terminals = graph.terminalNodes(); + assertThat(terminals.size(), is(2)); + Set keys = new HashSet<>(); + for (AgentGraphNode n : terminals) keys.add(n.getKey()); + assertThat(keys, containsInAnyOrder("b", "c")); + } + + @Test + public void terminalNodesWithSingleNodeIncludesRoot() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + assertThat(graph.terminalNodes().size(), is(1)); + assertThat(graph.terminalNodes().get(0).getKey(), is("a")); + } + + // ---- isEnabled ----------------------------------------------------------- + + @Test + public void isEnabledReflectsConstructorValue() { + AgentGraphDefinition enabled = buildEnabled("a", null, "a"); + assertThat(enabled.isEnabled(), is(true)); + + AgentGraphDefinition disabled = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(disabled.isEnabled(), is(false)); + } + + // ---- createTracker ------------------------------------------------------- + + @Test + public void createTrackerReturnsNullWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + assertThat(graph.createTracker(), is(nullValue())); + } + + @Test + public void createTrackerReturnsTrackerWhenEnabled() { + LDClientInterface client = mock(LDClientInterface.class); + AgentGraphFlagValue fv = flagValue("a", null); + Map nodes = AgentGraphDefinition.buildNodes(fv, configs("a")); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, + () -> new AIGraphTracker(client, "run-id", "graph-key", null, 1, + com.launchdarkly.sdk.LDContext.create("user"), + com.launchdarkly.logging.LDLogger.withAdapter( + com.launchdarkly.logging.Logs.none(), ""))); + assertThat(graph.createTracker(), is(notNullValue())); + } + + // ---- traverse ------------------------------------------------------------- + + @Test + public void traverseVisitsAllNodesFromRoot() { + // a -> b -> c + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"b", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + BiFunction, Object> fn = (node, ctx) -> { + visited.add(node.getKey()); + return node.getKey(); + }; + graph.traverse(fn, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a", "b", "c")); + // Root must be first + assertThat(visited.get(0), is("a")); + } + + @Test + public void traverseStoresResultsInContext() { + AgentGraphDefinition graph = buildEnabled("a", new String[][]{{"a", "b"}}, "a", "b"); + + Map ctx = new HashMap<>(); + BiFunction, Object> fn = (node, c) -> node.getKey() + "_result"; + graph.traverse(fn, ctx); + + assertThat(ctx.get("a"), is("a_result")); + assertThat(ctx.get("b"), is("b_result")); + } + + @Test + public void traverseIsNoOpWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, is(empty())); + } + + @Test + public void traverseHandlesCyclesSafely() { + // Manually build a cyclic graph: a -> b -> a + LDClientInterface client = mock(LDClientInterface.class); + Map cfgs = configs("a", "b"); + List aEdges = Collections.singletonList(new GraphEdge("b", null)); + List bEdges = Collections.singletonList(new GraphEdge("a", null)); + Map nodes = new HashMap<>(); + nodes.put("a", new AgentGraphNode("a", cfgs.get("a"), aEdges)); + nodes.put("b", new AgentGraphNode("b", cfgs.get("b"), bEdges)); + nodes = Collections.unmodifiableMap(nodes); + + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "a"}}); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, null); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited.size(), is(2)); // each node visited exactly once + } + + // ---- reverseTraverse ------------------------------------------------------ + + @Test + public void reverseTraverseProcessesRootLast() { + // a -> b -> c + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"b", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + + // c is terminal (seeded first), root "a" is last + assertThat(visited.get(visited.size() - 1), is("a")); + assertThat(visited.contains("b"), is(true)); + assertThat(visited.contains("c"), is(true)); + } + + @Test + public void reverseTraverseVisitsAllNodes() { + // a -> b, a -> c (c and b are terminals) + AgentGraphDefinition graph = buildEnabled("a", + new String[][]{{"a", "b"}, {"a", "c"}}, "a", "b", "c"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a", "b", "c")); + assertThat(visited.get(visited.size() - 1), is("a")); + } + + @Test + public void reverseTraverseSingleNodeGraph() { + AgentGraphDefinition graph = buildEnabled("a", null, "a"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, containsInAnyOrder("a")); + } + + @Test + public void reverseTraverseHandlesCyclesSafely() { + Map cfgs = configs("a", "b"); + List aEdges = Collections.singletonList(new GraphEdge("b", null)); + List bEdges = Collections.singletonList(new GraphEdge("a", null)); + Map nodes = new HashMap<>(); + nodes.put("a", new AgentGraphNode("a", cfgs.get("a"), aEdges)); + nodes.put("b", new AgentGraphNode("b", cfgs.get("b"), bEdges)); + nodes = Collections.unmodifiableMap(nodes); + + AgentGraphFlagValue fv = flagValue("a", new String[][]{{"a", "b"}, {"b", "a"}}); + AgentGraphDefinition graph = new AgentGraphDefinition(fv, nodes, true, null); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // No infinite loop; in a pure cycle neither node is terminal, so no seeds are added — + // only root is processed in the final "root last" block. + assertThat(visited.size() <= 2, is(true)); + assertThat(visited.size() >= 1, is(true)); + } + + @Test + public void reverseTraverseIsNoOpWhenDisabled() { + AgentGraphDefinition graph = new AgentGraphDefinition( + AgentGraphFlagValue.disabled(), Collections.emptyMap(), false, null); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + assertThat(visited, is(empty())); + } + + // ---- diamond graph traversal ---------------------------------------------- + + @Test + public void traverseDiamondGraph() { + // root -> a, root -> b; a -> sink, b -> sink + AgentGraphDefinition graph = buildEnabled("root", + new String[][]{{"root", "a"}, {"root", "b"}, {"a", "sink"}, {"b", "sink"}}, + "root", "a", "b", "sink"); + + List visited = new ArrayList<>(); + graph.traverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // root first, sink visited only once + assertThat(visited.get(0), is("root")); + assertThat(visited.size(), is(4)); + assertThat(new HashSet<>(visited).size(), is(4)); // all unique + } + + @Test + public void reverseTraverseDiamondGraph() { + AgentGraphDefinition graph = buildEnabled("root", + new String[][]{{"root", "a"}, {"root", "b"}, {"a", "sink"}, {"b", "sink"}}, + "root", "a", "b", "sink"); + + List visited = new ArrayList<>(); + graph.reverseTraverse((node, ctx) -> { visited.add(node.getKey()); return null; }, new HashMap<>()); + // root last, sink visited once + assertThat(visited.get(visited.size() - 1), is("root")); + assertThat(visited.size(), is(4)); + assertThat(new HashSet<>(visited).size(), is(4)); + } +} diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java index 9adc857a..85c2b81c 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/LDAIClientImplTest.java @@ -10,9 +10,11 @@ import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyDouble; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -20,8 +22,10 @@ import com.launchdarkly.logging.LDLogger; import com.launchdarkly.logging.LogCapture; import com.launchdarkly.logging.Logs; +import com.launchdarkly.sdk.ArrayBuilder; import com.launchdarkly.sdk.LDContext; import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.ObjectBuilder; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Mode; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Message; import com.launchdarkly.sdk.server.ai.datamodel.LDAIConfigTypes.Model; @@ -30,6 +34,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -292,4 +297,191 @@ public void agentConfigsUsageCountExcludesNullEntries() { private static Map variables() { return new HashMap<>(); } + + // ---- agentGraph ----------------------------------------------------------- + + private static LDValue graphFlagValue(String root, boolean enabled, String variationKey, + String... edges) { + ObjectBuilder edgesObj = LDValue.buildObject(); + // edges are pairs: [source, target, source, target, ...] + Map edgeMap = new LinkedHashMap<>(); + for (int i = 0; i + 1 < edges.length; i += 2) { + String src = edges[i], tgt = edges[i + 1]; + if (!edgeMap.containsKey(src)) { + edgeMap.put(src, LDValue.buildArray()); + } + edgeMap.get(src).add(LDValue.buildObject().put("key", tgt).build()); + } + for (Map.Entry e : edgeMap.entrySet()) { + edgesObj.put(e.getKey(), e.getValue().build()); + } + ObjectBuilder meta = LDValue.buildObject() + .put("enabled", enabled) + .put("version", 1); + if (variationKey != null) { + meta.put("variationKey", variationKey); + } + return LDValue.buildObject() + .put("root", root) + .put("edges", edgesObj.build()) + .put("_ldMeta", meta.build()) + .build(); + } + + private static LDValue agentFlagValue(boolean enabled) { + return LDValue.parse("{\"_ldMeta\":{\"enabled\":" + enabled + ",\"mode\":\"agent\"}," + + "\"instructions\":\"test instructions\"}"); + } + + @Test + public void agentGraphFiresUsageEvent() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "v1", "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(true)); + + ai.agentGraph("g", context, null); + + verify(client).trackMetric( + eq("$ld:ai:usage:agent-graph"), eq(context), eq(LDValue.of("g")), eq(1.0)); + } + + @Test + public void agentGraphDoesNotFireNodeLevelUsageEvents() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null)); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + ai.agentGraph("g", context, null); + + verify(client, never()).trackMetric(eq("$ld:ai:usage:agent-config"), any(), any(), anyDouble()); + } + + @Test + public void agentGraphReturnsEnabledGraphForValidFlag() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "v1", "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(true)); + assertThat(graph.rootNode(), is(notNullValue())); + assertThat(graph.rootNode().getKey(), is("node-a")); + assertThat(graph.getNode("node-b"), is(notNullValue())); + } + + @Test + public void agentGraphReturnsDisabledWhenFlagDisabled() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", false, null)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenRootMissing() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("", true, null)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenUnreachableNode() { + // node-a -> node-b, but flag also declares node-c which is unreachable from root + LDValue flag = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .build()) + .put("node-c", LDValue.buildArray() // unreachable source + .add(LDValue.buildObject().put("key", "node-d").build()) + .build()) + .build()) + .put("_ldMeta", LDValue.buildObject().put("enabled", true).build()) + .build(); + when(client.jsonValueVariation(eq("g"), any(), any())).thenReturn(flag); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledWhenAnyChildConfigDisabled() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null, "node-a", "node-b")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + when(client.jsonValueVariation(eq("node-b"), any(), any())).thenReturn(agentFlagValue(false)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphReturnsDisabledForNonObjectFlagValue() { + when(client.jsonValueVariation(eq("g"), any(), any())).thenReturn(LDValue.ofNull()); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + + assertThat(graph.isEnabled(), is(false)); + } + + @Test + public void agentGraphNoVariablesOverloadCallsThreeArgVersion() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, null)); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context); + assertThat(graph.isEnabled(), is(true)); + } + + @Test + public void agentGraphChildConfigsIncludeGraphKey() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "var-1")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + assertThat(graph.isEnabled(), is(true)); + + // Verify: when a node tracker is created, graphKey is present in its track data + LDAIConfigTracker nodeTracker = graph.getNode("node-a").getConfig().createTracker(); + assertThat(nodeTracker.getTrackData().getGraphKey(), is("g")); + } + + // ---- createGraphTracker -------------------------------------------------- + + @Test + public void createGraphTrackerRoundTripsToken() { + when(client.jsonValueVariation(eq("g"), any(), any())) + .thenReturn(graphFlagValue("node-a", true, "var-1")); + when(client.jsonValueVariation(eq("node-a"), any(), any())).thenReturn(agentFlagValue(true)); + + AgentGraphDefinition graph = ai.agentGraph("g", context, null); + AIGraphTracker original = graph.createTracker(); + assertThat(original, is(notNullValue())); + + String token = original.getResumptionToken(); + AIGraphTracker reconstructed = ai.createGraphTracker(token, context); + assertThat(reconstructed.getResumptionToken(), is(token)); + } + + @Test(expected = IllegalArgumentException.class) + public void createGraphTrackerThrowsForMalformedToken() { + ai.createGraphTracker("not-valid-base64!!!", context); + } + + private static java.util.LinkedHashMap newLinkedHashMap() { + return new java.util.LinkedHashMap<>(); + } } diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java new file mode 100644 index 00000000..cc690c21 --- /dev/null +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/AgentGraphFlagValueTest.java @@ -0,0 +1,276 @@ +package com.launchdarkly.sdk.server.ai.internal; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.nullValue; + +import com.launchdarkly.sdk.LDValue; +import com.launchdarkly.sdk.server.ai.GraphEdge; + +import java.util.List; +import java.util.Map; + +import org.junit.Test; + +@SuppressWarnings("javadoc") +public class AgentGraphFlagValueTest { + + // ---- disabled() factory --------------------------------------------------- + + @Test + public void disabledReturnsEnabledFalse() { + AgentGraphFlagValue v = AgentGraphFlagValue.disabled(); + assertThat(v.isEnabled(), is(false)); + assertThat(v.getRoot(), is("")); + assertThat(v.getEdges().isEmpty(), is(true)); + assertThat(v.getVariationKey(), is(nullValue())); + assertThat(v.getVersion(), is(1)); + } + + // ---- parse: non-object input ---------------------------------------------- + + @Test + public void parseNullReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(null); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parseStringReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.of("not-an-object")); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parseArrayReturnsDisabled() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildArray().add("x").build()); + assertThat(v.isEnabled(), is(false)); + } + + // ---- parse: defaults when fields absent ---------------------------------- + + @Test + public void parseEmptyObjectUsesDefaults() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildObject().build()); + assertThat(v.isEnabled(), is(true)); // default true for graphs + assertThat(v.getVersion(), is(1)); // default 1 + assertThat(v.getRoot(), is("")); // empty string default + assertThat(v.getEdges().isEmpty(), is(true)); + assertThat(v.getVariationKey(), is(nullValue())); + } + + // ---- parse: _ldMeta fields ----------------------------------------------- + + @Test + public void parsesEnabledFalseFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("enabled", false) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(false)); + } + + @Test + public void parsesEnabledTrueFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); + } + + @Test + public void parsesVariationKeyFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("variationKey", "var-xyz") + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVariationKey(), is("var-xyz")); + } + + @Test + public void parsesVersionFromMeta() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject() + .put("version", 5) + .build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVersion(), is(5)); + } + + @Test + public void missingMetaVersionDefaultsToOne() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.buildObject().build()) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getVersion(), is(1)); + } + + @Test + public void nonObjectMetaIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("_ldMeta", LDValue.of("not-an-object")) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); // fallback to default + assertThat(v.getVersion(), is(1)); + } + + // ---- parse: root --------------------------------------------------------- + + @Test + public void parsesRootString() { + LDValue value = LDValue.buildObject() + .put("root", "entry-node") + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getRoot(), is("entry-node")); + } + + @Test + public void missingRootIsEmptyString() { + AgentGraphFlagValue v = AgentGraphFlagValue.parse(LDValue.buildObject().build()); + assertThat(v.getRoot(), is("")); + } + + @Test + public void nonStringRootIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", 42) + .build(); + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getRoot(), is("")); + } + + // ---- parse: edges -------------------------------------------------------- + + @Test + public void parsesEdgesMap() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + Map> edges = v.getEdges(); + assertThat(edges.containsKey("node-a"), is(true)); + List aEdges = edges.get("node-a"); + assertThat(aEdges.size(), is(1)); + assertThat(aEdges.get(0).getKey(), is("node-b")); + assertThat(aEdges.get(0).getHandoff(), is(nullValue())); + } + + @Test + public void parsesEdgeWithHandoff() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject() + .put("key", "node-b") + .put("handoff", LDValue.buildObject() + .put("someData", LDValue.of("hello")) + .build()) + .build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + GraphEdge edge = v.getEdges().get("node-a").get(0); + assertThat(edge.getKey(), is("node-b")); + assertThat(edge.getHandoff(), is(notNullValue())); + assertThat(edge.getHandoff().get("someData").stringValue(), is("hello")); + } + + @Test + public void edgeMissingKeyIsSkipped() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("notKey", "node-b").build()) + .build()) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().get("node-a"), is(empty())); + } + + @Test + public void edgeWithNonArrayValueIsSkipped() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.of("not-an-array")) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().containsKey("node-a"), is(false)); + } + + @Test + public void nonObjectEdgesFieldIsIgnored() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.of("bad")) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.getEdges().isEmpty(), is(true)); + } + + // ---- parse: full round-trip ---------------------------------------------- + + @Test + public void parsesFullFlagValue() { + LDValue value = LDValue.buildObject() + .put("root", "node-a") + .put("edges", LDValue.buildObject() + .put("node-a", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-b").build()) + .add(LDValue.buildObject().put("key", "node-c").build()) + .build()) + .put("node-b", LDValue.buildArray() + .add(LDValue.buildObject().put("key", "node-c").build()) + .build()) + .build()) + .put("_ldMeta", LDValue.buildObject() + .put("enabled", true) + .put("variationKey", "var-1") + .put("version", 2) + .build()) + .build(); + + AgentGraphFlagValue v = AgentGraphFlagValue.parse(value); + assertThat(v.isEnabled(), is(true)); + assertThat(v.getRoot(), is("node-a")); + assertThat(v.getVariationKey(), is("var-1")); + assertThat(v.getVersion(), is(2)); + assertThat(v.getEdges().get("node-a").size(), is(2)); + assertThat(v.getEdges().get("node-b").size(), is(1)); + } +} From bed4ca29cc121c3099036af566d2d87e78f0dc30 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:40:33 -0500 Subject: [PATCH 04/23] guard against null AIMetrics --- .../sdk/server/ai/internal/LDAIConfigTrackerImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index 3c7faf6a..2a99266a 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -318,7 +318,7 @@ public T trackMetricsOf( // Extractor exceptions propagate to the caller — do NOT catch them here. // Do NOT call trackError() on extractor failure; the AI operation itself succeeded. - AIMetrics metrics = metricsExtractor.apply(result); + AIMetrics metrics = Objects.requireNonNull(metricsExtractor.apply(result), "metricsExtractor returned null"); // Duration: prefer runner-reported value (§1.1.13.2), fall back to wall-clock. if (metrics.getDurationMs() != null) { From 2b47c86bc419d5dc36a2bb63040715d86b78a254 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:44:05 -0500 Subject: [PATCH 05/23] fix: guard against blank metricKey and infinite/invalid score --- .../sdk/server/ai/datamodel/LDAITrackingTypes.java | 12 ++++++++++-- .../server/ai/internal/LDAIConfigTrackerImpl.java | 4 ++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java index 534e3aed..1f93d958 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -423,10 +423,14 @@ public Builder sampled(boolean sampled) { /** * Sets the metric key. * - * @param metricKey the metric key; may be {@code null} + * @param metricKey the metric key; may be {@code null}, but must not be blank if non-null * @return this builder + * @throws IllegalArgumentException if {@code metricKey} is non-null and blank */ public Builder metricKey(String metricKey) { + if (metricKey != null && metricKey.isBlank()) { + throw new IllegalArgumentException("metricKey must not be blank"); + } this.metricKey = metricKey; return this; } @@ -434,10 +438,14 @@ public Builder metricKey(String metricKey) { /** * Sets the judge score. * - * @param score the score; may be {@code null} + * @param score the score; may be {@code null}, but must be finite if non-null * @return this builder + * @throws IllegalArgumentException if {@code score} is non-null and non-finite (NaN or infinite) */ public Builder score(Double score) { + if (score != null && !Double.isFinite(score)) { + throw new IllegalArgumentException("score must be finite"); + } this.score = score; return this; } diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index 2a99266a..9f20af71 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -284,10 +284,10 @@ public void trackJudgeResult(JudgeResult result) { if (!result.isSuccess()) { return; } - if (result.getMetricKey() == null) { + if (result.getMetricKey() == null || result.getMetricKey().isBlank()) { return; } - if (result.getScore() == null) { + if (result.getScore() == null || !Double.isFinite(result.getScore())) { return; } ObjectBuilder data = baseData(); From 4ef3de2a2b95198cb465c82a2b842b47bcdd052f Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:46:41 -0500 Subject: [PATCH 06/23] fix: MAX_TOKEN_BYTES -> MAX_TOKEN_LENGTH --- .../sdk/server/ai/internal/ResumptionTokens.java | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index de3885c2..12fbdca9 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -14,7 +14,7 @@ * This class is an internal implementation detail and is not part of the supported API. */ final class ResumptionTokens { - private static final int MAX_TOKEN_BYTES = 4096; + private static final int MAX_TOKEN_LENGTH = 4096; private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); @@ -61,8 +61,8 @@ static Decoded decode(String token) { if (token == null) { throw new IllegalArgumentException("Resumption token must not be null"); } - if (token.length() > MAX_TOKEN_BYTES) { - throw new IllegalArgumentException("Resumption token exceeds maximum length of " + MAX_TOKEN_BYTES + " bytes"); + if (token.length() > MAX_TOKEN_LENGTH) { + throw new IllegalArgumentException("Resumption token exceeds maximum length of " + MAX_TOKEN_LENGTH + " characters"); } String json; From 1be0a1e838b310000083eefd707e5223a2f6bd6d Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:51:42 -0500 Subject: [PATCH 07/23] fix: guard against empty runId and configKey --- .../launchdarkly/sdk/server/ai/internal/ResumptionTokens.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index 12fbdca9..a5fb04df 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -150,10 +150,10 @@ private static Decoded parseJson(String json) { } } - if (runId == null) { + if (runId == null || runId.isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'runId'"); } - if (configKey == null) { + if (configKey == null || configKey.isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'configKey'"); } if (version == null) { From 8e81ea0b09a3dbb01d91f637b67afbcf01409e57 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 16:53:43 -0500 Subject: [PATCH 08/23] fix: Add warning comment to createTracker public call --- .../main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java index a2a0f76b..77031cf4 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClient.java @@ -89,6 +89,10 @@ AIJudgeConfig judgeConfig( * stores the resumption token from a previous tracker (via * {@link LDAIConfigTracker#getResumptionToken()}) and passes it back here to continue tracking * against the same run. + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. * * @param resumptionToken the token returned by a previous tracker; must not be {@code null} * @param context the evaluation context for the new request; must not be {@code null} From e81e2f58244fca423c095324ff06905e26948812 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 17:03:29 -0500 Subject: [PATCH 09/23] fix: use trim + isEmpty to support java 8 --- .../launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java | 2 +- .../sdk/server/ai/internal/LDAIConfigTrackerImpl.java | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java index 1f93d958..54a186d2 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -428,7 +428,7 @@ public Builder sampled(boolean sampled) { * @throws IllegalArgumentException if {@code metricKey} is non-null and blank */ public Builder metricKey(String metricKey) { - if (metricKey != null && metricKey.isBlank()) { + if (metricKey != null && metricKey.trim().isEmpty()) { throw new IllegalArgumentException("metricKey must not be blank"); } this.metricKey = metricKey; diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index 9f20af71..e5e7c40d 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -284,7 +284,7 @@ public void trackJudgeResult(JudgeResult result) { if (!result.isSuccess()) { return; } - if (result.getMetricKey() == null || result.getMetricKey().isBlank()) { + if (result.getMetricKey() == null || result.getMetricKey().trim().isEmpty()) { return; } if (result.getScore() == null || !Double.isFinite(result.getScore())) { From c21fdd7bf74b0ded60f7cb79cefff38b10bd055d Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 17:09:03 -0500 Subject: [PATCH 10/23] fix: stop trackMetricsOf clock before running metrics extractor --- .../ai/internal/LDAIConfigTrackerImpl.java | 5 +++-- .../ai/internal/LDAIConfigTrackerImplTest.java | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index e5e7c40d..e32cd61e 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -315,6 +315,8 @@ public T trackMetricsOf( trackError(); throw e; } + // Capture operation duration immediately so a slow extractor does not inflate the metric. + long operationElapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); // Extractor exceptions propagate to the caller — do NOT catch them here. // Do NOT call trackError() on extractor failure; the AI operation itself succeeded. @@ -324,8 +326,7 @@ public T trackMetricsOf( if (metrics.getDurationMs() != null) { trackDuration(Duration.ofMillis(metrics.getDurationMs())); } else { - long elapsed = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); - trackDuration(Duration.ofMillis(elapsed)); + trackDuration(Duration.ofMillis(operationElapsedMs)); } if (metrics.isSuccess()) { diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java index 9594c33c..96c69972 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -501,6 +501,23 @@ public void trackMetricsOfUsesRunnerReportedDurationWhenPresent() throws Excepti verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), eq(999.0)); } + @Test + public void trackMetricsOfWallClockDurationExcludesSlowExtractor() throws Exception { + // Operation returns immediately; extractor sleeps. Recorded duration must reflect only the + // operation, not the extractor work. + long extractorSleepMs = 200L; + AIMetrics metrics = AIMetrics.builder().success(true).build(); + tracker.trackMetricsOf( + r -> { try { Thread.sleep(extractorSleepMs); } catch (InterruptedException ie) { Thread.currentThread().interrupt(); } return metrics; }, + () -> "ok"); + ArgumentCaptor durationCaptor = ArgumentCaptor.forClass(Double.class); + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), durationCaptor.capture()); + assertThat( + "wall-clock duration must not include extractor time", + durationCaptor.getValue() < (double) extractorSleepMs / 2, + is(true)); + } + @Test public void trackMetricsOfTracksErrorAndRethrowsOnOperationException() { try { From 4c96dcaad70974cdc1f334490d138c302a059650 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 17:14:52 -0500 Subject: [PATCH 11/23] fix: record operation duration when trackMetricsOf extractor throws --- .../server/ai/internal/LDAIConfigTrackerImpl.java | 13 ++++++++++--- .../ai/internal/LDAIConfigTrackerImplTest.java | 12 ++++++++++++ 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index e32cd61e..ce2e280c 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -318,9 +318,16 @@ public T trackMetricsOf( // Capture operation duration immediately so a slow extractor does not inflate the metric. long operationElapsedMs = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start); - // Extractor exceptions propagate to the caller — do NOT catch them here. - // Do NOT call trackError() on extractor failure; the AI operation itself succeeded. - AIMetrics metrics = Objects.requireNonNull(metricsExtractor.apply(result), "metricsExtractor returned null"); + // Extractor exceptions propagate to the caller, but the operation's duration must still be + // recorded — the AI operation itself succeeded, only the user-supplied extractor failed. + // Do NOT call trackError(); that signals the operation failed, which is not what happened. + AIMetrics metrics; + try { + metrics = Objects.requireNonNull(metricsExtractor.apply(result), "metricsExtractor returned null"); + } catch (RuntimeException e) { + trackDuration(Duration.ofMillis(operationElapsedMs)); + throw e; + } // Duration: prefer runner-reported value (§1.1.13.2), fall back to wall-clock. if (metrics.getDurationMs() != null) { diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java index 96c69972..4ad7d770 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -545,6 +545,18 @@ public void trackMetricsOfExtractorExceptionPropagatesAndDoesNotCallTrackError() verify(client, never()).trackMetric(eq("$ld:ai:generation:success"), any(), any(), anyDouble()); } + @Test + public void trackMetricsOfRecordsDurationWhenExtractorThrows() { + try { + tracker.trackMetricsOf( + r -> { throw new RuntimeException("extractor failed"); }, + () -> "ok"); + } catch (Exception e) { + // expected; we care that duration was recorded before the throw + } + verify(client).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); + } + @Test public void trackMetricsOfTracksToolCalls() throws Exception { AIMetrics metrics = AIMetrics.builder() From 4da54789cf34a447abb1ecfae358c1eee8ee2b4a Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Tue, 23 Jun 2026 17:29:34 -0500 Subject: [PATCH 12/23] fix: downgrade null-arg track logs from warn to debug per spec --- .../ai/internal/LDAIConfigTrackerImpl.java | 12 +++--- .../internal/LDAIConfigTrackerImplTest.java | 38 +++++++++++++------ 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java index ce2e280c..766d59e7 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImpl.java @@ -155,7 +155,7 @@ public String getResumptionToken() { @Override public void trackDuration(Duration duration) { if (duration == null) { - logger.warn("Skipping trackDuration: duration was null."); + logger.debug("Skipping trackDuration: duration was null."); return; } long ms = Math.max(0L, duration.toMillis()); @@ -181,7 +181,7 @@ public T trackDurationOf(Callable operation) throws Exception { @Override public void trackTimeToFirstToken(Duration duration) { if (duration == null) { - logger.warn("Skipping trackTimeToFirstToken: duration was null."); + logger.debug("Skipping trackTimeToFirstToken: duration was null."); return; } long ms = Math.max(0L, duration.toMillis()); @@ -213,7 +213,7 @@ public void trackError() { @Override public void trackFeedback(FeedbackKind kind) { if (kind == null) { - logger.warn("Skipping trackFeedback: kind was null."); + logger.debug("Skipping trackFeedback: kind was null."); return; } // Resolve event name BEFORE claiming the guard — an exception here must not burn the slot. @@ -228,7 +228,7 @@ public void trackFeedback(FeedbackKind kind) { @Override public void trackTokens(TokenUsage tokens) { if (tokens == null) { - logger.warn("Skipping trackTokens: tokens was null."); + logger.debug("Skipping trackTokens: tokens was null."); return; } boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; @@ -254,7 +254,7 @@ public void trackTokens(TokenUsage tokens) { @Override public void trackToolCall(String toolKey) { if (toolKey == null) { - logger.warn("Skipping trackToolCall: toolKey was null."); + logger.debug("Skipping trackToolCall: toolKey was null."); return; } toolCalls.add(toolKey); @@ -275,7 +275,7 @@ public void trackToolCalls(List toolKeys) { @Override public void trackJudgeResult(JudgeResult result) { if (result == null) { - logger.warn("Skipping trackJudgeResult: result was null."); + logger.debug("Skipping trackJudgeResult: result was null."); return; } if (!result.isSampled()) { diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java index 4ad7d770..ffcadaa9 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/LDAIConfigTrackerImplTest.java @@ -3,6 +3,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; @@ -82,6 +83,13 @@ private List warnings() { .collect(Collectors.toList()); } + private List debugs() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.DEBUG) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + private LDValue baseExpectedData() { return LDValue.buildObject() .put("runId", RUN_ID) @@ -172,10 +180,11 @@ public void trackDurationAtMostOnce() { } @Test - public void trackDurationNullIsIgnoredWithWarning() { + public void trackDurationNullIsIgnoredWithDebugLog() { tracker.trackDuration(null); verify(client, never()).trackMetric(eq("$ld:ai:duration:total"), any(), any(), anyDouble()); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); } // ---- trackDurationOf ------------------------------------------------------ @@ -215,10 +224,11 @@ public void trackTimeToFirstTokenAtMostOnce() { } @Test - public void trackTimeToFirstTokenNullIsIgnoredWithWarning() { + public void trackTimeToFirstTokenNullIsIgnoredWithDebugLog() { tracker.trackTimeToFirstToken(null); verify(client, never()).trackMetric(eq("$ld:ai:tokens:ttf"), any(), any(), anyDouble()); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); } // ---- trackSuccess / trackError -------------------------------------------- @@ -299,9 +309,10 @@ public void trackFeedbackAtMostOnce() { } @Test - public void trackFeedbackNullIsIgnoredWithWarning_slotNotBurned() { + public void trackFeedbackNullIsIgnoredWithDebugLog_slotNotBurned() { tracker.trackFeedback(null); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); // Slot should not be burned — a subsequent valid call should still work tracker.trackFeedback(FeedbackKind.POSITIVE); verify(client, times(1)).trackMetric(eq("$ld:ai:feedback:user:positive"), any(), any(), anyDouble()); @@ -342,10 +353,11 @@ public void trackTokensAtMostOnce() { } @Test - public void trackTokensNullIsIgnoredWithWarning() { + public void trackTokensNullIsIgnoredWithDebugLog() { tracker.trackTokens(null); verify(client, never()).trackMetric(eq("$ld:ai:tokens:total"), any(), any(), anyDouble()); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); } // ---- trackToolCall -------------------------------------------------------- @@ -382,10 +394,11 @@ public void trackToolCallsDelegate() { } @Test - public void trackToolCallNullIsIgnoredWithWarning() { + public void trackToolCallNullIsIgnoredWithDebugLog() { tracker.trackToolCall(null); verify(client, never()).trackMetric(eq("$ld:ai:tool_call"), any(), any(), anyDouble()); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); } // ---- trackJudgeResult ----------------------------------------------------- @@ -470,9 +483,10 @@ public void trackJudgeResultIsNotAtMostOnce() { } @Test - public void trackJudgeResultNullIsIgnoredWithWarning() { + public void trackJudgeResultNullIsIgnoredWithDebugLog() { tracker.trackJudgeResult(null); - assertThat(warnings().size(), greaterThanOrEqualTo(1)); + assertThat(debugs().size(), greaterThanOrEqualTo(1)); + assertThat(warnings(), is(empty())); } // ---- trackMetricsOf ------------------------------------------------------- From 394a04476c4bb8e0129f3ba15afc748e12f0001d Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 15:12:57 -0500 Subject: [PATCH 13/23] fix: remove unnecessary NoOpAIConfigTracker --- .../ai/internal/NoOpAIConfigTracker.java | 98 ------------------- 1 file changed, 98 deletions(-) delete mode 100644 lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java deleted file mode 100644 index a1d65d5a..00000000 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/NoOpAIConfigTracker.java +++ /dev/null @@ -1,98 +0,0 @@ -package com.launchdarkly.sdk.server.ai.internal; - -import com.launchdarkly.sdk.server.ai.LDAIConfigTracker; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.AIMetrics; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.FeedbackKind; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.JudgeResult; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.MetricSummary; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TokenUsage; -import com.launchdarkly.sdk.server.ai.datamodel.LDAITrackingTypes.TrackData; - -import java.time.Duration; -import java.util.List; -import java.util.concurrent.Callable; -import java.util.function.Function; - -/** - * The no-op {@link LDAIConfigTracker} used when tracking is not applicable (for example, for - * disabled configs or in testing contexts). It is immutable and stateless, so a single shared - * instance is safe to reuse. - *

- * This class is an internal implementation detail and is not part of the supported API. - */ -public final class NoOpAIConfigTracker implements LDAIConfigTracker { - /** - * The shared instance. - */ - public static final NoOpAIConfigTracker INSTANCE = new NoOpAIConfigTracker(); - - private static final TrackData EMPTY_TRACK_DATA = new TrackData("", "", null, 0, "", "", null); - private static final MetricSummary EMPTY_SUMMARY = - new MetricSummary(null, null, null, null, null, null, null); - - private NoOpAIConfigTracker() { - } - - @Override - public TrackData getTrackData() { - return EMPTY_TRACK_DATA; - } - - @Override - public String getResumptionToken() { - return null; - } - - @Override - public void trackDuration(Duration duration) { - } - - @Override - public T trackDurationOf(Callable operation) throws Exception { - return operation.call(); - } - - @Override - public void trackTimeToFirstToken(Duration duration) { - } - - @Override - public void trackSuccess() { - } - - @Override - public void trackError() { - } - - @Override - public void trackFeedback(FeedbackKind kind) { - } - - @Override - public void trackTokens(TokenUsage tokens) { - } - - @Override - public void trackToolCall(String toolKey) { - } - - @Override - public void trackToolCalls(List toolKeys) { - } - - @Override - public void trackJudgeResult(JudgeResult result) { - } - - @Override - public T trackMetricsOf( - Function metricsExtractor, - Callable operation) throws Exception { - return operation.call(); - } - - @Override - public MetricSummary getSummary() { - return EMPTY_SUMMARY; - } -} From 5381bf44c716f2192afb6a8ba4a2ef85958475be Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 15:45:23 -0500 Subject: [PATCH 14/23] fix: remove resumption-token length cap --- .../server/ai/internal/ResumptionTokens.java | 6 +----- .../ai/internal/ResumptionTokensTest.java | 20 +++++++++---------- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index a5fb04df..ed15c16a 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -14,7 +14,6 @@ * This class is an internal implementation detail and is not part of the supported API. */ final class ResumptionTokens { - private static final int MAX_TOKEN_LENGTH = 4096; private static final Base64.Encoder ENCODER = Base64.getUrlEncoder().withoutPadding(); private static final Base64.Decoder DECODER = Base64.getUrlDecoder(); @@ -55,15 +54,12 @@ static String encode(String runId, String configKey, String variationKey, * * @param token the URL-safe Base64 token * @return the decoded fields - * @throws IllegalArgumentException if the token is malformed, oversized, or missing required fields + * @throws IllegalArgumentException if the token is malformed or missing required fields */ static Decoded decode(String token) { if (token == null) { throw new IllegalArgumentException("Resumption token must not be null"); } - if (token.length() > MAX_TOKEN_LENGTH) { - throw new IllegalArgumentException("Resumption token exceeds maximum length of " + MAX_TOKEN_LENGTH + " characters"); - } String json; try { diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java index 21a263b1..bed64ca9 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokensTest.java @@ -86,6 +86,16 @@ public void versionIsPreservedOnRoundTrip() { assertThat(ResumptionTokens.decode(token).getVersion(), is(1)); } + // ---- large keys ----------------------------------------------------------- + + @Test + public void roundTripsLongKeys() { + String key = new String(new char[5000]).replace('\0', 'a'); + String token = ResumptionTokens.encode("run", key, null, 1, null); + ResumptionTokens.Decoded d = ResumptionTokens.decode(token); + assertThat(d.getConfigKey(), is(key)); + } + // ---- decode error handling ------------------------------------------------ @Test(expected = IllegalArgumentException.class) @@ -93,16 +103,6 @@ public void decodeRejectsNull() { ResumptionTokens.decode(null); } - @Test(expected = IllegalArgumentException.class) - public void decodeRejectsOversizedToken() { - // Build a token larger than 4096 bytes - String largeValue = new String(new char[5000]).replace('\0', 'x'); - String json = "{\"runId\":\"" + largeValue + "\",\"configKey\":\"c\",\"version\":1}"; - String token = java.util.Base64.getUrlEncoder().withoutPadding() - .encodeToString(json.getBytes(java.nio.charset.StandardCharsets.UTF_8)); - ResumptionTokens.decode(token); - } - @Test(expected = IllegalArgumentException.class) public void decodeRejectsInvalidBase64() { ResumptionTokens.decode("not-valid-base64!!!!"); From 0bb83794eacf73bdc50d747f34f64be6181a0482 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 16:39:11 -0500 Subject: [PATCH 15/23] fix: remove token length cap, defensive copy trackPath, downgrade null-arg logs to debug --- .../com/launchdarkly/sdk/server/ai/AIGraphTracker.java | 7 ++++--- .../sdk/server/ai/internal/ResumptionTokens.java | 4 ---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java index 299143ce..ad182908 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java @@ -12,6 +12,7 @@ import com.launchdarkly.sdk.server.ai.internal.ResumptionTokens; import com.launchdarkly.sdk.server.interfaces.LDClientInterface; +import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Objects; @@ -156,7 +157,7 @@ public void trackDuration(double durationMs) { */ public void trackTotalTokens(TokenUsage tokens) { if (tokens == null) { - logger.warn("Skipping trackTotalTokens: tokens was null."); + logger.debug("Skipping trackTotalTokens: tokens was null."); return; } boolean hasPositive = tokens.getTotal() > 0 || tokens.getInput() > 0 || tokens.getOutput() > 0; @@ -179,10 +180,10 @@ public void trackTotalTokens(TokenUsage tokens) { */ public void trackPath(List path) { if (path == null || path.isEmpty()) { - logger.warn("Skipping trackPath: path was null or empty."); + logger.debug("Skipping trackPath: path was null or empty."); return; } - List snapshot = Collections.unmodifiableList(path); + List snapshot = Collections.unmodifiableList(new ArrayList<>(path)); if (!pathRecorded.compareAndSet(null, snapshot)) { logger.warn("Skipping trackPath: path already recorded on this graph tracker."); return; diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index 4976af19..83ccf707 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -284,10 +284,6 @@ public static DecodedGraph decodeGraph(String token) { if (token == null) { throw new IllegalArgumentException("Graph resumption token must not be null"); } - if (token.length() > MAX_TOKEN_BYTES) { - throw new IllegalArgumentException("Graph resumption token exceeds maximum length of " + MAX_TOKEN_BYTES + " bytes"); - } - String json; try { byte[] bytes = DECODER.decode(token); From d0ae81b5ffc857b97b923dc8a1deded9a6dfb003 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 16:49:05 -0500 Subject: [PATCH 16/23] fix: make GraphEdge.handoff a defensive unmodifiable copy --- .../main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java index 84cc0477..116d3abb 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/GraphEdge.java @@ -3,6 +3,7 @@ import com.launchdarkly.sdk.LDValue; import java.util.Collections; +import java.util.LinkedHashMap; import java.util.Map; /** @@ -17,7 +18,9 @@ public final class GraphEdge { public GraphEdge(String key, Map handoff) { this.key = key; - this.handoff = handoff; + this.handoff = handoff != null + ? Collections.unmodifiableMap(new LinkedHashMap<>(handoff)) + : null; } /** From d22532aa06bbd0d2fd6362b2ee614e37ac801c91 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 16:49:57 -0500 Subject: [PATCH 17/23] fix: don't emit graph total-tokens event when total is zero --- .../java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java index ad182908..646796f6 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java @@ -168,7 +168,9 @@ public void trackTotalTokens(TokenUsage tokens) { logger.warn("Skipping trackTotalTokens: token usage already recorded on this graph tracker."); return; } - client.trackMetric(GRAPH_TOTAL_TOKENS, context, baseData().build(), tokens.getTotal()); + if (tokens.getTotal() > 0) { + client.trackMetric(GRAPH_TOTAL_TOKENS, context, baseData().build(), tokens.getTotal()); + } } /** From 4d6565cf36b7d0f88be49cff767fa9731aa51a0a Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 16:57:19 -0500 Subject: [PATCH 18/23] fix tests --- .../java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java index 7d2b2fb7..7c7889a8 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java @@ -184,7 +184,7 @@ public void trackTotalTokensNullIsIgnored() { tracker.trackTotalTokens(null); verify(client, never()).trackMetric( eq("$ld:ai:graph:total_tokens"), any(), any(), anyDouble()); - assertThat(warnings().stream().anyMatch(w -> w.contains("tokens was null")), is(true)); + assertThat(debugs().stream().anyMatch(w -> w.contains("tokens was null")), is(true)); } // ---- trackPath ------------------------------------------------------------ From 77e49d4566ca08cd62ab8f60bd985ee056db8df3 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 17:02:05 -0500 Subject: [PATCH 19/23] fix: guard against empty runId and graphKey --- .../sdk/server/ai/internal/ResumptionTokens.java | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java index 83ccf707..dc10817c 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/internal/ResumptionTokens.java @@ -146,10 +146,10 @@ private static Decoded parseJson(String json) { } } - if (runId == null || runId.isEmpty()) { + if (runId == null || runId.trim().isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'runId'"); } - if (configKey == null || configKey.isEmpty()) { + if (configKey == null || configKey.trim().isEmpty()) { throw new IllegalArgumentException("Resumption token missing required field 'configKey'"); } if (version == null) { @@ -360,10 +360,10 @@ private static DecodedGraph parseGraphJson(String json) { } } - if (runId == null) { + if (runId == null || runId.trim().isEmpty()) { throw new IllegalArgumentException("Graph resumption token missing required field 'runId'"); } - if (graphKey == null) { + if (graphKey == null || graphKey.trim().isEmpty()) { throw new IllegalArgumentException("Graph resumption token missing required field 'graphKey'"); } if (version == null) { From 3aa5d08e959914d71de6a79329dc2c3acfce7ecd Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 17:10:48 -0500 Subject: [PATCH 20/23] fix: Add security note to LDAIConfigTracker.getResumptionToken() --- .../com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java index 3364d21e..3f591a2c 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIConfigTracker.java @@ -41,6 +41,10 @@ public interface LDAIConfigTracker { * The resumption token encodes the run's identity and can be passed to * {@link LDAIClient#createTracker(String, com.launchdarkly.sdk.LDContext)} to reconstruct a * tracker on a subsequent request (for example, in a streaming scenario). + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. * * @return the resumption token, or {@code null} if not available */ From 121b140154e1833df9391703feef126e91b22cc4 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 17:11:14 -0500 Subject: [PATCH 21/23] fix: Add security note to MetricSummary.getResumptionToken() --- .../sdk/server/ai/datamodel/LDAITrackingTypes.java | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java index 54a186d2..fbae0264 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/datamodel/LDAITrackingTypes.java @@ -575,6 +575,10 @@ public List getToolCalls() { /** * Returns the resumption token for this tracker. + *

+ * Security note: resumption tokens embed flag-evaluation details such as the + * variation key and config version. Keep tokens server-side and do not round-trip them through + * untrusted clients where they could leak flag-targeting information. * * @return the resumption token, or {@code null} if not available */ From a0a3a5444dcd828d60257ca37eb0144dad5e2c67 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 17:13:59 -0500 Subject: [PATCH 22/23] fix: add debugs helper to tests --- .../com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java index 7c7889a8..41d0afde 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java @@ -66,6 +66,13 @@ private List warnings() { .collect(Collectors.toList()); } + private List debugs() { + return logCapture.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.DEBUG) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + } + private LDValue baseExpectedData() { return LDValue.buildObject() .put("runId", RUN_ID) From 80ef017a9865c0c86cf81859aed1283e8abba332 Mon Sep 17 00:00:00 2001 From: Matt McCarthy Date: Wed, 24 Jun 2026 17:23:31 -0500 Subject: [PATCH 23/23] fix: pass instance logger to createGraphTracker for resumed runs --- .../sdk/server/ai/AIGraphTracker.java | 18 +++++++++++++++++- .../sdk/server/ai/LDAIClientImpl.java | 2 +- .../sdk/server/ai/AIGraphTrackerTest.java | 18 ++++++++++++++++++ 3 files changed, 36 insertions(+), 2 deletions(-) diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java index 646796f6..ec234306 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/AIGraphTracker.java @@ -92,6 +92,22 @@ public final class AIGraphTracker { */ public static AIGraphTracker fromResumptionToken( String token, LDClientInterface client, LDContext context) { + return fromResumptionToken(token, client, context, defaultLogger()); + } + + /** + * Reconstructs a graph tracker from a resumption token, preserving the original run identity, + * and logging through the supplied logger. + * + * @param token the resumption token produced by {@link #getResumptionToken()} + * @param client the LaunchDarkly client; must not be {@code null} + * @param context the evaluation context; must not be {@code null} + * @param logger the logger to use for at-most-once warnings; must not be {@code null} + * @return a new tracker with the decoded run identity + * @throws IllegalArgumentException if the token is malformed + */ + public static AIGraphTracker fromResumptionToken( + String token, LDClientInterface client, LDContext context, LDLogger logger) { ResumptionTokens.DecodedGraph d = ResumptionTokens.decodeGraph(token); int version = Math.max(1, d.getVersion()); return new AIGraphTracker( @@ -101,7 +117,7 @@ public static AIGraphTracker fromResumptionToken( d.getVariationKey(), version, context, - defaultLogger()); + logger); } /** diff --git a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java index 391e9069..02581650 100644 --- a/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java +++ b/lib/sdk/server-ai/src/main/java/com/launchdarkly/sdk/server/ai/LDAIClientImpl.java @@ -444,7 +444,7 @@ private static Set collectReachableKeys(AgentGraphFlagValue flagValue) { @Override public AIGraphTracker createGraphTracker(String resumptionToken, LDContext context) { - return AIGraphTracker.fromResumptionToken(resumptionToken, client, context); + return AIGraphTracker.fromResumptionToken(resumptionToken, client, context, logger); } @Override diff --git a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java index 41d0afde..a35a4119 100644 --- a/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java +++ b/lib/sdk/server-ai/src/test/java/com/launchdarkly/sdk/server/ai/AIGraphTrackerTest.java @@ -382,6 +382,24 @@ public void fromResumptionTokenPreservesRunId() { assertThat(captor.getValue().get("graphKey").stringValue(), is(GRAPH_KEY)); } + @Test + public void fromResumptionTokenWithLoggerRoutesWarningsThroughIt() { + LogCapture captureForResumed = Logs.capture(); + LDLogger resumedLogger = LDLogger.withAdapter(captureForResumed, "test"); + String token = tracker.getResumptionToken(); + AIGraphTracker reconstructed = AIGraphTracker.fromResumptionToken(token, client, CONTEXT, resumedLogger); + + reconstructed.trackInvocationSuccess(); + reconstructed.trackInvocationSuccess(); // duplicate — should warn on resumedLogger, not the base logCapture + + List resumedWarnings = captureForResumed.getMessages().stream() + .filter(m -> m.getLevel() == LDLogLevel.WARN) + .map(LogCapture.Message::getText) + .collect(Collectors.toList()); + assertThat(resumedWarnings.stream().anyMatch(w -> w.contains("invocation already recorded")), is(true)); + assertThat(logCapture.getMessages(), is(org.hamcrest.Matchers.empty())); + } + @Test public void fromResumptionTokenClampsVersionLessThanOne() { AIGraphTracker t = new AIGraphTracker(client, RUN_ID, GRAPH_KEY, null, 0, CONTEXT, logger);