diff --git a/Cargo.lock b/Cargo.lock index 4fbcc77..c75e6ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -954,6 +954,7 @@ dependencies = [ "console", "ctrlc", "httpmock", + "mime_guess", "reqwest", "scraper", "serde", @@ -1043,6 +1044,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "mio" version = "1.2.0" @@ -1396,6 +1407,7 @@ dependencies = [ "base64", "bytes", "futures-core", + "futures-util", "http", "http-body", "http-body-util", @@ -1404,6 +1416,7 @@ dependencies = [ "hyper-util", "js-sys", "log", + "mime_guess", "percent-encoding", "pin-project-lite", "quinn", @@ -2103,6 +2116,12 @@ version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "562d481066bde0658276a35467c4af00bdc6ee726305698a55b86e61d7ad82bb" +[[package]] +name = "unicase" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dbc4bc3a9f746d862c45cb89d705aa10f187bb96c76001afab07a0d35ce60142" + [[package]] name = "unicode-ident" version = "1.0.24" diff --git a/Cargo.toml b/Cargo.toml index c570f1a..e485159 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,7 +28,7 @@ clap_complete = "4.6.2" cliclack = "0.5.4" console = "0.16.3" ctrlc = "3.5.2" -reqwest = { version = "0.12.15", default-features = false, features = ["json", "rustls-tls"] } +reqwest = { version = "0.12.15", default-features = false, features = ["json", "multipart", "rustls-tls"] } scraper = "0.26.0" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.149" @@ -37,6 +37,7 @@ tokio = { version = "1.51.1", features = ["macros", "rt-multi-thread"] } toml = "1.1.2" tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } +mime_guess = "2.0.5" [dev-dependencies] httpmock = "0.8.3" diff --git a/README.md b/README.md index d86e210..a6b2b57 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,12 @@ run assistant with a saved assistant profile and markdown output: kagi assistant --assistant research --format markdown "summarize the latest rust release" ``` +attach local files to an assistant prompt: + +```bash +kagi assistant --attach ./a.jpg --attach ./b.pdf "tell me everything about this pdf" +``` + ask assistant about a page directly: ```bash diff --git a/src/api.rs b/src/api.rs index 4f522be..c2df6aa 100644 --- a/src/api.rs +++ b/src/api.rs @@ -1,7 +1,10 @@ use std::collections::HashMap; +use std::fs; use std::future::Future; +use std::path::{Path, PathBuf}; use std::time::Duration; +use reqwest::multipart; use reqwest::{Client, StatusCode, Url, header}; use scraper::Html; use serde::Deserialize; @@ -471,24 +474,27 @@ pub async fn execute_assistant_prompt( request: &AssistantPromptRequest, token: &str, ) -> Result { - let query = normalize_assistant_query(&request.query)?; - let thread_id = normalize_assistant_thread_id(request.thread_id.as_deref())?; - let profile = assistant_profile_payload(request); - let body = execute_assistant_stream( - &http::kagi_url(KAGI_ASSISTANT_PROMPT_PATH), - &json!({ - "focus": { - "thread_id": thread_id, - "branch_id": ASSISTANT_ZERO_BRANCH_UUID, - "prompt": query, - "message_id": Value::Null, - }, - "profile": profile, - }), - token, - "Assistant prompt", - ) - .await?; + let body = match build_assistant_prompt_payload(request)? { + AssistantPromptPayload::Json(state) => { + execute_assistant_stream( + &http::kagi_url(KAGI_ASSISTANT_PROMPT_PATH), + &state, + token, + "Assistant prompt", + ) + .await? + } + AssistantPromptPayload::Multipart { state, attachments } => { + execute_assistant_multipart_stream( + &http::kagi_url(KAGI_ASSISTANT_PROMPT_PATH), + &state, + &attachments, + token, + "Assistant prompt", + ) + .await? + } + }; parse_assistant_prompt_stream(&body) } @@ -1414,6 +1420,7 @@ pub async fn execute_ask_page( &AssistantPromptRequest { query: build_ask_page_prompt(&source_url, &question), thread_id: None, + attachments: Vec::new(), profile_id: None, model: None, lens_id: None, @@ -3019,6 +3026,97 @@ fn assistant_profile_payload(request: &AssistantPromptRequest) -> Value { Value::Object(payload) } +#[derive(Debug, Clone, PartialEq, Eq)] +struct AssistantAttachmentPayload { + path: PathBuf, + filename: String, + content_type: String, + bytes: Vec, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum AssistantPromptPayload { + Json(Value), + Multipart { + state: Value, + attachments: Vec, + }, +} + +fn assistant_prompt_state( + request: &AssistantPromptRequest, + query: String, + thread_id: Option, +) -> Value { + json!({ + "focus": { + "thread_id": thread_id, + "branch_id": ASSISTANT_ZERO_BRANCH_UUID, + "prompt": query, + "message_id": Value::Null, + }, + "profile": assistant_profile_payload(request), + }) +} + +fn build_assistant_prompt_payload( + request: &AssistantPromptRequest, +) -> Result { + let query = normalize_assistant_query(&request.query)?; + let thread_id = normalize_assistant_thread_id(request.thread_id.as_deref())?; + let state = assistant_prompt_state(request, query, thread_id); + + if request.attachments.is_empty() { + return Ok(AssistantPromptPayload::Json(state)); + } + + Ok(AssistantPromptPayload::Multipart { + state, + attachments: load_assistant_attachments(&request.attachments)?, + }) +} + +fn load_assistant_attachments( + paths: &[PathBuf], +) -> Result, KagiError> { + paths + .iter() + .map(|path| load_assistant_attachment(path)) + .collect() +} + +fn load_assistant_attachment(path: &Path) -> Result { + let filename = path + .file_name() + .and_then(|name| name.to_str()) + .map(str::trim) + .filter(|name| !name.is_empty()) + .ok_or_else(|| { + KagiError::Config(format!( + "assistant attachment '{}' must include a file name", + path.display() + )) + })? + .to_string(); + + let bytes = fs::read(path).map_err(|error| { + KagiError::Config(format!( + "failed to read assistant attachment '{}': {error}", + path.display() + )) + })?; + + Ok(AssistantAttachmentPayload { + path: path.to_path_buf(), + filename, + content_type: mime_guess::from_path(path) + .first_or_octet_stream() + .essence_str() + .to_string(), + bytes, + }) +} + async fn execute_assistant_stream( url: &str, payload: &Value, @@ -3042,6 +3140,66 @@ async fn execute_assistant_stream( .await .map_err(map_transport_error)?; + handle_assistant_stream_response(response, surface).await +} + +async fn execute_assistant_multipart_stream( + url: &str, + state: &Value, + attachments: &[AssistantAttachmentPayload], + token: &str, + surface: &str, +) -> Result { + if token.trim().is_empty() { + return Err(KagiError::Auth( + "missing Kagi session token (expected KAGI_SESSION_TOKEN)".to_string(), + )); + } + + let client = build_client()?; + let state_json = serde_json::to_vec(state).map_err(|error| { + KagiError::Config(format!( + "failed to serialize Assistant prompt upload state: {error}" + )) + })?; + let state_part = multipart::Part::bytes(state_json) + .mime_str("application/json") + .map_err(|error| { + KagiError::Config(format!( + "failed to set Assistant upload state MIME type: {error}" + )) + })?; + let mut form = multipart::Form::new().part("state", state_part); + + for attachment in attachments { + let file_part = multipart::Part::bytes(attachment.bytes.clone()) + .file_name(attachment.filename.clone()) + .mime_str(&attachment.content_type) + .map_err(|error| { + KagiError::Config(format!( + "failed to set Assistant attachment MIME type for '{}': {error}", + attachment.path.display() + )) + })?; + form = form.part("file", file_part); + } + + let response = client + .post(url) + .header(header::COOKIE, format!("kagi_session={token}")) + .header(header::ACCEPT, "application/vnd.kagi.stream") + .multipart(form) + .send() + .await + .map_err(map_transport_error)?; + + handle_assistant_stream_response(response, surface).await +} + +async fn handle_assistant_stream_response( + response: reqwest::Response, + surface: &str, +) -> Result { match response.status() { StatusCode::OK => { let body = response.text().await.map_err(|error| { @@ -4167,8 +4325,9 @@ pub struct KagiEnvelope { #[cfg(test)] mod tests { use super::{ - ApiErrorBody, KagiEnvelope, NewsFilterRequest, TRANSLATE_BOOTSTRAP_MISSING_COOKIE_ERROR, - TranslateSuggestionContext, apply_news_content_filters, build_ask_page_prompt, + ApiErrorBody, AssistantPromptPayload, KagiEnvelope, NewsFilterRequest, + TRANSLATE_BOOTSTRAP_MISSING_COOKIE_ERROR, TranslateSuggestionContext, + apply_news_content_filters, build_ask_page_prompt, build_assistant_prompt_payload, build_translate_option_state, build_translate_payload, build_translate_suggestions_payload, build_translate_word_insights_payload, capture_optional_translate_section, effective_translate_source_language, execute_news_filter_presets, extract_set_cookie_value, @@ -4211,11 +4370,14 @@ mod tests { }; use reqwest::StatusCode; use serde_json::{Value, json}; + use std::fs; + use std::path::PathBuf; use std::sync::{ Arc, atomic::{AtomicBool, Ordering}, }; use std::time::{SystemTime, UNIX_EPOCH}; + use tempfile::TempDir; struct ScopedEnvVar { key: &'static str, @@ -4749,6 +4911,147 @@ mod tests { ); } + #[test] + fn builds_json_assistant_prompt_payload_without_attachments() { + let request = AssistantPromptRequest { + query: " hello ".to_string(), + thread_id: Some(" thread-1 ".to_string()), + attachments: Vec::new(), + profile_id: Some("research".to_string()), + model: Some("gpt-5-mini".to_string()), + lens_id: Some(2), + internet_access: Some(true), + personalizations: Some(false), + }; + + match build_assistant_prompt_payload(&request).expect("payload should build") { + AssistantPromptPayload::Json(state) => { + assert_eq!(state["focus"]["prompt"], "hello"); + assert_eq!(state["focus"]["thread_id"], "thread-1"); + assert_eq!( + state["focus"]["branch_id"], + "00000000-0000-4000-0000-000000000000" + ); + assert_eq!(state["profile"]["id"], "research"); + assert_eq!(state["profile"]["model"], "gpt-5-mini"); + assert_eq!(state["profile"]["lens_id"], 2); + assert_eq!(state["profile"]["internet_access"], true); + assert_eq!(state["profile"]["personalizations"], false); + } + other => panic!("expected json assistant payload, got {other:?}"), + } + } + + #[test] + fn builds_multipart_assistant_prompt_payload_with_attachments() { + let tempdir = TempDir::new().expect("tempdir"); + let attachment_path = tempdir.path().join("note.txt"); + fs::write(&attachment_path, "attached-note").expect("attachment should write"); + + let request = AssistantPromptRequest { + query: "Reply with exactly: attached-note".to_string(), + thread_id: None, + attachments: vec![attachment_path.clone()], + profile_id: None, + model: Some("gpt-5-mini".to_string()), + lens_id: None, + internet_access: Some(false), + personalizations: Some(false), + }; + + match build_assistant_prompt_payload(&request).expect("payload should build") { + AssistantPromptPayload::Multipart { state, attachments } => { + assert_eq!( + state["focus"]["prompt"], + "Reply with exactly: attached-note" + ); + assert_eq!(state["profile"]["model"], "gpt-5-mini"); + assert_eq!(state["profile"]["internet_access"], false); + assert_eq!(attachments.len(), 1); + assert_eq!(attachments[0].path, attachment_path); + assert_eq!(attachments[0].filename, "note.txt"); + assert_eq!(attachments[0].content_type, "text/plain"); + assert_eq!(attachments[0].bytes, b"attached-note"); + } + other => panic!("expected multipart assistant payload, got {other:?}"), + } + } + + #[test] + fn rejects_missing_assistant_attachment() { + let missing = PathBuf::from("/tmp/definitely-missing-kagi-assistant-attachment.txt"); + let request = AssistantPromptRequest { + query: "hello".to_string(), + thread_id: None, + attachments: vec![missing.clone()], + profile_id: None, + model: None, + lens_id: None, + internet_access: None, + personalizations: None, + }; + + let error = + build_assistant_prompt_payload(&request).expect_err("missing attachment should fail"); + assert!( + error + .to_string() + .contains("failed to read assistant attachment") + ); + assert!(error.to_string().contains(&missing.display().to_string())); + } + + #[tokio::test] + #[allow(clippy::await_holding_lock)] + async fn assistant_prompt_uses_multipart_when_attachments_are_present() { + use httpmock::Method::POST; + use httpmock::MockServer; + + let server = MockServer::start(); + let _prompt = server.mock(|when, then| { + when.method(POST) + .path("/assistant/prompt") + .header("cookie", "kagi_session=test-session") + .header("accept", "application/vnd.kagi.stream") + .body_includes("name=\"state\"") + .body_includes("name=\"file\"; filename=\"note.txt\"") + .body_includes("\"prompt\":\"Reply with exactly: attached-note\"") + .body_includes("attached-note"); + then.status(200) + .header("content-type", "application/vnd.kagi.stream") + .body(concat!( + "hi:{\"v\":\"test\",\"trace\":\"trace-upload\"}\0\n", + "thread.json:{\"id\":\"thread-1\",\"title\":\"Upload test\",\"ack\":\"2026-04-24T00:00:00Z\",\"created_at\":\"2026-04-24T00:00:00Z\",\"expires_at\":\"2026-04-24T01:00:00Z\",\"saved\":false,\"shared\":false,\"branch_id\":\"00000000-0000-4000-0000-000000000000\",\"tag_ids\":[]}\0\n", + "new_message.json:{\"id\":\"msg-1\",\"thread_id\":\"thread-1\",\"created_at\":\"2026-04-24T00:00:00Z\",\"state\":\"done\",\"prompt\":\"Reply with exactly: attached-note\",\"reply_html\":\"attached-note\",\"md\":\"attached-note\",\"references_html\":\"\",\"references_markdown\":\"\",\"metadata_html\":\"\",\"documents\":[],\"profile\":null}\0\n" + )); + }); + + let tempdir = TempDir::new().expect("tempdir"); + let attachment_path = tempdir.path().join("note.txt"); + fs::write(&attachment_path, "attached-note").expect("attachment should write"); + + let _env_guard = lock_env(); + let _base_url_env = set_env_var("KAGI_BASE_URL", &server.base_url()); + let response = execute_assistant_prompt( + &AssistantPromptRequest { + query: "Reply with exactly: attached-note".to_string(), + thread_id: None, + attachments: vec![attachment_path], + profile_id: None, + model: Some("gpt-5-mini".to_string()), + lens_id: None, + internet_access: Some(false), + personalizations: Some(false), + }, + "test-session", + ) + .await + .expect("assistant prompt should succeed"); + + assert_eq!(response.meta.trace.as_deref(), Some("trace-upload")); + assert_eq!(response.message.markdown.as_deref(), Some("attached-note")); + } + #[test] fn normalizes_custom_bang_trigger_and_redirect_rule() { assert_eq!( @@ -4942,6 +5245,7 @@ mod tests { let request = AssistantPromptRequest { query: format!("Reply with exactly: assistant-v2-smoke-{}", live_nonce()), thread_id: None, + attachments: Vec::new(), profile_id: None, model: Some("gpt-5-mini".to_string()), lens_id: None, @@ -5037,6 +5341,7 @@ mod tests { &AssistantPromptRequest { query: "Reply with exactly: custom-assistant-smoke".to_string(), thread_id: None, + attachments: Vec::new(), profile_id: Some(created_id.clone()), model: None, lens_id: None, diff --git a/src/cli.rs b/src/cli.rs index 99e070c..bd3ea70 100644 --- a/src/cli.rs +++ b/src/cli.rs @@ -5,6 +5,7 @@ //! and per-subcommand options. use clap::{Args, Parser, Subcommand, ValueEnum}; +use std::path::PathBuf; #[derive(Debug, Clone, ValueEnum)] /// Supported shell types for tab-completion generation. @@ -578,6 +579,10 @@ pub struct AssistantArgs { #[arg(long, value_name = "THREAD_ID")] pub thread_id: Option, + /// Attach a local file to the assistant prompt (repeat for multiple files) + #[arg(long, value_name = "PATH")] + pub attach: Vec, + /// Use a saved assistant by name, id, or invoke profile slug #[arg(long, value_name = "ASSISTANT")] pub assistant: Option, @@ -1410,6 +1415,33 @@ mod tests { } } + #[test] + fn parses_assistant_attach_flags() { + let cli = Cli::try_parse_from([ + "kagi", + "assistant", + "--attach", + "./a.jpg", + "--attach", + "./b.pdf", + "tell me everything about this pdf", + ]) + .expect("assistant attach command should parse"); + + match cli.command.expect("command") { + Commands::Assistant(args) => { + assert_eq!( + args.query.as_deref(), + Some("tell me everything about this pdf") + ); + assert_eq!(args.attach.len(), 2); + assert_eq!(args.attach[0].to_string_lossy(), "./a.jpg"); + assert_eq!(args.attach[1].to_string_lossy(), "./b.pdf"); + } + other => panic!("unexpected command: {other:?}"), + } + } + #[test] fn rejects_conflicting_redirect_flags() { let error = Cli::try_parse_from([ diff --git a/src/main.rs b/src/main.rs index a2f5120..ac5c17a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -328,6 +328,7 @@ async fn run() -> Result<(), KagiError> { let request = AssistantPromptRequest { query, thread_id: args.thread_id, + attachments: args.attach, profile_id: normalize_optional_string(args.assistant), model: args.model, lens_id: args.lens, diff --git a/src/types.rs b/src/types.rs index 1fb110b..2b22567 100644 --- a/src/types.rs +++ b/src/types.rs @@ -12,6 +12,7 @@ //! - **Translation**: [`TranslateRequest`], [`TranslateResponse`] use std::collections::HashMap; +use std::path::PathBuf; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -311,6 +312,8 @@ pub struct AssistantPromptRequest { pub query: String, #[serde(skip_serializing_if = "Option::is_none")] pub thread_id: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + pub attachments: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub profile_id: Option, #[serde(skip_serializing_if = "Option::is_none")]