diff --git a/browser.go b/browser.go index 4533fc1..21ef07d 100644 --- a/browser.go +++ b/browser.go @@ -19,6 +19,7 @@ import ( "github.com/kernel/kernel-go-sdk/internal/apijson" "github.com/kernel/kernel-go-sdk/internal/apiquery" "github.com/kernel/kernel-go-sdk/internal/requestconfig" + "github.com/kernel/kernel-go-sdk/lib/browserrouting" "github.com/kernel/kernel-go-sdk/option" "github.com/kernel/kernel-go-sdk/packages/pagination" "github.com/kernel/kernel-go-sdk/packages/param" @@ -69,6 +70,9 @@ func (r *BrowserService) New(ctx context.Context, body BrowserNewParams, opts .. opts = slices.Concat(r.Options, opts) path := "browsers" err = requestconfig.ExecuteNewRequest(ctx, http.MethodPost, path, body, &res, opts...) + if err == nil && res != nil { + storeBrowserRouteCache(opts, browserrouting.Ref{SessionID: res.SessionID, BaseURL: res.BaseURL, CdpWsURL: res.CdpWsURL}) + } return res, err } @@ -81,6 +85,9 @@ func (r *BrowserService) Get(ctx context.Context, id string, query BrowserGetPar } path := fmt.Sprintf("browsers/%s", id) err = requestconfig.ExecuteNewRequest(ctx, http.MethodGet, path, query, &res, opts...) + if err == nil && res != nil { + storeBrowserRouteCache(opts, browserrouting.Ref{SessionID: res.SessionID, BaseURL: res.BaseURL, CdpWsURL: res.CdpWsURL}) + } return res, err } @@ -93,6 +100,9 @@ func (r *BrowserService) Update(ctx context.Context, id string, body BrowserUpda } path := fmt.Sprintf("browsers/%s", id) err = requestconfig.ExecuteNewRequest(ctx, http.MethodPatch, path, body, &res, opts...) + if err == nil && res != nil { + storeBrowserRouteCache(opts, browserrouting.Ref{SessionID: res.SessionID, BaseURL: res.BaseURL, CdpWsURL: res.CdpWsURL}) + } return res, err } @@ -112,6 +122,9 @@ func (r *BrowserService) List(ctx context.Context, query BrowserListParams, opts return nil, err } res.SetPageConfig(cfg, raw) + for _, item := range res.Items { + storeBrowserRouteCache(opts, browserrouting.Ref{SessionID: item.SessionID, BaseURL: item.BaseURL, CdpWsURL: item.CdpWsURL}) + } return res, nil } @@ -157,6 +170,11 @@ func (r *BrowserService) DeleteByID(ctx context.Context, id string, opts ...opti } path := fmt.Sprintf("browsers/%s", id) err = requestconfig.ExecuteNewRequest(ctx, http.MethodDelete, path, nil, nil, opts...) + if err == nil { + if cache := browserRouteCacheFromOptions(opts); cache != nil { + cache.Delete(id) + } + } return err } diff --git a/browser_http_client.go b/browser_http_client.go new file mode 100644 index 0000000..d128d0c --- /dev/null +++ b/browser_http_client.go @@ -0,0 +1,34 @@ +package kernel + +import ( + "context" + "fmt" + "net/http" + "slices" + + "github.com/kernel/kernel-go-sdk/internal/requestconfig" + "github.com/kernel/kernel-go-sdk/lib/browserrouting" + "github.com/kernel/kernel-go-sdk/option" +) + +// HTTPClient returns an [http.Client] that performs HTTP requests through the +// browser VM's internal /curl/raw path using cached browser route data. +func (r *BrowserService) HTTPClient(id string, opts ...option.RequestOption) (*http.Client, error) { + opts = slices.Concat(r.Options, opts) + cache := browserRouteCacheFromOptions(opts) + if cache == nil { + return nil, fmt.Errorf("kernel: browser route cache is not configured") + } + + route, ok := cache.Load(id) + if !ok { + return nil, fmt.Errorf("kernel: browser route cache does not contain session %s", id) + } + + cfg, err := requestconfig.NewRequestConfig(context.Background(), http.MethodGet, "https://example.com", nil, nil, opts...) + if err != nil { + return browserrouting.NewHTTPClient(route.BaseURL, route.JWT, nil), nil + } + + return browserrouting.NewHTTPClient(route.BaseURL, route.JWT, cfg.HTTPClient), nil +} diff --git a/browser_routing.go b/browser_routing.go new file mode 100644 index 0000000..330e951 --- /dev/null +++ b/browser_routing.go @@ -0,0 +1,84 @@ +package kernel + +import ( + "github.com/kernel/kernel-go-sdk/internal/requestconfig" + "github.com/kernel/kernel-go-sdk/lib/browserrouting" + "github.com/kernel/kernel-go-sdk/option" +) + +// BrowserRoutingConfig controls which browser subresources route directly to the browser VM. +type BrowserRoutingConfig struct { + Enabled bool + Subresources []string +} + +type browserRoutingOption struct { + cache *browserrouting.RouteCache + config BrowserRoutingConfig +} + +type browserRouteCacheOption struct { + cache *browserrouting.RouteCache +} + +// WithBrowserRouting enables direct-to-VM routing for the configured browser subresources. +func WithBrowserRouting(config BrowserRoutingConfig) option.RequestOption { + return &browserRoutingOption{config: config} +} + +func (o *browserRoutingOption) Apply(r *requestconfig.RequestConfig) error { + if !o.config.Enabled { + return nil + } + r.Middlewares = append(r.Middlewares, browserrouting.DirectVMRoutingMiddleware(o.cache, o.config.Subresources)) + return nil +} + +func (o *browserRoutingOption) browserRouteCache() *browserrouting.RouteCache { + return o.cache +} + +func (o *browserRouteCacheOption) Apply(*requestconfig.RequestConfig) error { + return nil +} + +func (o *browserRouteCacheOption) browserRouteCache() *browserrouting.RouteCache { + return o.cache +} + +func withBrowserRouteCache(cache *browserrouting.RouteCache) option.RequestOption { + return &browserRouteCacheOption{cache: cache} +} + +func browserRouteCacheFromOptions(opts []option.RequestOption) *browserrouting.RouteCache { + for _, opt := range opts { + if carrier, ok := opt.(interface{ browserRouteCache() *browserrouting.RouteCache }); ok { + if cache := carrier.browserRouteCache(); cache != nil { + return cache + } + } + } + return nil +} + +func storeBrowserRouteCache(opts []option.RequestOption, refs ...browserrouting.Ref) { + cache := browserRouteCacheFromOptions(opts) + for _, ref := range refs { + route, ok := browserRouteFromRef(ref) + if cache != nil && ok { + cache.Store(route) + } + } +} + +func browserRouteFromRef(ref browserrouting.Ref) (browserrouting.Route, bool) { + norm, err := ref.Normalize() + if err != nil { + return browserrouting.Route{}, false + } + return browserrouting.Route{ + SessionID: norm.SessionID, + BaseURL: norm.BaseURL, + JWT: norm.JWT, + }, true +} diff --git a/browser_routing_test.go b/browser_routing_test.go new file mode 100644 index 0000000..fede55c --- /dev/null +++ b/browser_routing_test.go @@ -0,0 +1,113 @@ +package kernel + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/kernel/kernel-go-sdk/option" +) + +func TestBrowserRoutingWarmsCacheAndRoutesAllowlistedSubresources(t *testing.T) { + var calls []struct { + Path string + Auth string + } + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + calls = append(calls, struct { + Path string + Auth string + }{Path: r.URL.Path + "?" + r.URL.RawQuery, Auth: r.Header.Get("Authorization")}) + + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/browsers": + _ = json.NewEncoder(w).Encode(map[string]any{ + "session_id": "sess-1", + "base_url": srv.URL + "/browser/kernel", + "cdp_ws_url": "wss://browser-session.test/browser/cdp?jwt=token-abc", + }) + default: + _ = json.NewEncoder(w).Encode(map[string]any{ + "duration_ms": 1, + "exit_code": 0, + "stderr_b64": "", + "stdout_b64": "", + }) + } + })) + defer srv.Close() + + client := NewClient( + option.WithBaseURL(srv.URL), + option.WithAPIKey("sk_test"), + option.WithHTTPClient(srv.Client()), + WithBrowserRouting(BrowserRoutingConfig{Enabled: true, Subresources: []string{"process"}}), + ) + + if _, err := client.Browsers.New(context.Background(), BrowserNewParams{}); err != nil { + t.Fatal(err) + } + if _, err := client.Browsers.Process.Exec(context.Background(), "sess-1", BrowserProcessExecParams{Command: "echo"}); err != nil { + t.Fatal(err) + } + + if route, ok := client.BrowserRouteCache.Load("sess-1"); !ok || route.JWT != "token-abc" { + t.Fatalf("expected warmed browser route cache, got %#v ok=%v", route, ok) + } + if len(calls) != 2 { + t.Fatalf("expected 2 calls, got %d", len(calls)) + } + if calls[1].Path != "/browser/kernel/process/exec?jwt=token-abc" { + t.Fatalf("expected direct VM path, got %q", calls[1].Path) + } + if calls[1].Auth != "" { + t.Fatalf("expected authorization header removed, got %q", calls[1].Auth) + } +} + +func TestBrowserRoutingSkipsSubresourcesOutsideConfiguredAllowlist(t *testing.T) { + var paths []string + var srv *httptest.Server + srv = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + paths = append(paths, r.URL.Path) + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/browsers": + _ = json.NewEncoder(w).Encode(map[string]any{ + "session_id": "sess-1", + "base_url": srv.URL + "/browser/kernel", + "cdp_ws_url": "wss://browser-session.test/browser/cdp?jwt=token-abc", + }) + default: + _ = json.NewEncoder(w).Encode(map[string]any{ + "duration_ms": 1, + "exit_code": 0, + "stderr_b64": "", + "stdout_b64": "", + }) + } + })) + defer srv.Close() + + client := NewClient( + option.WithBaseURL(srv.URL), + option.WithAPIKey("sk_test"), + option.WithHTTPClient(srv.Client()), + WithBrowserRouting(BrowserRoutingConfig{Enabled: true, Subresources: []string{"computer"}}), + ) + + if _, err := client.Browsers.New(context.Background(), BrowserNewParams{}); err != nil { + t.Fatal(err) + } + if _, err := client.Browsers.Process.Exec(context.Background(), "sess-1", BrowserProcessExecParams{Command: "echo"}); err != nil { + t.Fatal(err) + } + + if got := paths[len(paths)-1]; got != "/browsers/sess-1/process/exec" { + t.Fatalf("expected control-plane path, got %q", got) + } +} diff --git a/browser_session_httpclient_test.go b/browser_session_httpclient_test.go new file mode 100644 index 0000000..63d0f99 --- /dev/null +++ b/browser_session_httpclient_test.go @@ -0,0 +1,77 @@ +package kernel + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/kernel/kernel-go-sdk/lib/browserrouting" + "github.com/kernel/kernel-go-sdk/option" +) + +func TestBrowserSessionHTTPClientRawCurl(t *testing.T) { + var sawRaw string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/browser/kernel/curl/raw" { + http.NotFound(w, r) + return + } + sawRaw = r.URL.RawQuery + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write([]byte("proxied")) + })) + defer srv.Close() + + c := NewClient( + option.WithBaseURL("https://api.example/"), + option.WithAPIKey("sk"), + option.WithHTTPClient(srv.Client()), + ) + + storeBrowserRouteCache(c.Options, browserrouting.Ref{ + SessionID: "sid", + BaseURL: srv.URL + "/browser/kernel", + CdpWsURL: "wss://x/browser/cdp?jwt=j1", + }) + + hc, err := c.Browsers.HTTPClient("sid") + if err != nil { + t.Fatal(err) + } + req, err := http.NewRequest(http.MethodGet, "https://httpbin.org/get", nil) + if err != nil { + t.Fatal(err) + } + res, err := hc.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if string(body) != "proxied" { + t.Fatalf("body %q", body) + } + if sawRaw == "" { + t.Fatal("expected raw query on curl/raw") + } +} + +func TestBrowserSessionHTTPClientRequiresCachedRoute(t *testing.T) { + c := NewClient( + option.WithBaseURL("https://api.example/"), + option.WithAPIKey("sk"), + ) + + storeBrowserRouteCache(c.Options, browserrouting.Ref{ + SessionID: "sid", + BaseURL: "https://browser-session.test/browser/kernel", + CdpWsURL: "wss://x/browser/cdp?jwt=j1", + }) + c.BrowserRouteCache.Delete("sid") + + _, err := c.Browsers.HTTPClient("sid") + if err == nil { + t.Fatal("expected cached route lookup failure") + } +} diff --git a/client.go b/client.go index 880a657..1b719cf 100644 --- a/client.go +++ b/client.go @@ -9,6 +9,7 @@ import ( "slices" "github.com/kernel/kernel-go-sdk/internal/requestconfig" + "github.com/kernel/kernel-go-sdk/lib/browserrouting" "github.com/kernel/kernel-go-sdk/option" ) @@ -17,6 +18,8 @@ import ( // directly, and instead use the [NewClient] method instead. type Client struct { Options []option.RequestOption + // BrowserRouteCache stores cached base_url and jwt data for direct-to-VM routing. + BrowserRouteCache *browserrouting.RouteCache // Create and manage app deployments and stream deployment events. Deployments DeploymentService // List applications and versions. @@ -61,8 +64,20 @@ func DefaultClientOptions() []option.RequestOption { // the services and requests that this client makes. func NewClient(opts ...option.RequestOption) (r Client) { opts = append(DefaultClientOptions(), opts...) + cache := browserrouting.NewRouteCache() + nextOpts := make([]option.RequestOption, 0, len(opts)+1) + for _, opt := range opts { + if routing, ok := opt.(*browserRoutingOption); ok { + cloned := *routing + cloned.cache = cache + nextOpts = append(nextOpts, &cloned) + continue + } + nextOpts = append(nextOpts, opt) + } + opts = append(nextOpts, withBrowserRouteCache(cache)) - r = Client{Options: opts} + r = Client{Options: opts, BrowserRouteCache: cache} r.Deployments = NewDeploymentService(opts...) r.Apps = NewAppService(opts...) diff --git a/examples/browser-routing/main.go b/examples/browser-routing/main.go new file mode 100644 index 0000000..3e3591c --- /dev/null +++ b/examples/browser-routing/main.go @@ -0,0 +1,47 @@ +package main + +import ( + "context" + "fmt" + "net/http" + "os" + + kernel "github.com/kernel/kernel-go-sdk" +) + +func main() { + ctx := context.Background() + client := kernel.NewClient( + kernel.WithBrowserRouting(kernel.BrowserRoutingConfig{ + Enabled: true, + Subresources: []string{"process"}, + }), + ) + + browser, err := client.Browsers.New(ctx, kernel.BrowserNewParams{ + Headless: kernel.Bool(true), + }) + if err != nil { + panic(err) + } + defer func() { + _ = client.Browsers.DeleteByID(context.Background(), browser.SessionID) + }() + + httpClient, err := client.Browsers.HTTPClient(browser.SessionID) + if err != nil { + panic(err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) + if err != nil { + panic(err) + } + + resp, err := httpClient.Do(req) + if err != nil { + panic(err) + } + defer resp.Body.Close() + + fmt.Fprintln(os.Stdout, "status", resp.StatusCode) +} diff --git a/lib/browserrouting/integration_test.go b/lib/browserrouting/integration_test.go new file mode 100644 index 0000000..f1d29db --- /dev/null +++ b/lib/browserrouting/integration_test.go @@ -0,0 +1,87 @@ +package browserrouting_test + +import ( + "context" + "io" + "net/http" + "os" + "strings" + "testing" + "time" + + kernel "github.com/kernel/kernel-go-sdk" + "github.com/kernel/kernel-go-sdk/option" +) + +func TestIntegrationBrowserRouting(t *testing.T) { + apiKey := strings.TrimSpace(os.Getenv("KERNEL_API_KEY")) + baseURL := strings.TrimSpace(os.Getenv("KERNEL_BASE_URL")) + if apiKey == "" || baseURL == "" { + t.Skip("set KERNEL_API_KEY and KERNEL_BASE_URL to run integration test") + } + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + client := kernel.NewClient( + option.WithAPIKey(apiKey), + option.WithBaseURL(baseURL), + kernel.WithBrowserRouting(kernel.BrowserRoutingConfig{ + Enabled: true, + Subresources: []string{"process"}, + }), + ) + + browser, err := client.Browsers.New(ctx, kernel.BrowserNewParams{ + Headless: kernel.Bool(true), + TimeoutSeconds: kernel.Int(60), + }) + if err != nil { + t.Fatalf("create browser: %v", err) + } + t.Cleanup(func() { + _ = client.Browsers.DeleteByID(context.Background(), browser.SessionID) + }) + + if browser.BaseURL == "" { + t.Fatal("expected browser base_url to be set") + } + + execRes, err := client.Browsers.Process.Exec(ctx, browser.SessionID, kernel.BrowserProcessExecParams{ + Command: "echo", + Args: []string{"hello"}, + }) + if err != nil { + t.Fatalf("process exec: %v", err) + } + if execRes.ExitCode != 0 { + t.Fatalf("expected process exit code 0, got %d", execRes.ExitCode) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) + if err != nil { + t.Fatalf("new request: %v", err) + } + httpClient, err := client.Browsers.HTTPClient(browser.SessionID) + if err != nil { + t.Fatalf("browser http client: %v", err) + } + resp, err := httpClient.Do(req) + if err != nil { + t.Fatalf("browser http client: %v", err) + } + defer resp.Body.Close() + if resp.StatusCode < 200 || resp.StatusCode >= 400 { + t.Fatalf("expected successful browser http response, got %d", resp.StatusCode) + } + if got := resp.Header.Get("Content-Type"); got == "" { + t.Fatal("expected raw browser response to include content type") + } + body, err := io.ReadAll(resp.Body) + if err != nil { + t.Fatalf("read raw browser response: %v", err) + } + if len(body) == 0 { + t.Fatal("expected raw browser response body to be non-empty") + } +} diff --git a/lib/browserrouting/jwt.go b/lib/browserrouting/jwt.go new file mode 100644 index 0000000..98f676e --- /dev/null +++ b/lib/browserrouting/jwt.go @@ -0,0 +1,25 @@ +package browserrouting + +import ( + "fmt" + "net/url" + "strings" +) + +// jwtFromWebSocketURL extracts the session jwt query parameter from a browser +// websocket URL (for example cdp_ws_url or webdriver_ws_url). +func jwtFromWebSocketURL(wsURL string) (string, error) { + wsURL = strings.TrimSpace(wsURL) + if wsURL == "" { + return "", fmt.Errorf("browserrouting: empty websocket url") + } + u, err := url.Parse(wsURL) + if err != nil { + return "", fmt.Errorf("browserrouting: parse websocket url: %w", err) + } + jwt := u.Query().Get("jwt") + if jwt == "" { + return "", fmt.Errorf("browserrouting: missing jwt query parameter") + } + return jwt, nil +} diff --git a/lib/browserrouting/jwt_test.go b/lib/browserrouting/jwt_test.go new file mode 100644 index 0000000..c206784 --- /dev/null +++ b/lib/browserrouting/jwt_test.go @@ -0,0 +1,21 @@ +package browserrouting + +import "testing" + +func TestJWTFromWebSocketURL(t *testing.T) { + const u = "wss://browser.example/browser/cdp?jwt=abc123&foo=bar" + j, err := jwtFromWebSocketURL(u) + if err != nil { + t.Fatal(err) + } + if j != "abc123" { + t.Fatalf("jwt: got %q want abc123", j) + } +} + +func TestJWTFromWebSocketURLMissing(t *testing.T) { + _, err := jwtFromWebSocketURL("wss://browser.example/browser/cdp") + if err == nil { + t.Fatal("expected error") + } +} diff --git a/lib/browserrouting/rawcurl.go b/lib/browserrouting/rawcurl.go new file mode 100644 index 0000000..ce4d6eb --- /dev/null +++ b/lib/browserrouting/rawcurl.go @@ -0,0 +1,81 @@ +package browserrouting + +import ( + "fmt" + "net/http" + "net/url" + "strings" +) + +// RawCURLRoundTripper implements browser-egress HTTP by tunneling through the +// browser session base_url /curl/raw endpoint. +type RawCURLRoundTripper struct { + browserBaseURL string + jwt string + underlying http.RoundTripper +} + +var _ http.RoundTripper = (*RawCURLRoundTripper)(nil) + +// NewHTTPClient returns an [http.Client] that performs browser egress HTTP via +// the browser session base_url and internal /curl/raw path. +func NewHTTPClient(browserBaseURL, jwt string, underlying *http.Client) *http.Client { + if underlying == nil { + underlying = http.DefaultClient + } + rt := underlying.Transport + if rt == nil { + rt = http.DefaultTransport + } + return &http.Client{ + Transport: newRawCURLRoundTripper(browserBaseURL, jwt, rt), + Timeout: underlying.Timeout, + } +} + +// newRawCURLRoundTripper returns an [http.RoundTripper] that maps each request +// to {base_url}/curl/raw?jwt=...&url=, preserving method, +// headers, and body. The caller's request URL must be an absolute http(s) URL. +func newRawCURLRoundTripper(browserBaseURL, jwt string, underlying http.RoundTripper) *RawCURLRoundTripper { + if underlying == nil { + underlying = http.DefaultTransport + } + return &RawCURLRoundTripper{ + browserBaseURL: strings.TrimRight(strings.TrimSpace(browserBaseURL), "/"), + jwt: strings.TrimSpace(jwt), + underlying: underlying, + } +} + +// RoundTrip implements [http.RoundTripper]. +func (t *RawCURLRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + if req.URL == nil || !req.URL.IsAbs() { + return nil, fmt.Errorf("browserrouting: raw curl requires an absolute request URL (got %q)", req.URL) + } + if req.URL.Scheme != "http" && req.URL.Scheme != "https" { + return nil, fmt.Errorf("browserrouting: raw curl requires http or https scheme") + } + if t.browserBaseURL == "" { + return nil, fmt.Errorf("browserrouting: browser base_url is required") + } + if t.jwt == "" { + return nil, fmt.Errorf("browserrouting: jwt is required for raw curl") + } + + target := req.URL.String() + proxyURL, err := url.Parse(t.browserBaseURL + "/curl/raw") + if err != nil { + return nil, err + } + q := url.Values{} + q.Set("jwt", t.jwt) + q.Set("url", target) + proxyURL.RawQuery = q.Encode() + + out := req.Clone(req.Context()) + out.URL = proxyURL + out.Host = proxyURL.Host + out.RequestURI = "" + + return t.underlying.RoundTrip(out) +} diff --git a/lib/browserrouting/rawcurl_test.go b/lib/browserrouting/rawcurl_test.go new file mode 100644 index 0000000..57982b2 --- /dev/null +++ b/lib/browserrouting/rawcurl_test.go @@ -0,0 +1,54 @@ +package browserrouting + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" +) + +func TestRawCURLRoundTripper(t *testing.T) { + var sawPath, sawRaw string + up := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sawPath = r.URL.Path + sawRaw = r.URL.RawQuery + w.WriteHeader(http.StatusTeapot) + _, _ = w.Write([]byte("ok")) + })) + defer up.Close() + + rt := newRawCURLRoundTripper(up.URL+"/browser/kernel", "jwt1", http.DefaultTransport) + + client := &http.Client{Transport: rt} + req, err := http.NewRequest(http.MethodGet, "https://example.org/foo?bar=1", nil) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != http.StatusTeapot { + t.Fatalf("status: %d", res.StatusCode) + } + b, _ := io.ReadAll(res.Body) + if string(b) != "ok" { + t.Fatalf("body: %q", b) + } + if sawPath != "/browser/kernel/curl/raw" { + t.Fatalf("path: %s", sawPath) + } + if sawRaw == "" { + t.Fatal("expected query") + } +} + +func TestRawCURLRoundTripperRelativeURL(t *testing.T) { + rt := newRawCURLRoundTripper("https://x/browser/kernel", "j", http.DefaultTransport) + req, _ := http.NewRequest(http.MethodGet, "/relative", nil) + _, err := rt.RoundTrip(req) + if err == nil { + t.Fatal("expected error for relative url") + } +} diff --git a/lib/browserrouting/ref.go b/lib/browserrouting/ref.go new file mode 100644 index 0000000..d98e497 --- /dev/null +++ b/lib/browserrouting/ref.go @@ -0,0 +1,43 @@ +package browserrouting + +import ( + "fmt" + "strings" +) + +// Ref identifies a browser session for direct-to-VM HTTP calls. SessionID is +// reserved for future client-side routing; allowlisted requests rewrite the +// /browsers/{SessionID}/ path segment against the returned base_url. +type Ref struct { + SessionID string + BaseURL string + JWT string + CdpWsURL string +} + +// Normalize validates fields and fills JWT from CdpWsURL when JWT is empty. +func (r Ref) Normalize() (Ref, error) { + if strings.TrimSpace(r.BaseURL) == "" { + return Ref{}, fmt.Errorf("browserrouting: base_url is required") + } + if strings.TrimSpace(r.SessionID) == "" { + return Ref{}, fmt.Errorf("browserrouting: session_id is required") + } + out := r + out.BaseURL = strings.TrimSpace(r.BaseURL) + out.SessionID = strings.TrimSpace(r.SessionID) + out.JWT = strings.TrimSpace(r.JWT) + out.CdpWsURL = strings.TrimSpace(r.CdpWsURL) + if out.JWT == "" { + src := out.CdpWsURL + if src == "" { + return Ref{}, fmt.Errorf("browserrouting: jwt or cdp_ws_url is required") + } + jwt, err := jwtFromWebSocketURL(src) + if err != nil { + return Ref{}, err + } + out.JWT = jwt + } + return out, nil +} diff --git a/lib/browserrouting/route_cache.go b/lib/browserrouting/route_cache.go new file mode 100644 index 0000000..98fca6e --- /dev/null +++ b/lib/browserrouting/route_cache.go @@ -0,0 +1,140 @@ +package browserrouting + +import ( + "net/http" + "net/url" + "strings" + "sync" + + "github.com/kernel/kernel-go-sdk/option" +) + +// Route identifies a cached direct-to-VM transport for one browser session. +type Route struct { + SessionID string + BaseURL string + JWT string +} + +// RouteCache stores browser session transport details keyed by session_id. +type RouteCache struct { + mu sync.RWMutex + routes map[string]Route +} + +// NewRouteCache returns an empty browser route cache. +func NewRouteCache() *RouteCache { + return &RouteCache{routes: map[string]Route{}} +} + +// Load returns the cached route for the given session id. +func (c *RouteCache) Load(sessionID string) (Route, bool) { + if c == nil { + return Route{}, false + } + c.mu.RLock() + defer c.mu.RUnlock() + route, ok := c.routes[sessionID] + return route, ok +} + +// Store normalizes and caches the given route. +func (c *RouteCache) Store(route Route) { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + c.routes[strings.TrimSpace(route.SessionID)] = Route{ + SessionID: strings.TrimSpace(route.SessionID), + BaseURL: strings.TrimSpace(route.BaseURL), + JWT: strings.TrimSpace(route.JWT), + } +} + +// Delete removes a cached route. +func (c *RouteCache) Delete(sessionID string) { + if c == nil { + return + } + c.mu.Lock() + defer c.mu.Unlock() + delete(c.routes, sessionID) +} + +// DirectVMRoutingMiddleware rewrites allowlisted browser subresource requests to +// the browser VM using cached base_url and jwt data. +func DirectVMRoutingMiddleware(cache *RouteCache, subresources []string) option.Middleware { + allowed := map[string]struct{}{} + for _, subresource := range subresources { + if trimmed := strings.TrimSpace(subresource); trimmed != "" { + allowed[trimmed] = struct{}{} + } + } + + return func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) { + sessionID, subresource, suffix, ok := parseDirectVMPath(req.URL.Path) + if !ok { + return next(req) + } + if _, ok := allowed[subresource]; !ok { + return next(req) + } + route, ok := cache.Load(sessionID) + if !ok { + return next(req) + } + + base, err := url.Parse(route.BaseURL) + if err != nil { + return nil, err + } + req.Header.Del("Authorization") + if route.JWT != "" { + q := req.URL.Query() + if q.Get("jwt") == "" { + q.Set("jwt", route.JWT) + req.URL.RawQuery = q.Encode() + } + } + + req.URL.Scheme = base.Scheme + req.URL.Host = base.Host + req.Host = base.Host + req.URL.Path = joinURLPath(base.Path, subresource, suffix) + req.URL.RawPath = "" + + return next(req) + } +} + +func parseDirectVMPath(path string) (sessionID, subresource, suffix string, ok bool) { + parts := strings.Split(strings.Trim(path, "/"), "/") + for i := 0; i+2 < len(parts); i++ { + if parts[i] != "browsers" { + continue + } + sessionID = parts[i+1] + subresource = parts[i+2] + if sessionID == "" || subresource == "" { + return "", "", "", false + } + if i+3 < len(parts) { + suffix = "/" + strings.Join(parts[i+3:], "/") + } + return sessionID, subresource, suffix, true + } + return "", "", "", false +} + +func joinURLPath(basePath, subresource, suffix string) string { + base := "/" + strings.Trim(strings.TrimSpace(basePath), "/") + if base == "/" { + base = "" + } + out := base + "/" + strings.TrimPrefix(subresource, "/") + if suffix != "" { + out += "/" + strings.TrimPrefix(suffix, "/") + } + return out +} diff --git a/lib/browserrouting/route_cache_test.go b/lib/browserrouting/route_cache_test.go new file mode 100644 index 0000000..aa84b58 --- /dev/null +++ b/lib/browserrouting/route_cache_test.go @@ -0,0 +1,52 @@ +package browserrouting + +import ( + "net/http" + "net/url" + "testing" +) + +func TestDirectVMRoutingMiddlewareClearsStaleRawPath(t *testing.T) { + cache := NewRouteCache() + cache.Store(Route{ + SessionID: "sess-1", + BaseURL: "https://browser.example/browser/kernel", + JWT: "jwt-123", + }) + + middleware := DirectVMRoutingMiddleware(cache, []string{"process"}) + reqURL, err := url.Parse("https://api.example/browsers/sess-1/process/exec") + if err != nil { + t.Fatal(err) + } + reqURL.RawPath = "/browsers/sess-1/process/%65xec" + + req := &http.Request{ + URL: reqURL, + Header: http.Header{"Authorization": []string{"Bearer sk_test"}}, + } + + var got *http.Request + _, err = middleware(req, func(next *http.Request) (*http.Response, error) { + got = next + return nil, nil + }) + if err != nil { + t.Fatal(err) + } + if got == nil { + t.Fatal("expected middleware to invoke next handler") + } + if got.URL.Path != "/browser/kernel/process/exec" { + t.Fatalf("expected rewritten path, got %q", got.URL.Path) + } + if got.URL.RawPath != "" { + t.Fatalf("expected stale raw path to be cleared, got %q", got.URL.RawPath) + } + if got.URL.Query().Get("jwt") != "jwt-123" { + t.Fatalf("expected jwt query param, got %q", got.URL.Query().Get("jwt")) + } + if got.Header.Get("Authorization") != "" { + t.Fatalf("expected authorization to be stripped, got %q", got.Header.Get("Authorization")) + } +}