diff --git a/lib/browserrouting/rawcurl.go b/lib/browserrouting/rawcurl.go index ce4d6eb..6cbec47 100644 --- a/lib/browserrouting/rawcurl.go +++ b/lib/browserrouting/rawcurl.go @@ -23,9 +23,9 @@ func NewHTTPClient(browserBaseURL, jwt string, underlying *http.Client) *http.Cl if underlying == nil { underlying = http.DefaultClient } - rt := underlying.Transport + rt := rawCURLUnderlyingTransport(underlying.Transport) if rt == nil { - rt = http.DefaultTransport + rt = rawCURLUnderlyingTransport(http.DefaultTransport) } return &http.Client{ Transport: newRawCURLRoundTripper(browserBaseURL, jwt, rt), @@ -33,6 +33,15 @@ func NewHTTPClient(browserBaseURL, jwt string, underlying *http.Client) *http.Cl } } +func rawCURLUnderlyingTransport(rt http.RoundTripper) http.RoundTripper { + if transport, ok := rt.(*http.Transport); ok { + clone := transport.Clone() + clone.DisableCompression = true + return clone + } + return rt +} + // newRawCURLRoundTripper returns an [http.RoundTripper] that maps each request // to {base_url}/curl/raw?jwt=...&url=, preserving method, // headers, and body. The caller's request URL must be an absolute http(s) URL. @@ -76,6 +85,18 @@ func (t *RawCURLRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro out.URL = proxyURL out.Host = proxyURL.Host out.RequestURI = "" + if !hasHeader(out.Header, "User-Agent") { + out.Header["User-Agent"] = nil + } return t.underlying.RoundTrip(out) } + +func hasHeader(headers http.Header, name string) bool { + for key := range headers { + if strings.EqualFold(key, name) { + return true + } + } + return false +} diff --git a/lib/browserrouting/rawcurl_test.go b/lib/browserrouting/rawcurl_test.go index 57982b2..4a5e22e 100644 --- a/lib/browserrouting/rawcurl_test.go +++ b/lib/browserrouting/rawcurl_test.go @@ -44,6 +44,64 @@ func TestRawCURLRoundTripper(t *testing.T) { } } +func TestRawCURLRoundTripperSuppressesGoTransportHeaders(t *testing.T) { + var sawUserAgent, sawAcceptEncoding string + up := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sawUserAgent = r.Header.Get("User-Agent") + sawAcceptEncoding = r.Header.Get("Accept-Encoding") + w.WriteHeader(http.StatusNoContent) + })) + defer up.Close() + + client := NewHTTPClient(up.URL+"/browser/kernel", "jwt1", http.DefaultClient) + req, err := http.NewRequest(http.MethodGet, "https://example.org/foo", nil) + if err != nil { + t.Fatal(err) + } + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + if sawUserAgent != "" { + t.Fatalf("unexpected user-agent: %q", sawUserAgent) + } + if sawAcceptEncoding != "" { + t.Fatalf("unexpected accept-encoding: %q", sawAcceptEncoding) + } +} + +func TestRawCURLRoundTripperPreservesExplicitTransportHeaders(t *testing.T) { + var sawUserAgent, sawAcceptEncoding string + up := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sawUserAgent = r.Header.Get("User-Agent") + sawAcceptEncoding = r.Header.Get("Accept-Encoding") + w.WriteHeader(http.StatusNoContent) + })) + defer up.Close() + + client := NewHTTPClient(up.URL+"/browser/kernel", "jwt1", http.DefaultClient) + req, err := http.NewRequest(http.MethodGet, "https://example.org/foo", nil) + if err != nil { + t.Fatal(err) + } + req.Header.Set("User-Agent", "custom-agent") + req.Header.Set("Accept-Encoding", "identity") + res, err := client.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + if sawUserAgent != "custom-agent" { + t.Fatalf("user-agent: %q", sawUserAgent) + } + if sawAcceptEncoding != "identity" { + t.Fatalf("accept-encoding: %q", sawAcceptEncoding) + } +} + func TestRawCURLRoundTripperRelativeURL(t *testing.T) { rt := newRawCURLRoundTripper("https://x/browser/kernel", "j", http.DefaultTransport) req, _ := http.NewRequest(http.MethodGet, "/relative", nil)