diff --git a/go.mod b/go.mod index 056a340030..d9cb4789ec 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/prometheus/procfs v0.19.2 github.com/safchain/ethtool v0.7.0 golang.org/x/sys v0.40.0 + gopkg.in/yaml.v2 v2.4.0 howett.net/plist v1.0.1 ) diff --git a/go.sum b/go.sum index 1ef3b5a97a..98e0b5d61b 100644 --- a/go.sum +++ b/go.sum @@ -139,6 +139,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/node_exporter.go b/node_exporter.go index 2c0e12ccc1..3b4c10a9a8 100644 --- a/node_exporter.go +++ b/node_exporter.go @@ -15,7 +15,9 @@ package main import ( "fmt" + "io" "log/slog" + "net" "net/http" _ "net/http/pprof" "os" @@ -23,6 +25,9 @@ import ( "runtime" "slices" "sort" + "strings" + + "gopkg.in/yaml.v2" "github.com/prometheus/common/promslog" "github.com/prometheus/common/promslog/flag" @@ -39,6 +44,12 @@ import ( "github.com/prometheus/node_exporter/collector" ) +const ( + defaultIPHeaderXForwardedFor = "X-Forwarded-For" + defaultIPHeaderXRealIP = "X-Real-IP" + defaultIPHeaderXForwarded = "X-Forwarded" +) + // handler wraps an unfiltered http.Handler but uses a filtered handler, // created on the fly, if filtering is requested. Create instances with // newHandler. @@ -52,14 +63,18 @@ type handler struct { includeExporterMetrics bool maxRequests int logger *slog.Logger + allowedNetworks []*net.IPNet + ipHeaders []string } -func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logger) *handler { +func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logger, allowedNetworks []*net.IPNet, ipHeaders []string) *handler { h := &handler{ exporterMetricsRegistry: prometheus.NewRegistry(), includeExporterMetrics: includeExporterMetrics, maxRequests: maxRequests, logger: logger, + allowedNetworks: allowedNetworks, + ipHeaders: ipHeaders, } if h.includeExporterMetrics { h.exporterMetricsRegistry.MustRegister( @@ -75,8 +90,72 @@ func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logge return h } +func (h *handler) getClientIP(r *http.Request) net.IP { + headers := h.ipHeaders + if len(headers) == 0 { + headers = []string{ + defaultIPHeaderXForwardedFor, + defaultIPHeaderXRealIP, + defaultIPHeaderXForwarded, + } + } + + for _, header := range headers { + ipStr := r.Header.Get(header) + if ipStr == "" { + continue + } + + if header == defaultIPHeaderXForwardedFor { + ips := strings.Split(ipStr, ",") + if len(ips) > 0 { + ipStr = strings.TrimSpace(ips[0]) + } + } + + if ip := net.ParseIP(ipStr); ip != nil { + return ip + } + } + + host, _, err := net.SplitHostPort(r.RemoteAddr) + if err != nil { + return net.ParseIP(r.RemoteAddr) + } + return net.ParseIP(host) +} + +func (h *handler) isIPAllowed(ip net.IP) bool { + if len(h.allowedNetworks) == 0 { + return true + } + for _, network := range h.allowedNetworks { + if network.Contains(ip) { + h.logger.Debug("IP allowed by network", "ip", ip.String(), "network", network.String()) + return true + } + } + h.logger.Debug("IP not in any allowed network", "ip", ip.String()) + return false +} + // ServeHTTP implements http.Handler. func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if len(h.allowedNetworks) > 0 { + clientIP := h.getClientIP(r) + if clientIP == nil { + h.logger.Debug("Could not parse client IP address", "remote_addr", r.RemoteAddr) + http.Error(w, "Access denied: could not parse client IP address", http.StatusForbidden) + return + } + if !h.isIPAllowed(clientIP) { + h.logger.Debug("Access denied for IP", "ip", clientIP.String(), "remote_addr", r.RemoteAddr) + http.Error(w, "Access denied", http.StatusForbidden) + return + } + h.logger.Debug("Access allowed for IP", "ip", clientIP.String(), "remote_addr", r.RemoteAddr) + } + collects := r.URL.Query()["collect[]"] h.logger.Debug("collect query:", "collects", collects) @@ -91,8 +170,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if len(collects) > 0 && len(excludes) > 0 { h.logger.Debug("rejecting combined collect and exclude queries") - w.WriteHeader(http.StatusBadRequest) - w.Write([]byte("Combined collect and exclude queries are not allowed.")) + http.Error(w, "Combined collect and exclude queries are not allowed.", http.StatusBadRequest) return } @@ -108,12 +186,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { filters = &f } - // To serve filtered metrics, we create a filtering handler on the fly. filteredHandler, err := h.innerHandler(*filters...) if err != nil { - h.logger.Warn("Couldn't create filtered metrics handler:", "err", err) - w.WriteHeader(http.StatusBadRequest) - fmt.Fprintf(w, "Couldn't create filtered metrics handler: %s", err) + h.logger.Warn("Couldn't create filtered metrics handler", "err", err) + http.Error(w, fmt.Sprintf("Couldn't create filtered metrics handler: %s", err), http.StatusBadRequest) return } filteredHandler.ServeHTTP(w, r) @@ -179,6 +255,107 @@ func (h *handler) innerHandler(filters ...string) (http.Handler, error) { return handler, nil } +func parseAllowedNetworks(networkStrings []string) ([]*net.IPNet, error) { + if len(networkStrings) == 0 { + return nil, nil + } + + networks := make([]*net.IPNet, 0, len(networkStrings)) + for _, networkStr := range networkStrings { + networkStr = strings.TrimSpace(networkStr) + if networkStr == "" { + continue + } + + if !strings.Contains(networkStr, "/") { + ip := net.ParseIP(networkStr) + if ip == nil { + return nil, fmt.Errorf("invalid IP address: %s", networkStr) + } + if ip.To4() != nil { + networkStr = networkStr + "/32" + } else { + networkStr = networkStr + "/128" + } + } + + _, network, err := net.ParseCIDR(networkStr) + if err != nil { + return nil, fmt.Errorf("invalid network CIDR %s: %w", networkStr, err) + } + networks = append(networks, network) + } + + return networks, nil +} + +type whitelistConfig struct { + AllowedNetworks []string `yaml:"allowed_networks"` + IPHeaders []string `yaml:"ip_headers"` +} + +func loadWhitelistConfig(configPath string) (*whitelistConfig, error) { + if configPath == "" { + return nil, nil + } + + file, err := os.Open(configPath) + if err != nil { + return nil, fmt.Errorf("failed to open config file: %w", err) + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + return nil, fmt.Errorf("failed to read config file: %w", err) + } + + var config struct { + Whitelist whitelistConfig `yaml:"whitelist"` + } + if err := yaml.Unmarshal(data, &config); err != nil { + return nil, fmt.Errorf("failed to parse config file: %w", err) + } + + if len(config.Whitelist.AllowedNetworks) == 0 && len(config.Whitelist.IPHeaders) == 0 { + return nil, nil + } + + return &config.Whitelist, nil +} + +func loadWhitelistSettings(configPath string, networksFlag string, logger *slog.Logger) ([]*net.IPNet, []string, error) { + var networks []*net.IPNet + var ipHeaders []string + + config, err := loadWhitelistConfig(configPath) + if err != nil { + return nil, nil, fmt.Errorf("failed to load whitelist config: %w", err) + } + + if config != nil { + if len(config.AllowedNetworks) > 0 { + networks, err = parseAllowedNetworks(config.AllowedNetworks) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse allowed networks from config: %w", err) + } + } + if len(config.IPHeaders) > 0 { + ipHeaders = config.IPHeaders + } + } + + if networksFlag != "" { + networkStrings := strings.Split(networksFlag, ",") + networks, err = parseAllowedNetworks(networkStrings) + if err != nil { + return nil, nil, fmt.Errorf("failed to parse allowed networks from flag: %w", err) + } + } + + return networks, ipHeaders, nil +} + func main() { var ( metricsPath = kingpin.Flag( @@ -193,6 +370,14 @@ func main() { "web.max-requests", "Maximum number of parallel scrape requests. Use 0 to disable.", ).Default("40").Int() + allowedNetworks = kingpin.Flag( + "web.allowed-networks", + "Comma-separated list of allowed IP networks in CIDR notation (e.g., 192.168.1.0/24,10.0.0.0/8). Single IPs are also accepted and will be treated as /32 for IPv4 or /128 for IPv6.", + ).String() + whitelistConfigPath = kingpin.Flag( + "web.whitelist-config", + "Path to YAML configuration file for IP whitelist settings.", + ).String() disableDefaultCollectors = kingpin.Flag( "collector.disable-defaults", "Set all collectors to disabled by default.", @@ -211,6 +396,22 @@ func main() { kingpin.Parse() logger := promslog.New(promslogConfig) + networks, ipHeaders, err := loadWhitelistSettings(*whitelistConfigPath, *allowedNetworks, logger) + if err != nil { + logger.Error("Failed to load whitelist settings", "error", err) + os.Exit(1) + } + + if len(networks) > 0 { + logger.Info("IP whitelist enabled", "networks", len(networks)) + for _, network := range networks { + logger.Info("Allowed network", "network", network.String()) + } + if len(ipHeaders) > 0 { + logger.Info("IP headers configured", "headers", strings.Join(ipHeaders, ", ")) + } + } + if *disableDefaultCollectors { collector.DisableDefaultCollectors() } @@ -222,7 +423,9 @@ func main() { runtime.GOMAXPROCS(*maxProcs) logger.Debug("Go MAXPROCS", "procs", runtime.GOMAXPROCS(0)) - http.Handle(*metricsPath, newHandler(!*disableExporterMetrics, *maxRequests, logger)) + metricsHandler := newHandler(!*disableExporterMetrics, *maxRequests, logger, networks, ipHeaders) + http.Handle(*metricsPath, metricsHandler) + if *metricsPath != "/" { landingConfig := web.LandingConfig{ Name: "Node Exporter", @@ -240,7 +443,9 @@ func main() { logger.Error(err.Error()) os.Exit(1) } - http.Handle("/", landingPage) + landingHandler := newHandler(false, 0, logger, networks, ipHeaders) + landingHandler.unfilteredHandler = landingPage + http.Handle("/", landingHandler) } server := &http.Server{}