diff --git a/Package.swift b/Package.swift index e4cd52e..9d6affb 100644 --- a/Package.swift +++ b/Package.swift @@ -40,7 +40,7 @@ let package = Package( ), .testTarget( name: "ATOAuthKitTests", - dependencies: ["ATOAuthKit"] + dependencies: ["ATOAuthKit", "OAuthTypes"] ), ] ) diff --git a/Sources/OAuthTypes/Constants.swift b/Sources/OAuthTypes/Constants.swift index e4328db..1c1a887 100644 --- a/Sources/OAuthTypes/Constants.swift +++ b/Sources/OAuthTypes/Constants.swift @@ -1,5 +1,5 @@ // -// Models.swift +// Constants.swift // ATOAuthKit // // Created by Christopher Jr Riley on 2025-07-28. diff --git a/Sources/OAuthTypes/OAuthAuthorizationDetails.swift b/Sources/OAuthTypes/OAuthAuthorizationDetails.swift index d290f33..2088beb 100644 --- a/Sources/OAuthTypes/OAuthAuthorizationDetails.swift +++ b/Sources/OAuthTypes/OAuthAuthorizationDetails.swift @@ -36,4 +36,13 @@ public struct AuthorizationDetail: Codable { self.identifier = identifier self.privileges = privileges } + + enum CodingKeys: String, CodingKey { + case type + case locations + case actions + case dataTypes = "datatypes" + case identifier + case privileges + } } diff --git a/Sources/OAuthTypes/OAuthAuthorizationServerMetadata.swift b/Sources/OAuthTypes/OAuthAuthorizationServerMetadata.swift index 279389d..bf58e28 100644 --- a/Sources/OAuthTypes/OAuthAuthorizationServerMetadata.swift +++ b/Sources/OAuthTypes/OAuthAuthorizationServerMetadata.swift @@ -264,11 +264,16 @@ public struct AuthorizationServerMetadata: Codable { self.authorizationEndpoint = try container.decode(URI.WebURI.self, forKey: .authorizationEndpoint) self.tokenEndpoint = try container.decode(URI.WebURI.self, forKey: .tokenEndpoint) - let tokenEndpointAuthMethodsSupported = try container.decode([String].self, forKey: .tokenEndpointAuthMethodsSupported) - if tokenEndpointAuthMethodsSupported == [""] { - self.tokenEndpointAuthMethodsSupported = ["client_secret_basic"] + // The TypeScript zod `.default(["client_secret_basic"])` only applies when the key is + // absent; an explicit JSON null fails validation. Mirror that: decode non-optionally when + // the key is present (so null throws), and default only when the key is omitted entirely. + if container.contains(.tokenEndpointAuthMethodsSupported) { + self.tokenEndpointAuthMethodsSupported = try container.decode( + [String].self, + forKey: .tokenEndpointAuthMethodsSupported + ) } else { - self.tokenEndpointAuthMethodsSupported = try container.decode([String].self, forKey: .tokenEndpointAuthMethodsSupported) + self.tokenEndpointAuthMethodsSupported = ["client_secret_basic"] } self.tokenEndpointAuthSigningAlgorithmValuesSupported = try container.decodeIfPresent([String].self, forKey: .tokenEndpointAuthSigningAlgorithmValuesSupported) diff --git a/Sources/OAuthTypes/OAuthClientRredentials.swift b/Sources/OAuthTypes/OAuthClientCredentials.swift similarity index 100% rename from Sources/OAuthTypes/OAuthClientRredentials.swift rename to Sources/OAuthTypes/OAuthClientCredentials.swift diff --git a/Sources/OAuthTypes/OAuthClientIDDiscoverable.swift b/Sources/OAuthTypes/OAuthClientIDDiscoverable.swift index 6f4e99e..b74be3c 100644 --- a/Sources/OAuthTypes/OAuthClientIDDiscoverable.swift +++ b/Sources/OAuthTypes/OAuthClientIDDiscoverable.swift @@ -40,10 +40,10 @@ public struct ClientIDDiscoverable: Codable, CustomStringConvertible { try URI.validateHTTPSURI(uriString: clientID) guard let urlComponents = URLComponents(string: clientID) else { - throw OAuthClientIDDiscoverableError.invlaidURL + throw OAuthClientIDDiscoverableError.invalidURL } - guard urlComponents.user != nil || urlComponents.password != nil else { + guard urlComponents.user == nil, urlComponents.password == nil else { throw OAuthClientIDDiscoverableError.credentialsDetected } @@ -51,27 +51,41 @@ public struct ClientIDDiscoverable: Codable, CustomStringConvertible { throw OAuthClientIDDiscoverableError.containsFragment } - guard urlComponents.path != "/" else { + // The WHATWG URL parser normalizes a bare authority (e.g. "https://example.com") + // to a "/" pathname; URLComponents leaves it empty. Treat an empty path as "no path + // component" to match the TypeScript `url.pathname === '/'` rejection. + guard !urlComponents.path.isEmpty else { + throw OAuthClientIDDiscoverableError.endsInTrailingSlash + } + + // Reject any path that ends in a trailing slash (TS: `url.pathname.endsWith('/')`). + // This also covers the root path "/" once a bare authority has been normalized. + guard !urlComponents.path.hasSuffix("/") else { throw OAuthClientIDDiscoverableError.endsInTrailingSlash } guard let hostname = urlComponents.host else { - throw OAuthClientIDDiscoverableError.invlaidURL + throw OAuthClientIDDiscoverableError.invalidURL } guard isHostnameIPAddress(hostname) == false else { throw OAuthClientIDDiscoverableError.containsIPAddress } - let url = URL(string: clientID) - guard let originalURLPath = url?.path(percentEncoded: false) else { - throw OAuthClientIDDiscoverableError.invlaidURL + // The WHATWG URL parser normalizes the path (resolving "." and ".." segments), + // so we compare the raw, un-normalized path against the normalized one to reject + // path traversal (TS: `extractUrlPath(value) !== url.pathname`). `URLComponents` + // preserves the raw path; `URL.standardized` resolves the dot segments. A mismatch + // means the original value was not in canonical form. + let rawPath = urlComponents.path + guard let normalizedPath = URL(string: clientID)?.standardized.path(percentEncoded: false) else { + throw OAuthClientIDDiscoverableError.invalidURL } - guard originalURLPath == urlComponents.path else { + guard rawPath == normalizedPath else { throw OAuthClientIDDiscoverableError.incorrectCanonicalForm( - expectedValue: urlComponents.path, - foundValue: originalURLPath + expectedValue: normalizedPath, + foundValue: rawPath ) } } @@ -106,15 +120,19 @@ public struct ConventionalOAuthClientID: Codable, CustomStringConvertible { /// /// - Parameter rawValue: The raw value to validate and use for the new instance. public init(validating rawValue: String) throws { + // A conventional client ID is the intersection of the discoverable schema + // and the conventional-specific checks, so validate it as discoverable first. + _ = try ClientIDDiscoverable(validating: rawValue) + guard let urlComponents = URLComponents(string: rawValue) else { - throw OAuthClientIDDiscoverableError.invlaidURL + throw OAuthClientIDDiscoverableError.invalidURL } - guard urlComponents.port != nil else { + guard urlComponents.port == nil else { throw OAuthClientIDDiscoverableError.containsPort } - guard urlComponents.query != nil else { + guard urlComponents.query == nil else { throw OAuthClientIDDiscoverableError.containsQuery } diff --git a/Sources/OAuthTypes/OAuthIssuerIdentifier.swift b/Sources/OAuthTypes/OAuthIssuerIdentifier.swift index 0a0197e..728fc65 100644 --- a/Sources/OAuthTypes/OAuthIssuerIdentifier.swift +++ b/Sources/OAuthTypes/OAuthIssuerIdentifier.swift @@ -29,28 +29,40 @@ public struct IssuerIdentifier: Codable, CustomStringConvertible { let webURIValue = try String(describing: URI.WebURI(validating: rawValue)) guard webURIValue.last != "/" else { - throw OAuthIssuerIdentifierError.issurURLEndsWithSlash + throw OAuthIssuerIdentifierError.issuerURLEndsWithSlash } guard let urlComponents = URLComponents(string: webURIValue) else { throw OAuthIssuerIdentifierError.invalidURL } - guard urlComponents.user == nil || urlComponents.password == nil else { + guard urlComponents.user == nil, urlComponents.password == nil else { throw OAuthIssuerIdentifierError.usernameOrPasswordDetected } - guard urlComponents.query != nil || urlComponents.fragment != nil else { + guard urlComponents.query == nil, urlComponents.fragment == nil else { throw OAuthIssuerIdentifierError.queryOrFragmentDetected } - let port = urlComponents.port != nil ? ":\(String(describing: urlComponents.port))" : nil - - guard let scheme = urlComponents.scheme, let host = urlComponents.host, let absoluteString = urlComponents.url?.absoluteString else { + guard let scheme = urlComponents.scheme, let host = urlComponents.host else { throw OAuthIssuerIdentifierError.invalidURL } - let canonicalValue = urlComponents.path == "/" ? "\(scheme)://\(host)\(port ?? "")" : "\(absoluteString)" + // The WHATWG URL parser lowercases the scheme and host when it builds `url.origin` + // and `url.href`, so a mixed-case host such as "https://AUTH.EXAMPLE.com" is not in + // canonical form. Lowercase the scheme and host (never the path) before comparing. + let canonicalScheme = scheme.lowercased() + let canonicalHost = host.lowercased() + + // The parser also drops the scheme's default port (443 for "https", 80 for "http") + // when it builds `url.origin`, so an issuer that spells out the default port is not + // in canonical form. Any other port is kept. + let isDefaultPort = (canonicalScheme == "https" && urlComponents.port == 443) + || (canonicalScheme == "http" && urlComponents.port == 80) + let port = urlComponents.port.flatMap { isDefaultPort ? nil : ":\($0)" } + let canonicalValue = urlComponents.path.isEmpty || urlComponents.path == "/" + ? "\(canonicalScheme)://\(canonicalHost)\(port ?? "")" + : "\(canonicalScheme)://\(canonicalHost)\(port ?? "")\(urlComponents.path)" guard rawValue == canonicalValue else { throw OAuthIssuerIdentifierError.notInCanonicalForm diff --git a/Sources/OAuthTypes/OAuthPARResponse.swift b/Sources/OAuthTypes/OAuthPARResponse.swift index 926ffc5..2bb99f7 100644 --- a/Sources/OAuthTypes/OAuthPARResponse.swift +++ b/Sources/OAuthTypes/OAuthPARResponse.swift @@ -11,27 +11,39 @@ public struct OAuthPARResponse: Codable { /// The request URI of the PAR response. public let requestURI: String - /// The date and time it will expire (in a UNIX timestamp). + /// The lifetime of the request URI, in seconds. /// - /// This will always be a positive number. + /// This will always be a positive integer. public let expiresIn: Int /// Creates an instance of `OAuthPARResponse`. /// - /// If a negative number is inserted into `expiresIn`, it will be converted to a positive number. - /// /// - Parameters: /// - requestURI: The request URI of the PAR response. - /// - expiresIn: This will always be a positive number. + /// - expiresIn: The lifetime of the request URI, in seconds. Must be a positive integer. public init(requestURI: String, expiresIn: Int) { self.requestURI = requestURI - self.expiresIn = abs(expiresIn) + self.expiresIn = expiresIn + } + + enum CodingKeys: String, CodingKey { + case requestURI = "request_uri" + case expiresIn = "expires_in" } public init(from decoder: any Decoder) throws { let container = try decoder.container(keyedBy: CodingKeys.self) self.requestURI = try container.decode(String.self, forKey: .requestURI) - self.expiresIn = abs(try container.decode(Int.self, forKey: .expiresIn)) + + let decodedExpiresIn = try container.decode(Int.self, forKey: .expiresIn) + guard decodedExpiresIn >= 1 else { + throw DecodingError.dataCorruptedError( + forKey: .expiresIn, + in: container, + debugDescription: "expires_in must be a positive integer; received \(decodedExpiresIn)." + ) + } + self.expiresIn = decodedExpiresIn } } diff --git a/Sources/OAuthTypes/OAuthProtectedResourceMetadata.swift b/Sources/OAuthTypes/OAuthProtectedResourceMetadata.swift index b14a850..d0028ac 100644 --- a/Sources/OAuthTypes/OAuthProtectedResourceMetadata.swift +++ b/Sources/OAuthTypes/OAuthProtectedResourceMetadata.swift @@ -54,7 +54,6 @@ public struct ProtectedResourceMetadata: Codable { resourcePolicyURI: URI.WebURI? = nil, resourceTermsOfServiceURI: URI.WebURI? = nil ) throws { -#if !DEBUG if String(describing: resource).contains("?") { throw OAuthProtectedResourceMetadataError.containsQuery } @@ -62,7 +61,7 @@ public struct ProtectedResourceMetadata: Codable { if String(describing: resource).contains("#") { throw OAuthProtectedResourceMetadataError.containsFragment } -#endif + self.resource = resource self.authorizationServers = authorizationServers self.jwksURI = jwksURI @@ -78,7 +77,7 @@ public struct ProtectedResourceMetadata: Codable { let container = try decoder.container(keyedBy: CodingKeys.self) let decodedResource = try container.decode(URI.WebURI.self, forKey: .resource) -#if !DEBUG + if String(describing: decodedResource).contains("?") { throw OAuthProtectedResourceMetadataError.containsQuery } @@ -86,7 +85,7 @@ public struct ProtectedResourceMetadata: Codable { if String(describing: decodedResource).contains("#") { throw OAuthProtectedResourceMetadataError.containsFragment } -#endif + self.resource = decodedResource self.authorizationServers = try container.decodeIfPresent([IssuerIdentifier].self, forKey: .authorizationServers) self.jwksURI = try container.decodeIfPresent(URI.WebURI.self, forKey: .jwksURI) diff --git a/Sources/OAuthTypes/OAuthRedirectURI.swift b/Sources/OAuthTypes/OAuthRedirectURI.swift index e60eb6c..70cc311 100644 --- a/Sources/OAuthTypes/OAuthRedirectURI.swift +++ b/Sources/OAuthTypes/OAuthRedirectURI.swift @@ -20,9 +20,14 @@ public struct OAuthLoopbackRedirectURI: Codable, CustomStringConvertible { /// /// - Parameter rawValue: The raw value to validate and use for the new instance. public init(validating rawValue: String) throws { + // Chain the full loopback URI schema first (http: scheme + loopback IP host). + _ = try URI.LoopbackRedirectURI(validating: rawValue) + + // Then exclude the "localhost" hostname (RFC 8252). guard !rawValue.starts(with: "http://localhost") else { throw OAuthRedirectURIError.localhostDetected } + self.rawValue = rawValue } diff --git a/Sources/OAuthTypes/OAuthScope.swift b/Sources/OAuthTypes/OAuthScope.swift index c0f8874..e1d22b6 100644 --- a/Sources/OAuthTypes/OAuthScope.swift +++ b/Sources/OAuthTypes/OAuthScope.swift @@ -7,7 +7,7 @@ import Foundation -/// An structure representing the set of possible OAuth scopes as plain strings. +/// A structure representing the set of possible OAuth scopes as plain strings. /// /// Each OAuth scope defines what permissions or access a client is requesting from the user. /// This enum provides a helper for validating if a given scope string matches the expected diff --git a/Sources/OAuthTypes/OAuthTokenIdentification.swift b/Sources/OAuthTypes/OAuthTokenIdentification.swift index 799e62e..e504fb6 100644 --- a/Sources/OAuthTypes/OAuthTokenIdentification.swift +++ b/Sources/OAuthTypes/OAuthTokenIdentification.swift @@ -12,8 +12,13 @@ public struct TokenIdentification: Codable { /// The actual OAuth token string. public let token: Token - /// A hint to help the server identify the type of token being sent. - public let tokenTypeHint: TokenTypeHint + /// A hint to help the server identify the type of token being sent. Optional. + public let tokenTypeHint: TokenTypeHint? + + enum CodingKeys: String, CodingKey { + case token + case tokenTypeHint = "token_type_hint" + } /// A representation of possible OAuth token variants. public enum Token: Codable { diff --git a/Sources/OAuthTypes/OAuthTokenResponse.swift b/Sources/OAuthTypes/OAuthTokenResponse.swift index 606b38e..3003c50 100644 --- a/Sources/OAuthTypes/OAuthTokenResponse.swift +++ b/Sources/OAuthTypes/OAuthTokenResponse.swift @@ -22,8 +22,8 @@ public struct TokenResponse: Codable { /// The refresh token of the response. Optional. public let refreshToken: String? - /// The date and time the response expires. Optional. - public let expiresIn: Date? + /// The lifetime of the access token, in seconds. Optional. + public let expiresIn: Int? /// The signed JSON Web Token (JWT). Optional. public let idToken: SignedJWT? diff --git a/Sources/OAuthTypes/OAuthTokenType.swift b/Sources/OAuthTypes/OAuthTokenType.swift index 2213950..fda6cc1 100644 --- a/Sources/OAuthTypes/OAuthTokenType.swift +++ b/Sources/OAuthTypes/OAuthTokenType.swift @@ -36,4 +36,22 @@ public enum OAuthTokenType: String, Codable, CaseIterable { return nil } } + + public init(from decoder: any Decoder) throws { + let container = try decoder.singleValueContainer() + let value = try container.decode(String.self) + + guard let instance = Self.parse(value) else { + throw DecodingError.dataCorruptedError( + in: container, + debugDescription: "Invalid OAuthTokenType value: \(value)" + ) + } + self = instance + } + + public func encode(to encoder: any Encoder) throws { + var container = encoder.singleValueContainer() + try container.encode(self.rawValue) + } } diff --git a/Sources/OAuthTypes/OAuthTypesLabsErrors.swift b/Sources/OAuthTypes/OAuthTypesLabsErrors.swift index 55aeb5d..2c51f5d 100644 --- a/Sources/OAuthTypes/OAuthTypesLabsErrors.swift +++ b/Sources/OAuthTypes/OAuthTypesLabsErrors.swift @@ -22,8 +22,8 @@ public enum OAuthTypesLabsURIError: Error, LocalizedError, CustomStringConvertib /// The URI/URL lacks the "HTTPS" protocol. case noHTTPSProtocol - /// The URI has less than two segements. - case lessThanTwoSegementsInURI + /// The URI has less than two segments. + case lessThanTwoSegmentsInURI /// The URI ended with `.local`. case endsInLocal @@ -53,7 +53,7 @@ public enum OAuthTypesLabsURIError: Error, LocalizedError, CustomStringConvertib return "The URI must be 'http://localhost:8080', '127.0.0.1', or '[::1]' as the hostname." case .noHTTPSProtocol: return "The URI must start with 'https://'." - case .lessThanTwoSegementsInURI: + case .lessThanTwoSegmentsInURI: return "The URI must contain at least two segments." case .endsInLocal: return "The URI must not end with '.local'." @@ -271,7 +271,7 @@ public enum OAuthClientIDLoopbackError: Error, LocalizedError, CustomStringConve public enum OAuthIssuerIdentifierError: Error, LocalizedError, CustomStringConvertible { /// Issuer URL contained a slash (`/`) at the end. - case issurURLEndsWithSlash + case issuerURLEndsWithSlash /// The URL provided was invalid. case invalidURL @@ -287,7 +287,7 @@ public enum OAuthIssuerIdentifierError: Error, LocalizedError, CustomStringConve public var errorDescription: String? { switch self { - case .issurURLEndsWithSlash: + case .issuerURLEndsWithSlash: return "Issuer URL must not end with a slash." case .invalidURL: return "The URL provided was invalid." @@ -338,7 +338,7 @@ public enum OAuthAuthorizationServerMetadataError: Error, CustomStringConvertibl public enum OAuthClientIDDiscoverableError: Error, LocalizedError, CustomStringConvertible { /// The client ID provided doesn't have the "https://" protocol. - case invlaidURL + case invalidURL /// The client ID contains a username and/or password. case credentialsDetected @@ -377,7 +377,7 @@ public enum OAuthClientIDDiscoverableError: Error, LocalizedError, CustomStringC public var errorDescription: String? { switch self { - case .invlaidURL: + case .invalidURL: return "The client ID URL is invalid." case .credentialsDetected: return "The client ID must not contain a username or password." diff --git a/Sources/OAuthTypes/OIDCClaimsParameter.swift b/Sources/OAuthTypes/OIDCClaimsParameter.swift index 69736ed..2588170 100644 --- a/Sources/OAuthTypes/OIDCClaimsParameter.swift +++ b/Sources/OAuthTypes/OIDCClaimsParameter.swift @@ -25,7 +25,7 @@ public enum OpenIDConnectClaimsParameter: String, CaseIterable, Codable { /// This lets clients enforce minimum assurance levels case authenticationContextClassReference = "acr" - // Profile-specifc + // Profile-specific /// The full name of the user. case name @@ -36,7 +36,7 @@ public enum OpenIDConnectClaimsParameter: String, CaseIterable, Codable { /// The given name of the user. /// - /// Otehrwise known as the "first name." + /// Otherwise known as the "first name." case givenName = "given_name" /// The middle name of the user. @@ -66,7 +66,7 @@ public enum OpenIDConnectClaimsParameter: String, CaseIterable, Codable { /// The time zone of the user. /// /// This would typically be in the IANA format. - case timeZone = "zoneInfo" + case timeZone = "zoneinfo" /// The user's preferred language. case locale diff --git a/Sources/OAuthTypes/URI.swift b/Sources/OAuthTypes/URI.swift index cb33354..b70979f 100644 --- a/Sources/OAuthTypes/URI.swift +++ b/Sources/OAuthTypes/URI.swift @@ -64,7 +64,11 @@ public enum URI { throw OAuthTypesLabsURIError.noHTTPProtocol } - guard isLoopbackHost(rawValue) else { + guard let parsedHost = URLComponents(string: rawValue)?.host, !parsedHost.isEmpty else { + throw OAuthTypesLabsURIError.invalidLoopbackURI + } + + guard isLoopbackHost(parsedHost) else { throw OAuthTypesLabsURIError.invalidLoopbackURI } @@ -85,21 +89,26 @@ public enum URI { throw OAuthTypesLabsURIError.noHTTPSProtocol } - guard isLoopbackHost(uriString) else { - throw OAuthTypesLabsURIError.invalidLoopbackURI + guard let parsedHost = URLComponents(string: uriString)?.host, !parsedHost.isEmpty else { + throw OAuthTypesLabsURIError.noURIHostname } - guard let uriHost = URL(string: uriString)?.host() else { - throw OAuthTypesLabsURIError.noURIHostname + // Re-bracket an unbracketed IPv6 hostname so it matches the form the comparison helpers expect. + let uriHost = parsedHost.contains(":") && !parsedHost.hasPrefix("[") ? "[\(parsedHost)]" : parsedHost + + // Disallow loopback URLs with the "https:" protocol. + guard !isLoopbackHost(uriHost) else { + throw OAuthTypesLabsURIError.invalidLoopbackURI } - if !isHostnameIPAddress(uriString) { + if !isHostnameIPAddress(uriHost) { + // The hostname is a domain name. if !uriHost.contains(".") { - throw OAuthTypesLabsURIError.lessThanTwoSegementsInURI + throw OAuthTypesLabsURIError.lessThanTwoSegmentsInURI } - if uriHost.contains(".local") { - throw OAuthTypesLabsURIError.noURIHostname + if uriHost.hasSuffix(".local") { + throw OAuthTypesLabsURIError.endsInLocal } } } diff --git a/Sources/OAuthTypes/Utilities.swift b/Sources/OAuthTypes/Utilities.swift index 0822281..394f97a 100644 --- a/Sources/OAuthTypes/Utilities.swift +++ b/Sources/OAuthTypes/Utilities.swift @@ -18,6 +18,11 @@ public func isHostnameIPAddress(_ hostname: String) -> Bool { return true } + // IPv6: a bracketed hostname (e.g. "[::1]") is treated as an IP address. + if hostname.hasPrefix("[") && hostname.hasSuffix("]") { + return true + } + // IPv6 regex: covers full, shorthand, and mixed notations. let ipAddressV6Check = #"^(([0-9A-Fa-f]{1,4}:){7}([0-9A-Fa-f]{1,4}|:))|(([0-9A-Fa-f]{1,4}:){1,7}:)|(([0-9A-Fa-f]{1,4}:){1,6}:[0-9A-Fa-f]{1,4})|(([0-9A-Fa-f]{1,4}:){1,5}(:[0-9A-Fa-f]{1,4}){1,2})|(([0-9A-Fa-f]{1,4}:){1,4}(:[0-9A-Fa-f]{1,4}){1,3})|(([0-9A-Fa-f]{1,4}:){1,3}(:[0-9A-Fa-f]{1,4}){1,4})|(([0-9A-Fa-f]{1,4}:){1,2}(:[0-9A-Fa-f]{1,4}){1,5})|([0-9A-Fa-f]{1,4}:((:[0-9A-Fa-f]{1,4}){1,6}))|(:((:[0-9A-Fa-f]{1,4}){1,7}|:))$"# if hostname.range(of: ipAddressV6Check, options: .regularExpression) != nil { @@ -42,11 +47,15 @@ public enum LoopbackHost: String, CaseIterable { /// Checks if a string is a recognized loopback host. /// -/// - Parameter host: The string to check. +/// The host is matched directly against `"localhost"`, `"127.0.0.1"`, and `"[::1]"`. An +/// unbracketed IPv6 loopback (`"::1"`, the form `URL/host()` produces) is normalized to its +/// bracketed equivalent before matching. +/// +/// - Parameter host: The hostname to check. /// - Returns: `true` if the string is a loopback host, or `false` if it's not. public func isLoopbackHost(_ host: String) -> Bool { - guard let urlHost = URL(string: host)?.host() else { return false } - return LoopbackHost.allCases.contains { $0.rawValue == urlHost } + let normalizedHost = host == "::1" ? "[::1]" : host + return LoopbackHost.allCases.contains { $0.rawValue == normalizedHost } } /// Determines whether the host of the `URL` is a loopback host. diff --git a/Tests/ATOAuthKitTests/ClientIdentifierTests.swift b/Tests/ATOAuthKitTests/ClientIdentifierTests.swift new file mode 100644 index 0000000..4aeba56 --- /dev/null +++ b/Tests/ATOAuthKitTests/ClientIdentifierTests.swift @@ -0,0 +1,284 @@ +import Testing +import OAuthTypes + +@Suite("Discoverable OAuth client IDs") +struct ClientIDDiscoverableTests { + + @Test( + "Accepts well-formed discoverable client IDs", + arguments: [ + "https://app.example.com/oauth-client-metadata.json", + "https://app.example.com/client-metadata.json", + "https://example.com/some/nested/path.json", + "https://example.com:8443/oauth-client-metadata.json", + "https://example.com/oauth-client-metadata.json?foo=bar" + ] + ) + func acceptsValidDiscoverable(_ clientID: String) throws { + #expect(throws: Never.self) { + try ClientIDDiscoverable(validating: clientID) + } + #expect(ClientIDDiscoverable.isDiscoverable(clientID: clientID)) + } + + @Test( + "Rejects discoverable client IDs that embed credentials", + arguments: [ + "https://user@example.com/oauth-client-metadata.json", + "https://user:pass@example.com/oauth-client-metadata.json", + "https://:pass@example.com/oauth-client-metadata.json" + ] + ) + func rejectsCredentials(_ clientID: String) { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: clientID) + } + #expect(ClientIDDiscoverable.isDiscoverable(clientID: clientID) == false) + } + + @Test("Rejects a discoverable client ID with a fragment") + func rejectsFragment() { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: "https://example.com/oauth-client-metadata.json#section") + } + } + + @Test( + "Rejects discoverable client IDs without a path component", + arguments: [ + "https://example.com/", + "https://example.com" + ] + ) + func rejectsRootPath(_ clientID: String) { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: clientID) + } + } + + @Test("Rejects a discoverable client ID whose host is an IP address") + func rejectsIPHost() { + // A non-loopback IP reaches the discoverable IP-address guard. A loopback IP such as + // 127.0.0.1 is rejected earlier by the https URI layer (matching the TypeScript + // httpsUriSchema, which forbids loopback hosts before the IP-address refinement runs). + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: "https://8.8.8.8/oauth-client-metadata.json") + } + } + + @Test("Accepts a non-root path that does not end in a slash") + func acceptsNonRootPathWithoutTrailingSlash() throws { + #expect(throws: Never.self) { + try ClientIDDiscoverable(validating: "https://example.com/foo/oauth-client-metadata.json") + } + } + + @Test("Rejects a non-root path that ends in a trailing slash") + func rejectsNonRootTrailingSlash() { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: "https://example.com/foo/") + } + } + + @Test( + "Rejects a path that is not in canonical form", + arguments: [ + "https://example.com/a/../oauth-client-metadata.json", + "https://example.com/./oauth-client-metadata.json" + ] + ) + func rejectsNonCanonicalPath(_ clientID: String) { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ClientIDDiscoverable(validating: clientID) + } + } +} + +@Suite("Conventional OAuth client IDs") +struct ConventionalOAuthClientIDTests { + + @Test("Accepts the canonical conventional client ID") + func acceptsCanonical() throws { + #expect(throws: Never.self) { + try ConventionalOAuthClientID(validating: "https://app.example.com/oauth-client-metadata.json") + } + } + + @Test("Rejects a conventional client ID that contains a port") + func rejectsPort() { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ConventionalOAuthClientID(validating: "https://app.example.com:8443/oauth-client-metadata.json") + } + } + + @Test("Rejects a conventional client ID that contains a query string") + func rejectsQuery() { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ConventionalOAuthClientID(validating: "https://app.example.com/oauth-client-metadata.json?foo=bar") + } + } + + @Test( + "Rejects a conventional client ID whose path is not /oauth-client-metadata.json", + arguments: [ + "https://app.example.com/client-metadata.json", + "https://app.example.com/other.json" + ] + ) + func rejectsWrongPath(_ clientID: String) { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ConventionalOAuthClientID(validating: clientID) + } + } + + @Test("Rejects a conventional client ID that embeds credentials (discoverable intersection)") + func rejectsCredentials() { + #expect(throws: OAuthClientIDDiscoverableError.self) { + try ConventionalOAuthClientID(validating: "https://user:pass@app.example.com/oauth-client-metadata.json") + } + } +} + +@Suite("OAuth issuer identifiers") +struct IssuerIdentifierTests { + + @Test( + "Accepts canonical issuer identifiers", + arguments: [ + "https://auth.example.com", + "https://auth.example.com/issuer" + ] + ) + func acceptsCanonical(_ issuer: String) throws { + #expect(throws: Never.self) { + try IssuerIdentifier(validating: issuer) + } + } + + @Test("Rejects an issuer identifier that ends with a trailing slash") + func rejectsTrailingSlash() { + #expect(throws: OAuthIssuerIdentifierError.self) { + try IssuerIdentifier(validating: "https://auth.example.com/") + } + } + + @Test( + "Rejects issuer identifiers that embed credentials", + arguments: [ + "https://user@auth.example.com", + "https://user:pass@auth.example.com", + "https://:pass@auth.example.com" + ] + ) + func rejectsCredentials(_ issuer: String) { + #expect(throws: OAuthIssuerIdentifierError.self) { + try IssuerIdentifier(validating: issuer) + } + } + + @Test( + "Rejects issuer identifiers with a query or fragment", + arguments: [ + "https://auth.example.com?foo=bar", + "https://auth.example.com#frag", + "https://auth.example.com/issuer?foo=bar" + ] + ) + func rejectsQueryOrFragment(_ issuer: String) { + #expect(throws: OAuthIssuerIdentifierError.self) { + try IssuerIdentifier(validating: issuer) + } + } + + @Test( + "Rejects issuer identifiers whose host is not lowercased", + arguments: [ + "https://AUTH.EXAMPLE.com", + "https://AUTH.EXAMPLE.com/issuer" + ] + ) + func rejectsMixedCaseHost(_ issuer: String) { + #expect(throws: OAuthIssuerIdentifierError.self) { + try IssuerIdentifier(validating: issuer) + } + } + + @Test( + "Accepts an issuer identifier with a non-default port", + arguments: [ + "https://auth.example.com:8080", + "https://auth.example.com:8080/issuer" + ] + ) + func acceptsNonDefaultPort(_ issuer: String) throws { + #expect(throws: Never.self) { + try IssuerIdentifier(validating: issuer) + } + } + + @Test("Rejects an issuer identifier that spells out the scheme's default port") + func rejectsExplicitDefaultPort() { + #expect(throws: OAuthIssuerIdentifierError.self) { + try IssuerIdentifier(validating: "https://auth.example.com:443") + } + } +} + +@Suite("Loopback OAuth client IDs") +struct ClientIDLoopbackTests { + + @Test( + "Parses loopback client IDs with no query string", + arguments: [ + "http://localhost", + "http://localhost/" + ] + ) + func parsesNoQuery(_ clientID: String) throws { + let result = try ClientIDLoopback.parse(outhLoopbackClientID: clientID) + #expect(result.scope == nil) + #expect(result.redirectURIs == nil) + } + + @Test( + "Parses loopback client IDs with an empty query string", + arguments: [ + "http://localhost?", + "http://localhost/?" + ] + ) + func parsesEmptyQuery(_ clientID: String) throws { + let result = try ClientIDLoopback.parse(outhLoopbackClientID: clientID) + #expect(result.scope == nil) + #expect(result.redirectURIs == nil) + } + + @Test("Parses a loopback client ID after a single slash and query") + func parsesSlashThenQuery() throws { + let result = try ClientIDLoopback.parse(outhLoopbackClientID: "http://localhost/?scope=atproto") + #expect(result.scope != nil) + } + + @Test("Rejects a loopback client ID that contains a path component") + func rejectsPathComponent() { + #expect(throws: OAuthClientIDLoopbackError.self) { + try ClientIDLoopback.parse(outhLoopbackClientID: "http://localhost/path") + } + } + + @Test("Treats a second question mark as part of the query value") + func multipleQuestionMarks() { + // "http://localhost?a=1?b=2" slices to the query "a=1?b=2"; the name "a" is + // not an allowed loopback parameter, so parsing rejects it. + #expect(throws: OAuthClientIDLoopbackError.self) { + try ClientIDLoopback.parse(outhLoopbackClientID: "http://localhost?a=1?b=2") + } + } + + @Test("Rejects a loopback client ID that does not start with the prefix") + func rejectsWrongPrefix() { + #expect(throws: OAuthClientIDLoopbackError.self) { + try ClientIDLoopback.parse(outhLoopbackClientID: "https://localhost") + } + } +} diff --git a/Tests/ATOAuthKitTests/URIValidationTests.swift b/Tests/ATOAuthKitTests/URIValidationTests.swift new file mode 100644 index 0000000..baa0411 --- /dev/null +++ b/Tests/ATOAuthKitTests/URIValidationTests.swift @@ -0,0 +1,180 @@ +import Testing +import OAuthTypes + +@Suite("HTTPS URI validation") +struct HTTPSURIValidationTests { + + @Test( + "Accepts well-formed https URIs with a public hostname or IP address", + arguments: [ + "https://example.com", + "https://sub.example.com", + "https://example.co.uk", + "https://8.8.8.8", + "https://[2001:db8::1]" + ] + ) + func acceptsValidHTTPSURIs(uriString: String) throws { + try URI.validateHTTPSURI(uriString: uriString) + } + + @Test( + "Rejects loopback hosts over https", + arguments: [ + "https://127.0.0.1", + "https://localhost", + "https://[::1]" + ] + ) + func rejectsLoopbackHTTPSURIs(uriString: String) { + #expect(throws: OAuthTypesLabsURIError.self) { + try URI.validateHTTPSURI(uriString: uriString) + } + } + + @Test( + "Rejects https URIs that fail the domain refinements", + arguments: [ + "https://example", + "https://foo.local", + "http://example.com", + "ftp://evil.com" + ] + ) + func rejectsInvalidHTTPSURIs(uriString: String) { + #expect(throws: OAuthTypesLabsURIError.self) { + try URI.validateHTTPSURI(uriString: uriString) + } + } +} + +@Suite("Loopback redirect URI validation") +struct LoopbackRedirectURIValidationTests { + + @Test( + "Accepts http loopback URIs, including localhost", + arguments: [ + "http://127.0.0.1/", + "http://[::1]/", + "http://localhost/", + "http://localhost:8080/callback" + ] + ) + func acceptsLoopbackURIs(rawValue: String) throws { + _ = try URI.LoopbackRedirectURI(validating: rawValue) + } + + @Test( + "Rejects non-loopback hosts and non-http schemes", + arguments: [ + "http://example.com/", + "https://127.0.0.1/", + "ftp://evil.com" + ] + ) + func rejectsNonLoopbackURIs(rawValue: String) { + #expect(throws: OAuthTypesLabsURIError.self) { + try URI.LoopbackRedirectURI(validating: rawValue) + } + } +} + +@Suite("OAuth loopback redirect URI validation") +struct OAuthLoopbackRedirectURIValidationTests { + + @Test( + "Accepts loopback IP redirect URIs", + arguments: [ + "http://127.0.0.1/", + "http://[::1]/", + "http://127.0.0.1:8080/callback" + ] + ) + func acceptsLoopbackIPURIs(rawValue: String) throws { + _ = try OAuthLoopbackRedirectURI(validating: rawValue) + } + + @Test( + "Excludes the localhost hostname", + arguments: [ + "http://localhost/", + "http://localhost:8080/callback" + ] + ) + func excludesLocalhost(rawValue: String) { + #expect(throws: OAuthRedirectURIError.self) { + try OAuthLoopbackRedirectURI(validating: rawValue) + } + } + + @Test( + "Rejects non-loopback hosts and non-http schemes", + arguments: [ + "ftp://evil.com", + "http://example.com/", + "https://127.0.0.1/" + ] + ) + func rejectsNonLoopbackURIs(rawValue: String) { + // These reach the chained loopback URI schema (`URI.LoopbackRedirectURI`), which throws + // before the `localhost` exclusion can run: the non-http schemes fail the protocol guard + // and the non-loopback host fails the loopback guard. + #expect(throws: OAuthTypesLabsURIError.self) { + try OAuthLoopbackRedirectURI(validating: rawValue) + } + } +} + +@Suite("Hostname helper functions") +struct HostnameHelperTests { + + @Test( + "Recognizes IP-address hostnames, including bracketed IPv6", + arguments: [ + "127.0.0.1", + "8.8.8.8", + "[::1]", + "[2001:db8::1]" + ] + ) + func recognizesIPAddresses(hostname: String) { + #expect(isHostnameIPAddress(hostname)) + } + + @Test( + "Does not treat domain names as IP addresses", + arguments: [ + "example.com", + "localhost", + "foo.local" + ] + ) + func rejectsDomainNames(hostname: String) { + #expect(!isHostnameIPAddress(hostname)) + } + + @Test( + "Recognizes loopback hostnames", + arguments: [ + "localhost", + "127.0.0.1", + "[::1]", + "::1" + ] + ) + func recognizesLoopbackHosts(hostname: String) { + #expect(isLoopbackHost(hostname)) + } + + @Test( + "Does not treat non-loopback hostnames as loopback", + arguments: [ + "example.com", + "8.8.8.8", + "http://127.0.0.1/" + ] + ) + func rejectsNonLoopbackHosts(hostname: String) { + #expect(!isLoopbackHost(hostname)) + } +} diff --git a/Tests/ATOAuthKitTests/WireFormatTests.swift b/Tests/ATOAuthKitTests/WireFormatTests.swift new file mode 100644 index 0000000..0399f7e --- /dev/null +++ b/Tests/ATOAuthKitTests/WireFormatTests.swift @@ -0,0 +1,216 @@ +import Foundation +import Testing +import OAuthTypes + +@Suite("OAuth wire-format round trips") +struct WireFormatTests { + + private let decoder = JSONDecoder() + private let encoder = JSONEncoder() + + // MARK: - Token response expires_in + + @Test("Token response decodes expires_in as a lifetime in seconds") + func tokenResponseExpiresInIsSeconds() throws { + let json = Data(""" + {"access_token":"abc123","token_type":"DPoP","expires_in":3600} + """.utf8) + + let response = try decoder.decode(TokenResponse.self, from: json) + + #expect(response.expiresIn == 3600) + #expect(response.accessToken == "abc123") + #expect(response.tokenType == .dpop) + } + + @Test("Token response omits expires_in when absent") + func tokenResponseExpiresInOptional() throws { + let json = Data(""" + {"access_token":"abc123","token_type":"Bearer"} + """.utf8) + + let response = try decoder.decode(TokenResponse.self, from: json) + + #expect(response.expiresIn == nil) + } + + // MARK: - Authorization details "datatypes" + + @Test("Authorization detail maps the lowercase datatypes wire key") + func authorizationDetailDatatypesKey() throws { + let json = Data(""" + {"type":"payment","datatypes":["account","balance"]} + """.utf8) + + let detail = try decoder.decode(AuthorizationDetail.self, from: json) + #expect(detail.dataTypes == ["account", "balance"]) + + let reencoded = try encoder.encode(detail) + let object = try JSONSerialization.jsonObject(with: reencoded) as? [String: Any] + #expect(object?["datatypes"] != nil) + #expect(object?["dataTypes"] == nil) + } + + // MARK: - OIDC zoneinfo claim + + @Test("zoneinfo claim decodes from the lowercase wire value") + func zoneInfoClaimIsLowercase() throws { + let decoded = try decoder.decode(OpenIDConnectClaimsParameter.self, from: Data("\"zoneinfo\"".utf8)) + #expect(decoded == .timeZone) + + let encoded = try encoder.encode(OpenIDConnectClaimsParameter.timeZone) + #expect(String(data: encoded, encoding: .utf8) == "\"zoneinfo\"") + + #expect(throws: (any Error).self) { + _ = try decoder.decode(OpenIDConnectClaimsParameter.self, from: Data("\"zoneInfo\"".utf8)) + } + } + + // MARK: - Token identification optional hint + + @Test("Token identification decodes without token_type_hint") + func tokenIdentificationHintOptional() throws { + let json = Data(""" + {"token":"access-token-value"} + """.utf8) + + let identification = try decoder.decode(TokenIdentification.self, from: json) + #expect(identification.tokenTypeHint == nil) + } + + @Test("Token identification still decodes an explicit token_type_hint") + func tokenIdentificationHintPresent() throws { + let json = Data(""" + {"token":"access-token-value","token_type_hint":"refresh_token"} + """.utf8) + + let identification = try decoder.decode(TokenIdentification.self, from: json) + #expect(identification.tokenTypeHint == .refreshToken) + } + + // MARK: - Token type case-insensitivity + + @Test("DPoP token type decodes case-insensitively", arguments: ["DPoP", "dpop", "DPOP", "dPoP"]) + func dpopTokenTypeDecodesCaseInsensitively(rawValue: String) throws { + let decoded = try decoder.decode(OAuthTokenType.self, from: Data("\"\(rawValue)\"".utf8)) + #expect(decoded == .dpop) + } + + @Test("Bearer token type decodes case-insensitively", arguments: ["Bearer", "bearer", "BEARER", "bEaReR"]) + func bearerTokenTypeDecodesCaseInsensitively(rawValue: String) throws { + let decoded = try decoder.decode(OAuthTokenType.self, from: Data("\"\(rawValue)\"".utf8)) + #expect(decoded == .bearer) + } + + @Test("Token type encodes the canonical form") + func tokenTypeEncodesCanonical() throws { + #expect(String(data: try encoder.encode(OAuthTokenType.dpop), encoding: .utf8) == "\"DPoP\"") + #expect(String(data: try encoder.encode(OAuthTokenType.bearer), encoding: .utf8) == "\"Bearer\"") + } + + @Test("Token type rejects an unknown value") + func tokenTypeRejectsUnknown() { + #expect(throws: (any Error).self) { + _ = try decoder.decode(OAuthTokenType.self, from: Data("\"MAC\"".utf8)) + } + } + + // MARK: - PAR response positive expires_in + + @Test("PAR response decodes a positive expires_in") + func parResponsePositiveExpiresIn() throws { + let json = Data(""" + {"request_uri":"urn:ietf:params:oauth:request_uri:abc","expires_in":90} + """.utf8) + + let response = try decoder.decode(OAuthPARResponse.self, from: json) + #expect(response.expiresIn == 90) + } + + @Test("PAR response rejects a non-positive expires_in", arguments: [0, -1, -3600]) + func parResponseRejectsNonPositiveExpiresIn(expiresIn: Int) { + let json = Data(""" + {"request_uri":"urn:ietf:params:oauth:request_uri:abc","expires_in":\(expiresIn)} + """.utf8) + + #expect(throws: (any Error).self) { + _ = try self.decoder.decode(OAuthPARResponse.self, from: json) + } + } + + // MARK: - Authorization server metadata default auth methods + + @Test("Server metadata defaults token_endpoint_auth_methods_supported when absent") + func serverMetadataDefaultsAuthMethods() throws { + let json = Data(""" + { + "issuer":"https://issuer.example.com", + "authorization_endpoint":"https://issuer.example.com/authorize", + "token_endpoint":"https://issuer.example.com/token" + } + """.utf8) + + let metadata = try decoder.decode(AuthorizationServerMetadata.self, from: json) + #expect(metadata.tokenEndpointAuthMethodsSupported == ["client_secret_basic"]) + } + + @Test("Server metadata keeps an explicit token_endpoint_auth_methods_supported") + func serverMetadataKeepsExplicitAuthMethods() throws { + let json = Data(""" + { + "issuer":"https://issuer.example.com", + "authorization_endpoint":"https://issuer.example.com/authorize", + "token_endpoint":"https://issuer.example.com/token", + "token_endpoint_auth_methods_supported":["private_key_jwt","none"] + } + """.utf8) + + let metadata = try decoder.decode(AuthorizationServerMetadata.self, from: json) + #expect(metadata.tokenEndpointAuthMethodsSupported == ["private_key_jwt", "none"]) + } + + @Test("Server metadata rejects an explicit null token_endpoint_auth_methods_supported") + func serverMetadataRejectsExplicitNullAuthMethods() { + let json = Data(""" + { + "issuer":"https://issuer.example.com", + "authorization_endpoint":"https://issuer.example.com/authorize", + "token_endpoint":"https://issuer.example.com/token", + "token_endpoint_auth_methods_supported":null + } + """.utf8) + + #expect(throws: (any Error).self) { + _ = try self.decoder.decode(AuthorizationServerMetadata.self, from: json) + } + } + + // MARK: - Protected resource metadata query/fragment rejection + + @Test("Protected resource metadata accepts a clean resource URL") + func protectedResourceAcceptsCleanResource() throws { + let json = Data(""" + {"resource":"https://resource.example.com"} + """.utf8) + + let metadata = try decoder.decode(ProtectedResourceMetadata.self, from: json) + #expect(String(describing: metadata.resource) == "https://resource.example.com") + } + + @Test( + "Protected resource metadata rejects query or fragment in debug builds too", + arguments: [ + "https://resource.example.com/path?token=1", + "https://resource.example.com/path#section" + ] + ) + func protectedResourceRejectsQueryOrFragment(resource: String) { + let json = Data(""" + {"resource":"\(resource)"} + """.utf8) + + #expect(throws: (any Error).self) { + _ = try self.decoder.decode(ProtectedResourceMetadata.self, from: json) + } + } +}