Sfoglia il codice sorgente

add simple rate limit for forgot-password

Daniel Bohry 3 settimane fa
parent
commit
f66d3fc3ff

+ 3 - 0
src/main/java/com/danielbohry/authservice/config/SecurityConfig.java

@@ -1,6 +1,7 @@
 package com.danielbohry.authservice.config;
 
 import com.danielbohry.authservice.service.auth.JwtAuthenticationFilter;
+import com.danielbohry.authservice.service.auth.RateLimitingFilter;
 import com.danielbohry.authservice.service.user.UserService;
 import lombok.RequiredArgsConstructor;
 import org.springframework.context.annotation.Bean;
@@ -26,6 +27,7 @@ import static org.springframework.security.config.http.SessionCreationPolicy.STA
 public class SecurityConfig {
 
     private final JwtAuthenticationFilter jwtAuthenticationFilter;
+    private final RateLimitingFilter rateLimitingFilter;
     private final UserService userService;
 
     @Bean
@@ -53,6 +55,7 @@ public class SecurityConfig {
                 )
                 .sessionManagement(manager -> manager.sessionCreationPolicy(STATELESS))
                 .authenticationProvider(authenticationProvider())
+                .addFilterBefore(rateLimitingFilter, UsernamePasswordAuthenticationFilter.class)
                 .addFilterBefore(jwtAuthenticationFilter, UsernamePasswordAuthenticationFilter.class);
 
         return http.build();

+ 70 - 0
src/main/java/com/danielbohry/authservice/service/auth/RateLimitingFilter.java

@@ -0,0 +1,70 @@
+package com.danielbohry.authservice.service.auth;
+
+import jakarta.servlet.FilterChain;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import lombok.extern.slf4j.Slf4j;
+import org.springframework.stereotype.Component;
+import org.springframework.web.filter.OncePerRequestFilter;
+
+import java.io.IOException;
+import java.time.LocalDateTime;
+import java.util.List;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.CopyOnWriteArrayList;
+
+import static java.time.LocalDateTime.now;
+
+@Slf4j
+@Component
+public class RateLimitingFilter extends OncePerRequestFilter {
+
+    private static final int MAX_REQUESTS_PER_MINUTE = 2;
+    private static final String FORGOT_PASSWORD_ENDPOINT = "/api/forgot-password";
+
+    private final ConcurrentHashMap<String, List<LocalDateTime>> requestTracker = new ConcurrentHashMap<>();
+
+    @SuppressWarnings("NullableProblems")
+    @Override
+    protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain)
+            throws ServletException, IOException {
+
+        if (!FORGOT_PASSWORD_ENDPOINT.equals(request.getRequestURI()) ||
+            !"POST".equalsIgnoreCase(request.getMethod())) {
+            filterChain.doFilter(request, response);
+            return;
+        }
+
+        String username = request.getParameter("username");
+        if (username == null || username.trim().isEmpty()) {
+            filterChain.doFilter(request, response);
+            return;
+        }
+
+        if (isRateLimited(username)) {
+            log.warn("Rate limit exceeded for username: {}", username);
+            response.setStatus(429);
+            response.setContentType("application/json");
+            response.getWriter().write("{\"error\":\"Too many requests. Maximum " + MAX_REQUESTS_PER_MINUTE + " requests per minute allowed.\"}");
+            return;
+        }
+
+        recordRequest(username);
+        filterChain.doFilter(request, response);
+    }
+
+    private boolean isRateLimited(String username) {
+        List<LocalDateTime> userRequests = requestTracker.computeIfAbsent(username, k -> new CopyOnWriteArrayList<>());
+        LocalDateTime now = now();
+        LocalDateTime oneMinuteAgo = now.minusMinutes(1);
+        userRequests.removeIf(timestamp -> timestamp.isBefore(oneMinuteAgo));
+
+        return userRequests.size() >= MAX_REQUESTS_PER_MINUTE;
+    }
+
+    private void recordRequest(String username) {
+        List<LocalDateTime> userRequests = requestTracker.computeIfAbsent(username, k -> new CopyOnWriteArrayList<>());
+        userRequests.add(now());
+    }
+}

