diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index debcac4..c8f7480 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -27,7 +27,7 @@ jobs: if: github.event_name == 'pull_request' run: | if git show origin/main:buf.yaml >/dev/null 2>&1; then - buf breaking --against '.git#branch=main' + buf breaking --against '.git#branch=origin/main' else echo "Skipping: no buf.yaml on main yet" fi diff --git a/internal/connector/registration.go b/internal/connector/registration.go index ec66747..12f334a 100644 --- a/internal/connector/registration.go +++ b/internal/connector/registration.go @@ -3,7 +3,6 @@ package connector import ( "context" "fmt" - "net/http" "sync" "time" @@ -62,7 +61,7 @@ func (rc *RegistrationClient) Register(ctx context.Context) (string, error) { req := connect.NewRequest(&gatev1.RegisterConnectorRequest{ Connector: registrationInfoProto(rc.info), }) - rc.rpcClient.authReq(req) + controlplaneclient.ApplyAuth(req, rc.rpcClient.token, "") resp, err := client.RegisterConnector(ctx, req) if err != nil { return "", fmt.Errorf("register request: %w", err) @@ -107,7 +106,7 @@ func (rc *RegistrationClient) sendHeartbeat(ctx context.Context) error { req := connect.NewRequest(&gatev1.HeartbeatConnectorRequest{ Id: rc.connectorID, }) - rc.rpcClient.authReq(req) + controlplaneclient.ApplyAuth(req, rc.rpcClient.token, "") resp, err := client.HeartbeatConnector(ctx, req) if err != nil { return fmt.Errorf("heartbeat request: %w", err) @@ -129,7 +128,7 @@ func (rc *RegistrationClient) Deregister(ctx context.Context) error { return err } req := connect.NewRequest(&gatev1.DeregisterConnectorRequest{Id: rc.connectorID}) - rc.rpcClient.authReq(req) + controlplaneclient.ApplyAuth(req, rc.rpcClient.token, "") if _, err := client.DeregisterConnector(ctx, req); err != nil { return fmt.Errorf("deregister request: %w", err) } @@ -176,15 +175,6 @@ func newRegistrationRPCClient(controlPlaneURL, token string) *registrationRPCCli } } -func (c *registrationRPCClient) authReq(req interface{ Header() http.Header }) { - headers := controlplaneclient.OutgoingHeaders(c.token, "") - for key, values := range headers { - for _, value := range values { - req.Header().Set(key, value) - } - } -} - func (c *registrationRPCClient) service() (gatev1connect.ConnectorServiceClient, error) { c.mu.Lock() defer c.mu.Unlock() diff --git a/internal/controlplane/connect.go b/internal/controlplane/connect.go index d70addf..4fa0b1d 100644 --- a/internal/controlplane/connect.go +++ b/internal/controlplane/connect.go @@ -8,6 +8,8 @@ import ( "time" "connectrpc.com/connect" + "github.com/go-chi/chi/v5/middleware" + gatev1 "github.com/evalops/gate/internal/gen/gate/v1" "github.com/evalops/gate/internal/gen/gate/v1/gatev1connect" "github.com/evalops/gate/internal/store" @@ -49,7 +51,17 @@ func (s *Server) newConnectHandler() http.Handler { } func (s *Server) connectMuxHandler() http.Handler { - connectHandler := s.newConnectHandler() + // Wrap the ConnectRPC handler with the same middleware applied to REST routes. + // RealIP must run before the rate limiter so it sees the client IP, not the proxy IP. + connectHandler := middleware.RealIP( + s.rateLimiter.Middleware()( + maxBodySize(10 * 1024 * 1024)( + securityHeaders( + s.newConnectHandler(), + ), + ), + ), + ) return h2c.NewHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if strings.HasPrefix(r.URL.Path, "/gate.v1.") { connectHandler.ServeHTTP(w, r) diff --git a/internal/controlplaneclient/client.go b/internal/controlplaneclient/client.go index 6c7735a..152d086 100644 --- a/internal/controlplaneclient/client.go +++ b/internal/controlplaneclient/client.go @@ -347,14 +347,12 @@ func ClientTransport(controlPlaneURL string) (string, *http.Client, error) { } } -// OutgoingHeaders returns HTTP headers with auth and org ID for direct use. -func OutgoingHeaders(apiKey, orgID string) http.Header { - h := http.Header{} +// ApplyAuth sets authentication headers on a ConnectRPC request. +func ApplyAuth(req interface{ Header() http.Header }, apiKey, orgID string) { if apiKey != "" { - h.Set("Authorization", "Bearer "+apiKey) + req.Header().Set("Authorization", "Bearer "+apiKey) } if orgID != "" { - h.Set("X-Org-Id", orgID) + req.Header().Set("X-Org-Id", orgID) } - return h } diff --git a/internal/sync/connectclient.go b/internal/sync/connectclient.go index 60e2274..585d73d 100644 --- a/internal/sync/connectclient.go +++ b/internal/sync/connectclient.go @@ -3,7 +3,6 @@ package sync import ( "context" "fmt" - "net/http" "sync" "time" @@ -32,15 +31,6 @@ func newControlPlaneRPCClient(controlPlaneURL, apiKey string) *controlPlaneRPCCl } } -func (c *controlPlaneRPCClient) authReq(req interface{ Header() http.Header }) { - headers := controlplaneclient.OutgoingHeaders(c.apiKey, "") - for key, values := range headers { - for _, value := range values { - req.Header().Set(key, value) - } - } -} - func (c *controlPlaneRPCClient) SyncPolicies(ctx context.Context) ([]SyncPolicy, error) { client, err := c.ensureClient() if err != nil { @@ -48,7 +38,7 @@ func (c *controlPlaneRPCClient) SyncPolicies(ctx context.Context) ([]SyncPolicy, } req := connect.NewRequest(&gatev1.SyncPoliciesRequest{}) - c.authReq(req) + controlplaneclient.ApplyAuth(req, c.apiKey, "") resp, err := client.SyncPolicies(ctx, req) if err != nil { return nil, err @@ -84,7 +74,7 @@ func (c *controlPlaneRPCClient) SyncAccessGrants(ctx context.Context) ([]protoco } req := connect.NewRequest(&gatev1.SyncAccessGrantsRequest{}) - c.authReq(req) + controlplaneclient.ApplyAuth(req, c.apiKey, "") resp, err := client.SyncAccessGrants(ctx, req) if err != nil { return nil, err @@ -142,7 +132,7 @@ func (c *controlPlaneRPCClient) IngestAuditLogs(ctx context.Context, entries []s req := connect.NewRequest(&gatev1.IngestAuditLogsRequest{ Entries: protoEntries, }) - c.authReq(req) + controlplaneclient.ApplyAuth(req, c.apiKey, "") resp, err := client.IngestAuditLogs(ctx, req) if err != nil { return 0, err