diff --git a/conf/shiro.ini.template b/conf/shiro.ini.template index 6721d175f91..24b18e27109 100644 --- a/conf/shiro.ini.template +++ b/conf/shiro.ini.template @@ -87,6 +87,9 @@ sessionManager = org.apache.shiro.web.session.mgt.DefaultWebSessionManager cookie = org.apache.shiro.web.servlet.SimpleCookie cookie.name = JSESSIONID cookie.httpOnly = true +### Restrict the session cookie to same-site requests by default. Set to NONE only when +### Zeppelin is intentionally embedded into a different origin (and 'cookie.secure = true'). +cookie.sameSite = LAX ### Uncomment the below line only when Zeppelin is running over HTTPS #cookie.secure = true sessionManager.sessionIdCookie = $cookie diff --git a/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java b/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java index a21ef991745..2d025d53611 100644 --- a/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java +++ b/zeppelin-client/src/main/java/org/apache/zeppelin/client/ZeppelinClient.java @@ -592,6 +592,7 @@ public String addParagraph(String noteId, String title, String text) throws Exce bodyObject.put("text", text); HttpResponse response = Unirest.post("/notebook/{noteId}/paragraph") .routeParam("noteId", noteId) + .header("Content-Type", "application/json") .body(bodyObject.toString()) .asJson(); checkResponse(response); @@ -617,6 +618,7 @@ public void updateParagraph(String noteId, String paragraphId, String title, Str HttpResponse response = Unirest.put("/notebook/{noteId}/paragraph/{paragraphId}") .routeParam("noteId", noteId) .routeParam("paragraphId", paragraphId) + .header("Content-Type", "application/json") .body(bodyObject.toString()) .asJson(); checkResponse(response); diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java b/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java index 3b9ebee0bad..07fd7399e14 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/conf/ZeppelinConfiguration.java @@ -1045,7 +1045,9 @@ public enum ConfVars { "https://github.com/yarnpkg/yarn/releases/download/"), // Allows a way to specify a ',' separated list of allowed origins for rest and websockets // i.e. http://localhost:8080 - ZEPPELIN_ALLOWED_ORIGINS("zeppelin.server.allowed.origins", "*"), + // Default is empty (no cross-origin requests permitted). Operators that need cross-origin + // access must set this explicitly to the trusted origin(s) or to "*". + ZEPPELIN_ALLOWED_ORIGINS("zeppelin.server.allowed.origins", ""), ZEPPELIN_USERNAME_FORCE_LOWERCASE("zeppelin.username.force.lowercase", false), ZEPPELIN_CREDENTIALS_PERSIST("zeppelin.credentials.persist", true), ZEPPELIN_CREDENTIALS_ENCRYPT_KEY("zeppelin.credentials.encryptKey", null), diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilter.java new file mode 100644 index 00000000000..af2cd417883 --- /dev/null +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilter.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.rest.filter; + +import java.util.Locale; +import java.util.Set; + +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.container.ContainerRequestFilter; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; +import jakarta.ws.rs.ext.Provider; + +import org.apache.zeppelin.utils.HttpMethods; + +/** + * Restricts the request body media types accepted by REST endpoints to a small allow-list. + * Requests carrying state-changing methods (POST/PUT/DELETE/PATCH) with a body must use + * {@code application/json}, {@code application/x-www-form-urlencoded}, or + * {@code multipart/form-data}; anything else is rejected with 415. + */ +@Provider +public class JsonContentTypeFilter implements ContainerRequestFilter { + + private static final Set ALLOWED_TYPES = Set.of( + "application/json", + "application/x-www-form-urlencoded", + "multipart/form-data"); + + @Override + public void filter(ContainerRequestContext ctx) { + String method = ctx.getMethod(); + if (method == null || !HttpMethods.STATE_CHANGING.contains(method.toUpperCase(Locale.ROOT))) { + return; + } + if (!ctx.hasEntity()) { + return; + } + MediaType mt = ctx.getMediaType(); + if (mt == null || !ALLOWED_TYPES.contains(baseType(mt))) { + ctx.abortWith( + Response.status(Response.Status.UNSUPPORTED_MEDIA_TYPE) + .entity("Unsupported Content-Type") + .build()); + } + } + + private static String baseType(MediaType mt) { + return (mt.getType() + "/" + mt.getSubtype()).toLowerCase(Locale.ROOT); + } +} diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java index 51906cfd31d..d9d1c39ec87 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/CorsFilter.java @@ -18,6 +18,7 @@ import java.io.IOException; import java.net.URISyntaxException; +import java.util.Locale; import jakarta.servlet.Filter; import jakarta.servlet.FilterChain; import jakarta.servlet.FilterConfig; @@ -28,6 +29,7 @@ import jakarta.servlet.http.HttpServletResponse; import org.apache.zeppelin.conf.ZeppelinConfiguration; import org.apache.zeppelin.utils.CorsUtils; +import org.apache.zeppelin.utils.HttpMethods; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -46,33 +48,52 @@ public CorsFilter(ZeppelinConfiguration zConf) { @Override public void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) throws IOException, ServletException { - String sourceHost = ((HttpServletRequest) request).getHeader("Origin"); - String origin = ""; + HttpServletRequest httpRequest = (HttpServletRequest) request; + HttpServletResponse httpResponse = (HttpServletResponse) response; - try { - if (CorsUtils.isValidOrigin(sourceHost, zConf)) { - origin = sourceHost; + String sourceHost = httpRequest.getHeader(CorsUtils.HEADER_ORIGIN); + String method = httpRequest.getMethod(); + String allowedOrigin = ""; + + if (sourceHost != null && !sourceHost.isEmpty()) { + try { + if (CorsUtils.isValidOrigin(sourceHost, zConf)) { + allowedOrigin = sourceHost; + } + } catch (URISyntaxException e) { + LOGGER.warn("Rejecting request with malformed Origin header: {}", sourceHost); } - } catch (URISyntaxException e) { - LOGGER.error("Exception in WebDriverManager while getWebDriver ", e); - } - if (((HttpServletRequest) request).getMethod().equals("OPTIONS")) { - HttpServletResponse resp = ((HttpServletResponse) response); - addCorsHeaders(resp, origin); - return; + if (allowedOrigin.isEmpty() && (isCorsPreflight(httpRequest) || isStateChanging(method))) { + LOGGER.warn("Blocking cross-origin {} request from disallowed Origin: {}", + method, sourceHost); + httpResponse.sendError(HttpServletResponse.SC_FORBIDDEN, "Origin not allowed"); + return; + } } - if (response instanceof HttpServletResponse) { - HttpServletResponse alteredResponse = ((HttpServletResponse) response); - addCorsHeaders(alteredResponse, origin); + addCorsHeaders(httpResponse, allowedOrigin); + if (isCorsPreflight(httpRequest)) { + return; } filterChain.doFilter(request, response); } + private static boolean isCorsPreflight(HttpServletRequest request) { + return "OPTIONS".equalsIgnoreCase(request.getMethod()) + && request.getHeader("Access-Control-Request-Method") != null; + } + + private static boolean isStateChanging(String method) { + return method != null + && HttpMethods.STATE_CHANGING.contains(method.toUpperCase(Locale.ROOT)); + } + private void addCorsHeaders(HttpServletResponse response, String origin) { response.setHeader("Access-Control-Allow-Origin", origin); - response.setHeader("Access-Control-Allow-Credentials", "true"); + if (!origin.isEmpty()) { + response.setHeader("Access-Control-Allow-Credentials", "true"); + } response.setHeader("Access-Control-Allow-Headers", "authorization,Content-Type"); response.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, HEAD, DELETE"); diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java b/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java index f4cff80206e..3dca9e269c6 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/server/RestApiApplication.java @@ -36,6 +36,7 @@ import org.apache.zeppelin.rest.ZeppelinRestApi; import org.apache.zeppelin.rest.exception.WebApplicationExceptionMapper; import org.apache.zeppelin.rest.filter.CacheControlFilter; +import org.apache.zeppelin.rest.filter.JsonContentTypeFilter; import org.glassfish.jersey.server.ServerProperties; public class RestApiApplication extends Application { @@ -60,6 +61,7 @@ public Set> getClasses() { s.add(GsonProvider.class); // Filter s.add(CacheControlFilter.class); + s.add(JsonContentTypeFilter.class); return s; } diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java index 363bfaf2590..1d20783c4b8 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/CorsUtils.java @@ -20,6 +20,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.net.UnknownHostException; +import java.util.Locale; import org.apache.zeppelin.conf.ZeppelinConfiguration; public class CorsUtils { @@ -36,15 +37,19 @@ public static boolean isValidOrigin(String sourceHost, ZeppelinConfiguration zCo if (sourceHost != null && !sourceHost.isEmpty()) { sourceUriHost = new URI(sourceHost).getHost(); - sourceUriHost = (sourceUriHost == null) ? "" : sourceUriHost.toLowerCase(); + sourceUriHost = (sourceUriHost == null) ? "" : sourceUriHost.toLowerCase(Locale.ROOT); } - sourceUriHost = sourceUriHost.toLowerCase(); - String currentHost = InetAddress.getLocalHost().getHostName().toLowerCase(); + String currentHost = InetAddress.getLocalHost().getHostName().toLowerCase(Locale.ROOT); + // getAllowedOrigins() returns lowercased entries; normalize sourceHost the same way + // before the membership check so case differences in the Origin header do not produce + // false rejections of explicitly configured origins. + String normalizedOrigin = + sourceHost == null ? "" : sourceHost.toLowerCase(Locale.ROOT); return zConf.getAllowedOrigins().contains("*") || currentHost.equals(sourceUriHost) || "localhost".equals(sourceUriHost) - || zConf.getAllowedOrigins().contains(sourceHost); + || zConf.getAllowedOrigins().contains(normalizedOrigin); } } diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java new file mode 100644 index 00000000000..440ab7ff1fd --- /dev/null +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/utils/HttpMethods.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.utils; + +import java.util.Set; + +public final class HttpMethods { + + private HttpMethods() { + } + + public static final Set STATE_CHANGING = Set.of("POST", "PUT", "DELETE", "PATCH"); +} diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java index a0ab71e5933..63fcd0da212 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/conf/ZeppelinConfigurationTest.java @@ -53,7 +53,7 @@ void getAllowedOriginsNoneTest() throws MalformedURLException { ZeppelinConfiguration zConf = ZeppelinConfiguration.load("zeppelin-test-site.xml"); List origins = zConf.getAllowedOrigins(); - assertEquals(1, origins.size()); + assertTrue(origins.isEmpty()); } @Test diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilterTest.java new file mode 100644 index 00000000000..9401211113c --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/rest/filter/JsonContentTypeFilterTest.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.rest.filter; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import jakarta.ws.rs.container.ContainerRequestContext; +import jakarta.ws.rs.core.MediaType; +import jakarta.ws.rs.core.Response; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; + +class JsonContentTypeFilterTest { + + private final JsonContentTypeFilter filter = new JsonContentTypeFilter(); + + @Test + void getRequestPasses() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("GET"); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } + + @Test + void postWithoutBodyPasses() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(false); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } + + @Test + void postJsonPasses() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(MediaType.APPLICATION_JSON_TYPE); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } + + @Test + void postFormUrlEncodedPasses() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(MediaType.APPLICATION_FORM_URLENCODED_TYPE); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } + + @Test + void postMultipartFormDataPasses() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(MediaType.MULTIPART_FORM_DATA_TYPE); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } + + @Test + void postTextPlainRejected() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(MediaType.TEXT_PLAIN_TYPE); + + filter.filter(ctx); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Response.class); + verify(ctx, times(1)).abortWith(captor.capture()); + org.junit.jupiter.api.Assertions.assertEquals( + Response.Status.UNSUPPORTED_MEDIA_TYPE.getStatusCode(), + captor.getValue().getStatus()); + } + + @Test + void postWithoutContentTypeRejected() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(null); + + filter.filter(ctx); + + verify(ctx, times(1)).abortWith(any()); + } + + @ParameterizedTest + @ValueSource(strings = {"PUT", "DELETE", "PATCH"}) + void stateChangingTextPlainRejected(String method) { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn(method); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn(MediaType.TEXT_PLAIN_TYPE); + + filter.filter(ctx); + + verify(ctx, times(1)).abortWith(any()); + } + + @Test + void contentTypeWithCharsetParameterAllowed() { + ContainerRequestContext ctx = mock(ContainerRequestContext.class); + when(ctx.getMethod()).thenReturn("POST"); + when(ctx.hasEntity()).thenReturn(true); + when(ctx.getMediaType()).thenReturn( + MediaType.valueOf("application/json; charset=UTF-8")); + + filter.filter(ctx); + + verify(ctx, never()).abortWith(any()); + } +} diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java index 0a6f1eddfb6..1048c0ba258 100644 --- a/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/server/CorsFilterTest.java @@ -17,15 +17,21 @@ package org.apache.zeppelin.server; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import org.apache.zeppelin.conf.ZeppelinConfiguration; import org.junit.jupiter.api.Test; -import org.mockito.invocation.InvocationOnMock; -import org.mockito.stubbing.Answer; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import jakarta.servlet.FilterChain; import jakarta.servlet.ServletException; @@ -33,17 +39,16 @@ import jakarta.servlet.http.HttpServletResponse; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; /** * Basic CORS REST API tests. */ class CorsFilterTest { - public static String[] headers = new String[8]; - public static Integer count = 0; @Test - @SuppressWarnings("rawtypes") void validCorsFilterTest() throws IOException, ServletException { CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); HttpServletResponse mockResponse = mock(HttpServletResponse.class); @@ -51,24 +56,14 @@ void validCorsFilterTest() throws IOException, ServletException { HttpServletRequest mockRequest = mock(HttpServletRequest.class); when(mockRequest.getHeader("Origin")).thenReturn("http://localhost:8080"); when(mockRequest.getMethod()).thenReturn("Empty"); - when(mockRequest.getServerName()).thenReturn("localhost"); - count = 0; - - doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - headers[count] = invocationOnMock.getArguments()[1].toString(); - count++; - return null; - } - }).when(mockResponse).setHeader(anyString(), anyString()); + Map setHeaders = recordSetHeaders(mockResponse); filter.doFilter(mockRequest, mockResponse, mockedFilterChain); - assertEquals("http://localhost:8080", headers[0]); + + assertEquals("http://localhost:8080", setHeaders.get("Access-Control-Allow-Origin")); } @Test - @SuppressWarnings("rawtypes") void invalidCorsFilterTest() throws IOException, ServletException { CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); HttpServletResponse mockResponse = mock(HttpServletResponse.class); @@ -76,18 +71,118 @@ void invalidCorsFilterTest() throws IOException, ServletException { HttpServletRequest mockRequest = mock(HttpServletRequest.class); when(mockRequest.getHeader("Origin")).thenReturn("http://evillocalhost:8080"); when(mockRequest.getMethod()).thenReturn("Empty"); - when(mockRequest.getServerName()).thenReturn("evillocalhost"); + Map setHeaders = recordSetHeaders(mockResponse); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + assertEquals("", setHeaders.get("Access-Control-Allow-Origin")); + } + + @ParameterizedTest + @ValueSource(strings = {"POST", "PUT", "DELETE", "PATCH"}) + void crossOriginStateChangingBlocked(String method) throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com"); + when(mockRequest.getMethod()).thenReturn(method); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + verify(mockResponse).sendError(eq(HttpServletResponse.SC_FORBIDDEN), anyString()); + verify(mockedFilterChain, never()).doFilter(mockRequest, mockResponse); + } + + @Test + void crossOriginPreflightBlocked() throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com"); + when(mockRequest.getHeader("Access-Control-Request-Method")).thenReturn("POST"); + when(mockRequest.getMethod()).thenReturn("OPTIONS"); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + verify(mockResponse).sendError(eq(HttpServletResponse.SC_FORBIDDEN), anyString()); + verify(mockedFilterChain, never()).doFilter(mockRequest, mockResponse); + } + + @Test + void allowedOriginPostPasses() throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://localhost"); + when(mockRequest.getMethod()).thenReturn("POST"); + Map setHeaders = recordSetHeaders(mockResponse); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + verify(mockResponse, never()).sendError(anyInt(), anyString()); + verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse); + assertEquals("http://localhost", setHeaders.get("Access-Control-Allow-Origin")); + assertEquals("true", setHeaders.get("Access-Control-Allow-Credentials")); + } - doAnswer(new Answer() { - @Override - public Object answer(InvocationOnMock invocationOnMock) throws Throwable { - headers[count] = invocationOnMock.getArguments()[1].toString(); - count++; - return null; - } - }).when(mockResponse).setHeader(anyString(), anyString()); + @Test + void disallowedOriginGetPasses() throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com"); + when(mockRequest.getMethod()).thenReturn("GET"); + Map setHeaders = recordSetHeaders(mockResponse); filter.doFilter(mockRequest, mockResponse, mockedFilterChain); - assertEquals("", headers[0]); + + verify(mockResponse, never()).sendError(anyInt(), anyString()); + verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse); + assertEquals("", setHeaders.get("Access-Control-Allow-Origin")); + assertNull(setHeaders.get("Access-Control-Allow-Credentials")); + } + + @Test + void noOriginPostPasses() throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn(null); + when(mockRequest.getMethod()).thenReturn("POST"); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + verify(mockResponse, never()).sendError(anyInt(), anyString()); + verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse); + } + + @Test + void simpleOptionsWithoutPreflightHeaderPasses() throws IOException, ServletException { + CorsFilter filter = new CorsFilter(ZeppelinConfiguration.load()); + HttpServletRequest mockRequest = mock(HttpServletRequest.class); + HttpServletResponse mockResponse = mock(HttpServletResponse.class); + FilterChain mockedFilterChain = mock(FilterChain.class); + when(mockRequest.getHeader("Origin")).thenReturn("http://evil.example.com"); + when(mockRequest.getHeader("Access-Control-Request-Method")).thenReturn(null); + when(mockRequest.getMethod()).thenReturn("OPTIONS"); + + filter.doFilter(mockRequest, mockResponse, mockedFilterChain); + + verify(mockResponse, never()).sendError(anyInt(), anyString()); + verify(mockedFilterChain, times(1)).doFilter(mockRequest, mockResponse); + } + + private static Map recordSetHeaders(HttpServletResponse response) { + Map recorded = new HashMap<>(); + doAnswer(invocation -> { + recorded.put(invocation.getArgument(0), invocation.getArgument(1)); + return null; + }).when(response).setHeader(anyString(), anyString()); + return recorded; } }