+ 15 - 0
src/test/java/com/danielbohry/authservice/api/AuthControllerUnitTest.java

@@ -214,4 +214,19 @@ class AuthControllerUnitTest {
 
         verify(authService).register(testRequest);
     }
+
+    @Test
+    void shouldHandleForgotPasswordSuccessfully() {
+        // given
+        String username = "testuser";
+        doNothing().when(authService).forgotPassword(username);
+
+        // when
+        ResponseEntity<Void> response = authController.forgotPassword(username);
+
+        // then
+        assertNotNull(response);
+        assertEquals(HttpStatus.OK, response.getStatusCode());
+        verify(authService).forgotPassword(username);
+    }
 }

+ 183 - 0
src/test/java/com/danielbohry/authservice/service/auth/RateLimitingFilterTest.java

@@ -0,0 +1,183 @@
+package com.danielbohry.authservice.service.auth;
+
+import jakarta.servlet.FilterChain;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.mockito.InjectMocks;
+import org.mockito.Mock;
+import org.mockito.MockitoAnnotations;
+
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.io.StringWriter;
+
+import static org.junit.jupiter.api.Assertions.*;
+import static org.mockito.Mockito.*;
+
+class RateLimitingFilterTest {
+
+    @InjectMocks
+    private RateLimitingFilter rateLimitingFilter;
+
+    @Mock
+    private HttpServletRequest request;
+
+    @Mock
+    private HttpServletResponse response;
+
+    @Mock
+    private FilterChain filterChain;
+
+    private StringWriter stringWriter;
+    private PrintWriter printWriter;
+
+    @BeforeEach
+    void setUp() {
+        MockitoAnnotations.openMocks(this);
+        stringWriter = new StringWriter();
+        printWriter = new PrintWriter(stringWriter);
+    }
+
+    @Test
+    void shouldNotApplyRateLimitingToNonForgotPasswordEndpoints() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/register");
+        when(request.getMethod()).thenReturn("POST");
+
+        // when
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldNotApplyRateLimitingToGetRequests() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("GET");
+
+        // when
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldNotApplyRateLimitingWhenUsernameIsNull() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(request.getParameter("username")).thenReturn(null);
+
+        // when
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldNotApplyRateLimitingWhenUsernameIsEmpty() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(request.getParameter("username")).thenReturn("");
+
+        // when
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldAllowFirstRequest() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(request.getParameter("username")).thenReturn("testuser");
+
+        // when
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldAllowSecondRequest() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(request.getParameter("username")).thenReturn("testuser2");
+
+        // when - first request
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // when - second request
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain, times(2)).doFilter(request, response);
+        verify(response, never()).setStatus(anyInt());
+    }
+
+    @Test
+    void shouldBlockThirdRequest() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(request.getParameter("username")).thenReturn("testuser3");
+        when(response.getWriter()).thenReturn(printWriter);
+
+        // when - first and second requests (should be allowed)
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // when - third request (should be blocked)
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain, times(2)).doFilter(request, response);
+        verify(response).setStatus(429);
+        verify(response).setContentType("application/json");
+
+        String responseContent = stringWriter.toString();
+        assertTrue(responseContent.contains("Too many requests"));
+        assertTrue(responseContent.contains("Maximum 2 requests per minute allowed"));
+    }
+
+    @Test
+    void shouldNotAffectDifferentUsernames() throws ServletException, IOException {
+        // given
+        when(request.getRequestURI()).thenReturn("/api/forgot-password");
+        when(request.getMethod()).thenReturn("POST");
+        when(response.getWriter()).thenReturn(printWriter);
+
+        // when - make 2 requests for user1
+        when(request.getParameter("username")).thenReturn("user1");
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // when - make third request for user1 (should be blocked)
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // when - make first request for user2 (should be allowed)
+        when(request.getParameter("username")).thenReturn("user2");
+        rateLimitingFilter.doFilterInternal(request, response, filterChain);
+
+        // then
+        verify(filterChain, times(3)).doFilter(request, response);
+        verify(response, times(1)).setStatus(429);
+    }
+}