Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ require (
github.com/prometheus/procfs v0.19.2
github.com/safchain/ethtool v0.7.0
golang.org/x/sys v0.40.0
gopkg.in/yaml.v2 v2.4.0
howett.net/plist v1.0.1
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/yaml.v1 v1.0.0-20140924161607-9f9df34309c0/go.mod h1:WDnlLJ4WF5VGsH/HVa3CI79GS0ol3YnhVnKP89i0kNg=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
Expand Down
223 changes: 214 additions & 9 deletions node_exporter.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@ package main

import (
"fmt"
"io"
"log/slog"
"net"
"net/http"
_ "net/http/pprof"
"os"
"os/user"
"runtime"
"slices"
"sort"
"strings"

"gopkg.in/yaml.v2"

"github.com/prometheus/common/promslog"
"github.com/prometheus/common/promslog/flag"
Expand All @@ -39,6 +44,12 @@ import (
"github.com/prometheus/node_exporter/collector"
)

const (
defaultIPHeaderXForwardedFor = "X-Forwarded-For"
defaultIPHeaderXRealIP = "X-Real-IP"
defaultIPHeaderXForwarded = "X-Forwarded"
)

// handler wraps an unfiltered http.Handler but uses a filtered handler,
// created on the fly, if filtering is requested. Create instances with
// newHandler.
Expand All @@ -52,14 +63,18 @@ type handler struct {
includeExporterMetrics bool
maxRequests int
logger *slog.Logger
allowedNetworks []*net.IPNet
ipHeaders []string
}

func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logger) *handler {
func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logger, allowedNetworks []*net.IPNet, ipHeaders []string) *handler {
h := &handler{
exporterMetricsRegistry: prometheus.NewRegistry(),
includeExporterMetrics: includeExporterMetrics,
maxRequests: maxRequests,
logger: logger,
allowedNetworks: allowedNetworks,
ipHeaders: ipHeaders,
}
if h.includeExporterMetrics {
h.exporterMetricsRegistry.MustRegister(
Expand All @@ -75,8 +90,72 @@ func newHandler(includeExporterMetrics bool, maxRequests int, logger *slog.Logge
return h
}

func (h *handler) getClientIP(r *http.Request) net.IP {
headers := h.ipHeaders
if len(headers) == 0 {
headers = []string{
defaultIPHeaderXForwardedFor,
defaultIPHeaderXRealIP,
defaultIPHeaderXForwarded,
}
}

for _, header := range headers {
ipStr := r.Header.Get(header)
if ipStr == "" {
continue
}

if header == defaultIPHeaderXForwardedFor {
ips := strings.Split(ipStr, ",")
if len(ips) > 0 {
ipStr = strings.TrimSpace(ips[0])
}
}

if ip := net.ParseIP(ipStr); ip != nil {
return ip
}
}

host, _, err := net.SplitHostPort(r.RemoteAddr)
if err != nil {
return net.ParseIP(r.RemoteAddr)
}
return net.ParseIP(host)
}

func (h *handler) isIPAllowed(ip net.IP) bool {
if len(h.allowedNetworks) == 0 {
return true
}
for _, network := range h.allowedNetworks {
if network.Contains(ip) {
h.logger.Debug("IP allowed by network", "ip", ip.String(), "network", network.String())
return true
}
}
h.logger.Debug("IP not in any allowed network", "ip", ip.String())
return false
}

// ServeHTTP implements http.Handler.
func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if len(h.allowedNetworks) > 0 {
clientIP := h.getClientIP(r)
if clientIP == nil {
h.logger.Debug("Could not parse client IP address", "remote_addr", r.RemoteAddr)
http.Error(w, "Access denied: could not parse client IP address", http.StatusForbidden)
return
}
if !h.isIPAllowed(clientIP) {
h.logger.Debug("Access denied for IP", "ip", clientIP.String(), "remote_addr", r.RemoteAddr)
http.Error(w, "Access denied", http.StatusForbidden)
return
}
h.logger.Debug("Access allowed for IP", "ip", clientIP.String(), "remote_addr", r.RemoteAddr)
}

collects := r.URL.Query()["collect[]"]
h.logger.Debug("collect query:", "collects", collects)

Expand All @@ -91,8 +170,7 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {

if len(collects) > 0 && len(excludes) > 0 {
h.logger.Debug("rejecting combined collect and exclude queries")
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("Combined collect and exclude queries are not allowed."))
http.Error(w, "Combined collect and exclude queries are not allowed.", http.StatusBadRequest)
return
}

Expand All @@ -108,12 +186,10 @@ func (h *handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
filters = &f
}

// To serve filtered metrics, we create a filtering handler on the fly.
filteredHandler, err := h.innerHandler(*filters...)
if err != nil {
h.logger.Warn("Couldn't create filtered metrics handler:", "err", err)
w.WriteHeader(http.StatusBadRequest)
fmt.Fprintf(w, "Couldn't create filtered metrics handler: %s", err)
h.logger.Warn("Couldn't create filtered metrics handler", "err", err)
http.Error(w, fmt.Sprintf("Couldn't create filtered metrics handler: %s", err), http.StatusBadRequest)
return
}
filteredHandler.ServeHTTP(w, r)
Expand Down Expand Up @@ -179,6 +255,107 @@ func (h *handler) innerHandler(filters ...string) (http.Handler, error) {
return handler, nil
}

func parseAllowedNetworks(networkStrings []string) ([]*net.IPNet, error) {
if len(networkStrings) == 0 {
return nil, nil
}

networks := make([]*net.IPNet, 0, len(networkStrings))
for _, networkStr := range networkStrings {
networkStr = strings.TrimSpace(networkStr)
if networkStr == "" {
continue
}

if !strings.Contains(networkStr, "/") {
ip := net.ParseIP(networkStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address: %s", networkStr)
}
if ip.To4() != nil {
networkStr = networkStr + "/32"
} else {
networkStr = networkStr + "/128"
}
}

_, network, err := net.ParseCIDR(networkStr)
if err != nil {
return nil, fmt.Errorf("invalid network CIDR %s: %w", networkStr, err)
}
networks = append(networks, network)
}

return networks, nil
}

type whitelistConfig struct {
AllowedNetworks []string `yaml:"allowed_networks"`
IPHeaders []string `yaml:"ip_headers"`
}

func loadWhitelistConfig(configPath string) (*whitelistConfig, error) {
if configPath == "" {
return nil, nil
}

file, err := os.Open(configPath)
if err != nil {
return nil, fmt.Errorf("failed to open config file: %w", err)
}
defer file.Close()

data, err := io.ReadAll(file)
if err != nil {
return nil, fmt.Errorf("failed to read config file: %w", err)
}

var config struct {
Whitelist whitelistConfig `yaml:"whitelist"`
}
if err := yaml.Unmarshal(data, &config); err != nil {
return nil, fmt.Errorf("failed to parse config file: %w", err)
}

if len(config.Whitelist.AllowedNetworks) == 0 && len(config.Whitelist.IPHeaders) == 0 {
return nil, nil
}

return &config.Whitelist, nil
}

func loadWhitelistSettings(configPath string, networksFlag string, logger *slog.Logger) ([]*net.IPNet, []string, error) {
var networks []*net.IPNet
var ipHeaders []string

config, err := loadWhitelistConfig(configPath)
if err != nil {
return nil, nil, fmt.Errorf("failed to load whitelist config: %w", err)
}

if config != nil {
if len(config.AllowedNetworks) > 0 {
networks, err = parseAllowedNetworks(config.AllowedNetworks)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse allowed networks from config: %w", err)
}
}
if len(config.IPHeaders) > 0 {
ipHeaders = config.IPHeaders
}
}

if networksFlag != "" {
networkStrings := strings.Split(networksFlag, ",")
networks, err = parseAllowedNetworks(networkStrings)
if err != nil {
return nil, nil, fmt.Errorf("failed to parse allowed networks from flag: %w", err)
}
}

return networks, ipHeaders, nil
}

func main() {
var (
metricsPath = kingpin.Flag(
Expand All @@ -193,6 +370,14 @@ func main() {
"web.max-requests",
"Maximum number of parallel scrape requests. Use 0 to disable.",
).Default("40").Int()
allowedNetworks = kingpin.Flag(
"web.allowed-networks",
"Comma-separated list of allowed IP networks in CIDR notation (e.g., 192.168.1.0/24,10.0.0.0/8). Single IPs are also accepted and will be treated as /32 for IPv4 or /128 for IPv6.",
).String()
whitelistConfigPath = kingpin.Flag(
"web.whitelist-config",
"Path to YAML configuration file for IP whitelist settings.",
).String()
disableDefaultCollectors = kingpin.Flag(
"collector.disable-defaults",
"Set all collectors to disabled by default.",
Expand All @@ -211,6 +396,22 @@ func main() {
kingpin.Parse()
logger := promslog.New(promslogConfig)

networks, ipHeaders, err := loadWhitelistSettings(*whitelistConfigPath, *allowedNetworks, logger)
if err != nil {
logger.Error("Failed to load whitelist settings", "error", err)
os.Exit(1)
}

if len(networks) > 0 {
logger.Info("IP whitelist enabled", "networks", len(networks))
for _, network := range networks {
logger.Info("Allowed network", "network", network.String())
}
if len(ipHeaders) > 0 {
logger.Info("IP headers configured", "headers", strings.Join(ipHeaders, ", "))
}
}

if *disableDefaultCollectors {
collector.DisableDefaultCollectors()
}
Expand All @@ -222,7 +423,9 @@ func main() {
runtime.GOMAXPROCS(*maxProcs)
logger.Debug("Go MAXPROCS", "procs", runtime.GOMAXPROCS(0))

http.Handle(*metricsPath, newHandler(!*disableExporterMetrics, *maxRequests, logger))
metricsHandler := newHandler(!*disableExporterMetrics, *maxRequests, logger, networks, ipHeaders)
http.Handle(*metricsPath, metricsHandler)

if *metricsPath != "/" {
landingConfig := web.LandingConfig{
Name: "Node Exporter",
Expand All @@ -240,7 +443,9 @@ func main() {
logger.Error(err.Error())
os.Exit(1)
}
http.Handle("/", landingPage)
landingHandler := newHandler(false, 0, logger, networks, ipHeaders)
landingHandler.unfilteredHandler = landingPage
http.Handle("/", landingHandler)
}

server := &http.Server{}
Expand Down