diff --git a/backend/README.md b/backend/README.md index a4972e9..1cabee8 100644 --- a/backend/README.md +++ b/backend/README.md @@ -18,12 +18,14 @@ 默认配置: ```bash +APP_JWT_SECRET=<至少32字节的随机密钥> \ mvn spring-boot:run ``` 本地联调建议使用 `dev` 环境: ```bash +APP_JWT_SECRET=<至少32字节的随机密钥> \ mvn spring-boot:run -Dspring-boot.run.profiles=dev ``` @@ -33,6 +35,13 @@ mvn spring-boot:run -Dspring-boot.run.profiles=dev - CQU 接口返回 mock 数据 - 方便和 `vue/` 前端直接联调 +JWT 启动要求: + +- `app.jwt.secret` 不能为空 +- 不允许使用默认占位值 +- 至少需要 32 字节强密钥 +- 仓库内的 `application.yml` / `application-dev.yml` 只从环境变量 `APP_JWT_SECRET` 读取,不再内置可直接启动的默认 secret + ## 访问地址 - Swagger: `http://localhost:8080/swagger-ui.html` @@ -84,6 +93,7 @@ CREATE INDEX IF NOT EXISTS idx_grade_user_semester ON portal_grade (user_id, sem - `POST /api/auth/register` - `POST /api/auth/login` +- `POST /api/auth/refresh` - `GET /api/user/profile` - `POST /api/files/upload` - `POST /api/files/upload/initiate` diff --git a/backend/src/main/java/com/yoyuzh/auth/AuthController.java b/backend/src/main/java/com/yoyuzh/auth/AuthController.java index f5242d8..33ebd84 100644 --- a/backend/src/main/java/com/yoyuzh/auth/AuthController.java +++ b/backend/src/main/java/com/yoyuzh/auth/AuthController.java @@ -2,6 +2,7 @@ package com.yoyuzh.auth; import com.yoyuzh.auth.dto.AuthResponse; import com.yoyuzh.auth.dto.LoginRequest; +import com.yoyuzh.auth.dto.RefreshTokenRequest; import com.yoyuzh.auth.dto.RegisterRequest; import com.yoyuzh.common.ApiResponse; import io.swagger.v3.oas.annotations.Operation; @@ -30,4 +31,10 @@ public class AuthController { public ApiResponse login(@Valid @RequestBody LoginRequest request) { return ApiResponse.success(authService.login(request)); } + + @Operation(summary = "刷新访问令牌") + @PostMapping("/refresh") + public ApiResponse refresh(@Valid @RequestBody RefreshTokenRequest request) { + return ApiResponse.success(authService.refresh(request.refreshToken())); + } } diff --git a/backend/src/main/java/com/yoyuzh/auth/AuthService.java b/backend/src/main/java/com/yoyuzh/auth/AuthService.java index 47a2c93..be08876 100644 --- a/backend/src/main/java/com/yoyuzh/auth/AuthService.java +++ b/backend/src/main/java/com/yoyuzh/auth/AuthService.java @@ -23,6 +23,7 @@ public class AuthService { private final PasswordEncoder passwordEncoder; private final AuthenticationManager authenticationManager; private final JwtTokenProvider jwtTokenProvider; + private final RefreshTokenService refreshTokenService; private final FileService fileService; @Transactional @@ -40,7 +41,7 @@ public class AuthService { user.setPasswordHash(passwordEncoder.encode(request.password())); User saved = userRepository.save(user); fileService.ensureDefaultDirectories(saved); - return new AuthResponse(jwtTokenProvider.generateToken(saved.getId(), saved.getUsername()), toProfile(saved)); + return issueTokens(saved); } public AuthResponse login(LoginRequest request) { @@ -54,7 +55,7 @@ public class AuthService { User user = userRepository.findByUsername(request.username()) .orElseThrow(() -> new BusinessException(ErrorCode.NOT_LOGGED_IN, "用户不存在")); fileService.ensureDefaultDirectories(user); - return new AuthResponse(jwtTokenProvider.generateToken(user.getId(), user.getUsername()), toProfile(user)); + return issueTokens(user); } @Transactional @@ -73,7 +74,13 @@ public class AuthService { return userRepository.save(created); }); fileService.ensureDefaultDirectories(user); - return new AuthResponse(jwtTokenProvider.generateToken(user.getId(), user.getUsername()), toProfile(user)); + return issueTokens(user); + } + + @Transactional + public AuthResponse refresh(String refreshToken) { + RefreshTokenService.RotatedRefreshToken rotated = refreshTokenService.rotateRefreshToken(refreshToken); + return issueTokens(rotated.user(), rotated.refreshToken()); } public UserProfileResponse getProfile(String username) { @@ -85,4 +92,13 @@ public class AuthService { private UserProfileResponse toProfile(User user) { return new UserProfileResponse(user.getId(), user.getUsername(), user.getEmail(), user.getCreatedAt()); } + + private AuthResponse issueTokens(User user) { + return issueTokens(user, refreshTokenService.issueRefreshToken(user)); + } + + private AuthResponse issueTokens(User user, String refreshToken) { + String accessToken = jwtTokenProvider.generateAccessToken(user.getId(), user.getUsername()); + return AuthResponse.issued(accessToken, refreshToken, toProfile(user)); + } } diff --git a/backend/src/main/java/com/yoyuzh/auth/JwtTokenProvider.java b/backend/src/main/java/com/yoyuzh/auth/JwtTokenProvider.java index 5957045..20a63e0 100644 --- a/backend/src/main/java/com/yoyuzh/auth/JwtTokenProvider.java +++ b/backend/src/main/java/com/yoyuzh/auth/JwtTokenProvider.java @@ -15,6 +15,8 @@ import java.util.Date; @Component public class JwtTokenProvider { + private static final String DEFAULT_SECRET = "change-me-change-me-change-me-change-me"; + private final JwtProperties jwtProperties; private SecretKey secretKey; @@ -24,16 +26,26 @@ public class JwtTokenProvider { @PostConstruct public void init() { - secretKey = Keys.hmacShaKeyFor(jwtProperties.getSecret().getBytes(StandardCharsets.UTF_8)); + String secret = jwtProperties.getSecret() == null ? "" : jwtProperties.getSecret().trim(); + if (secret.isEmpty()) { + throw new IllegalStateException("app.jwt.secret 未配置,请设置强密钥后再启动"); + } + if (DEFAULT_SECRET.equals(secret)) { + throw new IllegalStateException("检测到默认 JWT 密钥,请替换 app.jwt.secret 后再启动"); + } + if (secret.getBytes(StandardCharsets.UTF_8).length < 32) { + throw new IllegalStateException("JWT 密钥长度过短,至少需要 32 字节"); + } + secretKey = Keys.hmacShaKeyFor(secret.getBytes(StandardCharsets.UTF_8)); } - public String generateToken(Long userId, String username) { + public String generateAccessToken(Long userId, String username) { Instant now = Instant.now(); return Jwts.builder() .subject(username) .claim("uid", userId) .issuedAt(Date.from(now)) - .expiration(Date.from(now.plusSeconds(jwtProperties.getExpirationSeconds()))) + .expiration(Date.from(now.plusSeconds(jwtProperties.getAccessExpirationSeconds()))) .signWith(secretKey) .compact(); } diff --git a/backend/src/main/java/com/yoyuzh/auth/RefreshToken.java b/backend/src/main/java/com/yoyuzh/auth/RefreshToken.java new file mode 100644 index 0000000..9aa845c --- /dev/null +++ b/backend/src/main/java/com/yoyuzh/auth/RefreshToken.java @@ -0,0 +1,115 @@ +package com.yoyuzh.auth; + +import jakarta.persistence.Column; +import jakarta.persistence.Entity; +import jakarta.persistence.FetchType; +import jakarta.persistence.GeneratedValue; +import jakarta.persistence.GenerationType; +import jakarta.persistence.Id; +import jakarta.persistence.Index; +import jakarta.persistence.JoinColumn; +import jakarta.persistence.ManyToOne; +import jakarta.persistence.PrePersist; +import jakarta.persistence.Table; + +import java.time.LocalDateTime; + +@Entity +@Table(name = "portal_refresh_token", indexes = { + @Index(name = "uk_refresh_token_hash", columnList = "token_hash", unique = true), + @Index(name = "idx_refresh_user_expired", columnList = "user_id, expires_at"), + @Index(name = "idx_refresh_revoked", columnList = "revoked") +}) +public class RefreshToken { + + @Id + @GeneratedValue(strategy = GenerationType.IDENTITY) + private Long id; + + @ManyToOne(fetch = FetchType.LAZY, optional = false) + @JoinColumn(name = "user_id", nullable = false) + private User user; + + @Column(name = "token_hash", nullable = false, length = 64, unique = true) + private String tokenHash; + + @Column(name = "expires_at", nullable = false) + private LocalDateTime expiresAt; + + @Column(name = "revoked", nullable = false) + private boolean revoked; + + @Column(name = "created_at", nullable = false) + private LocalDateTime createdAt; + + @Column(name = "revoked_at") + private LocalDateTime revokedAt; + + @PrePersist + public void prePersist() { + if (createdAt == null) { + createdAt = LocalDateTime.now(); + } + } + + public void revoke(LocalDateTime revokedAt) { + this.revoked = true; + this.revokedAt = revokedAt; + } + + public Long getId() { + return id; + } + + public void setId(Long id) { + this.id = id; + } + + public User getUser() { + return user; + } + + public void setUser(User user) { + this.user = user; + } + + public String getTokenHash() { + return tokenHash; + } + + public void setTokenHash(String tokenHash) { + this.tokenHash = tokenHash; + } + + public LocalDateTime getExpiresAt() { + return expiresAt; + } + + public void setExpiresAt(LocalDateTime expiresAt) { + this.expiresAt = expiresAt; + } + + public boolean isRevoked() { + return revoked; + } + + public void setRevoked(boolean revoked) { + this.revoked = revoked; + } + + public LocalDateTime getCreatedAt() { + return createdAt; + } + + public void setCreatedAt(LocalDateTime createdAt) { + this.createdAt = createdAt; + } + + public LocalDateTime getRevokedAt() { + return revokedAt; + } + + public void setRevokedAt(LocalDateTime revokedAt) { + this.revokedAt = revokedAt; + } +} diff --git a/backend/src/main/java/com/yoyuzh/auth/RefreshTokenRepository.java b/backend/src/main/java/com/yoyuzh/auth/RefreshTokenRepository.java new file mode 100644 index 0000000..f4aecba --- /dev/null +++ b/backend/src/main/java/com/yoyuzh/auth/RefreshTokenRepository.java @@ -0,0 +1,15 @@ +package com.yoyuzh.auth; + +import jakarta.persistence.LockModeType; +import org.springframework.data.jpa.repository.JpaRepository; +import org.springframework.data.jpa.repository.Lock; +import org.springframework.data.jpa.repository.Query; + +import java.util.Optional; + +public interface RefreshTokenRepository extends JpaRepository { + + @Lock(LockModeType.PESSIMISTIC_WRITE) + @Query("select token from RefreshToken token join fetch token.user where token.tokenHash = :tokenHash") + Optional findForUpdateByTokenHash(String tokenHash); +} diff --git a/backend/src/main/java/com/yoyuzh/auth/RefreshTokenService.java b/backend/src/main/java/com/yoyuzh/auth/RefreshTokenService.java new file mode 100644 index 0000000..72426a7 --- /dev/null +++ b/backend/src/main/java/com/yoyuzh/auth/RefreshTokenService.java @@ -0,0 +1,84 @@ +package com.yoyuzh.auth; + +import com.yoyuzh.common.BusinessException; +import com.yoyuzh.common.ErrorCode; +import com.yoyuzh.config.JwtProperties; +import lombok.RequiredArgsConstructor; +import org.springframework.stereotype.Service; +import org.springframework.transaction.annotation.Transactional; + +import java.nio.charset.StandardCharsets; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.time.LocalDateTime; +import java.util.Base64; +import java.util.HexFormat; + +@Service +@RequiredArgsConstructor +public class RefreshTokenService { + + private static final int REFRESH_TOKEN_BYTES = 48; + + private final RefreshTokenRepository refreshTokenRepository; + private final JwtProperties jwtProperties; + private final SecureRandom secureRandom = new SecureRandom(); + + @Transactional + public String issueRefreshToken(User user) { + String rawToken = generateRawToken(); + + RefreshToken refreshToken = new RefreshToken(); + refreshToken.setUser(user); + refreshToken.setTokenHash(hashToken(rawToken)); + refreshToken.setExpiresAt(LocalDateTime.now().plusSeconds(jwtProperties.getRefreshExpirationSeconds())); + refreshToken.setRevoked(false); + refreshTokenRepository.save(refreshToken); + + return rawToken; + } + + @Transactional(noRollbackFor = BusinessException.class) + public RotatedRefreshToken rotateRefreshToken(String rawToken) { + RefreshToken existing = refreshTokenRepository.findForUpdateByTokenHash(hashToken(rawToken)) + .orElseThrow(() -> new BusinessException(ErrorCode.NOT_LOGGED_IN, "刷新令牌无效")); + + if (existing.isRevoked()) { + throw new BusinessException(ErrorCode.NOT_LOGGED_IN, "刷新令牌无效或已使用"); + } + + if (existing.getExpiresAt().isBefore(LocalDateTime.now())) { + existing.revoke(LocalDateTime.now()); + throw new BusinessException(ErrorCode.NOT_LOGGED_IN, "刷新令牌已过期"); + } + + User user = existing.getUser(); + existing.revoke(LocalDateTime.now()); + + String nextRefreshToken = issueRefreshToken(user); + return new RotatedRefreshToken(user, nextRefreshToken); + } + + private String generateRawToken() { + byte[] bytes = new byte[REFRESH_TOKEN_BYTES]; + secureRandom.nextBytes(bytes); + return Base64.getUrlEncoder().withoutPadding().encodeToString(bytes); + } + + private String hashToken(String rawToken) { + if (rawToken == null || rawToken.isBlank()) { + throw new BusinessException(ErrorCode.NOT_LOGGED_IN, "刷新令牌不能为空"); + } + try { + MessageDigest digest = MessageDigest.getInstance("SHA-256"); + byte[] hash = digest.digest(rawToken.getBytes(StandardCharsets.UTF_8)); + return HexFormat.of().formatHex(hash); + } catch (NoSuchAlgorithmException ex) { + throw new IllegalStateException("无法初始化刷新令牌哈希算法", ex); + } + } + + public record RotatedRefreshToken(User user, String refreshToken) { + } +} diff --git a/backend/src/main/java/com/yoyuzh/auth/dto/AuthResponse.java b/backend/src/main/java/com/yoyuzh/auth/dto/AuthResponse.java index b318269..56a91b2 100644 --- a/backend/src/main/java/com/yoyuzh/auth/dto/AuthResponse.java +++ b/backend/src/main/java/com/yoyuzh/auth/dto/AuthResponse.java @@ -1,4 +1,8 @@ package com.yoyuzh.auth.dto; -public record AuthResponse(String token, UserProfileResponse user) { +public record AuthResponse(String token, String accessToken, String refreshToken, UserProfileResponse user) { + + public static AuthResponse issued(String accessToken, String refreshToken, UserProfileResponse user) { + return new AuthResponse(accessToken, accessToken, refreshToken, user); + } } diff --git a/backend/src/main/java/com/yoyuzh/auth/dto/RefreshTokenRequest.java b/backend/src/main/java/com/yoyuzh/auth/dto/RefreshTokenRequest.java new file mode 100644 index 0000000..3c64b01 --- /dev/null +++ b/backend/src/main/java/com/yoyuzh/auth/dto/RefreshTokenRequest.java @@ -0,0 +1,6 @@ +package com.yoyuzh.auth.dto; + +import jakarta.validation.constraints.NotBlank; + +public record RefreshTokenRequest(@NotBlank String refreshToken) { +} diff --git a/backend/src/main/java/com/yoyuzh/auth/dto/RegisterRequest.java b/backend/src/main/java/com/yoyuzh/auth/dto/RegisterRequest.java index 743e333..0725188 100644 --- a/backend/src/main/java/com/yoyuzh/auth/dto/RegisterRequest.java +++ b/backend/src/main/java/com/yoyuzh/auth/dto/RegisterRequest.java @@ -1,12 +1,40 @@ package com.yoyuzh.auth.dto; import jakarta.validation.constraints.Email; +import jakarta.validation.constraints.AssertTrue; import jakarta.validation.constraints.NotBlank; import jakarta.validation.constraints.Size; public record RegisterRequest( @NotBlank @Size(min = 3, max = 64) String username, @NotBlank @Email @Size(max = 128) String email, - @NotBlank @Size(min = 6, max = 64) String password + @NotBlank @Size(min = 10, max = 64, message = "密码至少10位,且必须包含大写字母、小写字母、数字和特殊字符") String password ) { + + @AssertTrue(message = "密码至少10位,且必须包含大写字母、小写字母、数字和特殊字符") + public boolean isPasswordStrong() { + if (password == null || password.length() < 10) { + return false; + } + + boolean hasLower = false; + boolean hasUpper = false; + boolean hasDigit = false; + boolean hasSpecial = false; + + for (int i = 0; i < password.length(); i += 1) { + char c = password.charAt(i); + if (Character.isLowerCase(c)) { + hasLower = true; + } else if (Character.isUpperCase(c)) { + hasUpper = true; + } else if (Character.isDigit(c)) { + hasDigit = true; + } else { + hasSpecial = true; + } + } + + return hasLower && hasUpper && hasDigit && hasSpecial; + } } diff --git a/backend/src/main/java/com/yoyuzh/common/GlobalExceptionHandler.java b/backend/src/main/java/com/yoyuzh/common/GlobalExceptionHandler.java index 380ef17..c4cfa1b 100644 --- a/backend/src/main/java/com/yoyuzh/common/GlobalExceptionHandler.java +++ b/backend/src/main/java/com/yoyuzh/common/GlobalExceptionHandler.java @@ -1,15 +1,19 @@ package com.yoyuzh.common; +import jakarta.validation.ConstraintViolation; import jakarta.validation.ConstraintViolationException; import lombok.extern.slf4j.Slf4j; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.security.access.AccessDeniedException; import org.springframework.security.authentication.BadCredentialsException; +import org.springframework.validation.ObjectError; import org.springframework.web.bind.MethodArgumentNotValidException; import org.springframework.web.bind.annotation.ExceptionHandler; import org.springframework.web.bind.annotation.RestControllerAdvice; +import java.util.Objects; + @Slf4j @RestControllerAdvice public class GlobalExceptionHandler { @@ -27,7 +31,27 @@ public class GlobalExceptionHandler { @ExceptionHandler({MethodArgumentNotValidException.class, ConstraintViolationException.class}) public ResponseEntity> handleValidationException(Exception ex) { - return ResponseEntity.badRequest().body(ApiResponse.error(ErrorCode.UNKNOWN, ex.getMessage())); + if (ex instanceof MethodArgumentNotValidException validationException) { + String message = validationException.getBindingResult().getAllErrors().stream() + .map(ObjectError::getDefaultMessage) + .filter(Objects::nonNull) + .map(String::trim) + .filter(msg -> !msg.isEmpty()) + .findFirst() + .orElse("请求参数不合法"); + return ResponseEntity.badRequest().body(ApiResponse.error(ErrorCode.UNKNOWN, message)); + } + if (ex instanceof ConstraintViolationException validationException) { + String message = validationException.getConstraintViolations().stream() + .map(ConstraintViolation::getMessage) + .filter(Objects::nonNull) + .map(String::trim) + .filter(msg -> !msg.isEmpty()) + .findFirst() + .orElse("请求参数不合法"); + return ResponseEntity.badRequest().body(ApiResponse.error(ErrorCode.UNKNOWN, message)); + } + return ResponseEntity.badRequest().body(ApiResponse.error(ErrorCode.UNKNOWN, "请求参数不合法")); } @ExceptionHandler(AccessDeniedException.class) diff --git a/backend/src/main/java/com/yoyuzh/config/JwtProperties.java b/backend/src/main/java/com/yoyuzh/config/JwtProperties.java index 1991f89..3be0db6 100644 --- a/backend/src/main/java/com/yoyuzh/config/JwtProperties.java +++ b/backend/src/main/java/com/yoyuzh/config/JwtProperties.java @@ -5,8 +5,9 @@ import org.springframework.boot.context.properties.ConfigurationProperties; @ConfigurationProperties(prefix = "app.jwt") public class JwtProperties { - private String secret = "change-me-change-me-change-me-change-me"; - private long expirationSeconds = 86400; + private String secret = ""; + private long accessExpirationSeconds = 900; + private long refreshExpirationSeconds = 1209600; public String getSecret() { return secret; @@ -16,11 +17,19 @@ public class JwtProperties { this.secret = secret; } - public long getExpirationSeconds() { - return expirationSeconds; + public long getAccessExpirationSeconds() { + return accessExpirationSeconds; } - public void setExpirationSeconds(long expirationSeconds) { - this.expirationSeconds = expirationSeconds; + public void setAccessExpirationSeconds(long accessExpirationSeconds) { + this.accessExpirationSeconds = accessExpirationSeconds; + } + + public long getRefreshExpirationSeconds() { + return refreshExpirationSeconds; + } + + public void setRefreshExpirationSeconds(long refreshExpirationSeconds) { + this.refreshExpirationSeconds = refreshExpirationSeconds; } } diff --git a/backend/src/main/resources/application-dev.yml b/backend/src/main/resources/application-dev.yml index cbd5e5b..eb238dd 100644 --- a/backend/src/main/resources/application-dev.yml +++ b/backend/src/main/resources/application-dev.yml @@ -13,5 +13,7 @@ spring: path: /h2-console app: + jwt: + secret: ${APP_JWT_SECRET:} cqu: mock-enabled: true diff --git a/backend/src/main/resources/application.yml b/backend/src/main/resources/application.yml index a9f02d1..e26cacf 100644 --- a/backend/src/main/resources/application.yml +++ b/backend/src/main/resources/application.yml @@ -23,8 +23,9 @@ spring: app: jwt: - secret: change-me-change-me-change-me-change-me - expiration-seconds: 86400 + secret: ${APP_JWT_SECRET:} + access-expiration-seconds: 900 + refresh-expiration-seconds: 1209600 storage: root-dir: ./storage max-file-size: 524288000 diff --git a/backend/src/test/java/com/yoyuzh/auth/AuthControllerValidationTest.java b/backend/src/test/java/com/yoyuzh/auth/AuthControllerValidationTest.java new file mode 100644 index 0000000..cb1ca7b --- /dev/null +++ b/backend/src/test/java/com/yoyuzh/auth/AuthControllerValidationTest.java @@ -0,0 +1,79 @@ +package com.yoyuzh.auth; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yoyuzh.auth.dto.AuthResponse; +import com.yoyuzh.auth.dto.UserProfileResponse; +import com.yoyuzh.common.GlobalExceptionHandler; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.junit.jupiter.api.extension.ExtendWith; +import org.springframework.http.MediaType; +import org.springframework.test.web.servlet.MockMvc; +import org.springframework.test.web.servlet.setup.MockMvcBuilders; + +import java.time.LocalDateTime; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.springframework.test.web.servlet.request.MockMvcRequestBuilders.post; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.jsonPath; +import static org.springframework.test.web.servlet.result.MockMvcResultMatchers.status; + +@ExtendWith(MockitoExtension.class) +class AuthControllerValidationTest { + + @Mock + private AuthService authService; + + private MockMvc mockMvc; + private final ObjectMapper objectMapper = new ObjectMapper(); + + @BeforeEach + void setUp() { + mockMvc = MockMvcBuilders.standaloneSetup(new AuthController(authService)) + .setControllerAdvice(new GlobalExceptionHandler()) + .build(); + } + + @Test + void shouldReturnReadablePasswordValidationMessage() throws Exception { + mockMvc.perform(post("/api/auth/register") + .contentType(MediaType.APPLICATION_JSON) + .content(""" + { + "username": "alice", + "email": "alice@example.com", + "password": "weakpass" + } + """)) + .andExpect(status().isBadRequest()) + .andExpect(jsonPath("$.code").value(1000)) + .andExpect(jsonPath("$.msg").value("密码至少10位,且必须包含大写字母、小写字母、数字和特殊字符")); + } + + @Test + void shouldExposeRefreshEndpointContract() throws Exception { + AuthResponse response = AuthResponse.issued( + "new-access-token", + "new-refresh-token", + new UserProfileResponse(7L, "alice", "alice@example.com", LocalDateTime.now()) + ); + when(authService.refresh("refresh-1")).thenReturn(response); + + mockMvc.perform(post("/api/auth/refresh") + .contentType(MediaType.APPLICATION_JSON) + .content(objectMapper.writeValueAsString(new Object() { + public final String refreshToken = "refresh-1"; + }))) + .andExpect(status().isOk()) + .andExpect(jsonPath("$.code").value(0)) + .andExpect(jsonPath("$.data.token").value("new-access-token")) + .andExpect(jsonPath("$.data.accessToken").value("new-access-token")) + .andExpect(jsonPath("$.data.refreshToken").value("new-refresh-token")) + .andExpect(jsonPath("$.data.user.username").value("alice")); + + verify(authService).refresh("refresh-1"); + } +} diff --git a/backend/src/test/java/com/yoyuzh/auth/AuthServiceTest.java b/backend/src/test/java/com/yoyuzh/auth/AuthServiceTest.java index 01cccfd..e263e29 100644 --- a/backend/src/test/java/com/yoyuzh/auth/AuthServiceTest.java +++ b/backend/src/test/java/com/yoyuzh/auth/AuthServiceTest.java @@ -39,6 +39,9 @@ class AuthServiceTest { @Mock private JwtTokenProvider jwtTokenProvider; + @Mock + private RefreshTokenService refreshTokenService; + @Mock private FileService fileService; @@ -47,29 +50,32 @@ class AuthServiceTest { @Test void shouldRegisterUserWithEncryptedPassword() { - RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "plain-password"); + RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "StrongPass1!"); when(userRepository.existsByUsername("alice")).thenReturn(false); when(userRepository.existsByEmail("alice@example.com")).thenReturn(false); - when(passwordEncoder.encode("plain-password")).thenReturn("encoded-password"); + when(passwordEncoder.encode("StrongPass1!")).thenReturn("encoded-password"); when(userRepository.save(any(User.class))).thenAnswer(invocation -> { User user = invocation.getArgument(0); user.setId(1L); user.setCreatedAt(LocalDateTime.now()); return user; }); - when(jwtTokenProvider.generateToken(1L, "alice")).thenReturn("jwt-token"); + when(jwtTokenProvider.generateAccessToken(1L, "alice")).thenReturn("access-token"); + when(refreshTokenService.issueRefreshToken(any(User.class))).thenReturn("refresh-token"); AuthResponse response = authService.register(request); - assertThat(response.token()).isEqualTo("jwt-token"); + assertThat(response.token()).isEqualTo("access-token"); + assertThat(response.accessToken()).isEqualTo("access-token"); + assertThat(response.refreshToken()).isEqualTo("refresh-token"); assertThat(response.user().username()).isEqualTo("alice"); - verify(passwordEncoder).encode("plain-password"); + verify(passwordEncoder).encode("StrongPass1!"); verify(fileService).ensureDefaultDirectories(any(User.class)); } @Test void shouldRejectDuplicateUsernameOnRegister() { - RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "plain-password"); + RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "StrongPass1!"); when(userRepository.existsByUsername("alice")).thenReturn(true); assertThatThrownBy(() -> authService.register(request)) @@ -87,17 +93,39 @@ class AuthServiceTest { user.setPasswordHash("encoded-password"); user.setCreatedAt(LocalDateTime.now()); when(userRepository.findByUsername("alice")).thenReturn(Optional.of(user)); - when(jwtTokenProvider.generateToken(1L, "alice")).thenReturn("jwt-token"); + when(jwtTokenProvider.generateAccessToken(1L, "alice")).thenReturn("access-token"); + when(refreshTokenService.issueRefreshToken(user)).thenReturn("refresh-token"); AuthResponse response = authService.login(request); verify(authenticationManager).authenticate( new UsernamePasswordAuthenticationToken("alice", "plain-password")); - assertThat(response.token()).isEqualTo("jwt-token"); + assertThat(response.token()).isEqualTo("access-token"); + assertThat(response.accessToken()).isEqualTo("access-token"); + assertThat(response.refreshToken()).isEqualTo("refresh-token"); assertThat(response.user().email()).isEqualTo("alice@example.com"); verify(fileService).ensureDefaultDirectories(user); } + @Test + void shouldRotateRefreshTokenAndReturnNewCredentials() { + User user = new User(); + user.setId(1L); + user.setUsername("alice"); + user.setEmail("alice@example.com"); + user.setCreatedAt(LocalDateTime.now()); + when(refreshTokenService.rotateRefreshToken("old-refresh")) + .thenReturn(new RefreshTokenService.RotatedRefreshToken(user, "new-refresh")); + when(jwtTokenProvider.generateAccessToken(1L, "alice")).thenReturn("new-access"); + + AuthResponse response = authService.refresh("old-refresh"); + + assertThat(response.token()).isEqualTo("new-access"); + assertThat(response.accessToken()).isEqualTo("new-access"); + assertThat(response.refreshToken()).isEqualTo("new-refresh"); + assertThat(response.user().username()).isEqualTo("alice"); + } + @Test void shouldThrowBusinessExceptionWhenAuthenticationFails() { LoginRequest request = new LoginRequest("alice", "wrong-password"); @@ -119,11 +147,14 @@ class AuthServiceTest { user.setCreatedAt(LocalDateTime.now()); return user; }); - when(jwtTokenProvider.generateToken(9L, "demo")).thenReturn("jwt-token"); + when(jwtTokenProvider.generateAccessToken(9L, "demo")).thenReturn("access-token"); + when(refreshTokenService.issueRefreshToken(any(User.class))).thenReturn("refresh-token"); AuthResponse response = authService.devLogin("demo"); assertThat(response.user().username()).isEqualTo("demo"); + assertThat(response.accessToken()).isEqualTo("access-token"); + assertThat(response.refreshToken()).isEqualTo("refresh-token"); verify(fileService).ensureDefaultDirectories(any(User.class)); } } diff --git a/backend/src/test/java/com/yoyuzh/auth/JwtTokenProviderTest.java b/backend/src/test/java/com/yoyuzh/auth/JwtTokenProviderTest.java new file mode 100644 index 0000000..c5afab9 --- /dev/null +++ b/backend/src/test/java/com/yoyuzh/auth/JwtTokenProviderTest.java @@ -0,0 +1,76 @@ +package com.yoyuzh.auth; + +import com.yoyuzh.config.JwtProperties; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.security.Keys; +import org.junit.jupiter.api.Test; + +import javax.crypto.SecretKey; +import java.nio.charset.StandardCharsets; +import java.time.Instant; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +class JwtTokenProviderTest { + + @Test + void shouldRejectEmptyJwtSecret() { + JwtProperties properties = new JwtProperties(); + properties.setSecret(" "); + + JwtTokenProvider provider = new JwtTokenProvider(properties); + + assertThatThrownBy(provider::init) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("未配置"); + } + + @Test + void shouldRejectDefaultJwtSecret() { + JwtProperties properties = new JwtProperties(); + properties.setSecret("change-me-change-me-change-me-change-me"); + + JwtTokenProvider provider = new JwtTokenProvider(properties); + + assertThatThrownBy(provider::init) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("默认 JWT 密钥"); + } + + @Test + void shouldRejectTooShortJwtSecret() { + JwtProperties properties = new JwtProperties(); + properties.setSecret("too-short-secret"); + + JwtTokenProvider provider = new JwtTokenProvider(properties); + + assertThatThrownBy(provider::init) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("至少需要 32 字节"); + } + + @Test + void shouldGenerateShortLivedAccessToken() { + JwtProperties properties = new JwtProperties(); + properties.setSecret("0123456789abcdef0123456789abcdef"); + properties.setAccessExpirationSeconds(900); + + JwtTokenProvider provider = new JwtTokenProvider(properties); + provider.init(); + + String token = provider.generateAccessToken(7L, "alice"); + SecretKey secretKey = Keys.hmacShaKeyFor(properties.getSecret().getBytes(StandardCharsets.UTF_8)); + Instant expiration = Jwts.parser().verifyWith(secretKey).build() + .parseSignedClaims(token) + .getPayload() + .getExpiration() + .toInstant(); + + assertThat(provider.validateToken(token)).isTrue(); + assertThat(provider.getUsername(token)).isEqualTo("alice"); + assertThat(provider.getUserId(token)).isEqualTo(7L); + assertThat(expiration).isAfter(Instant.now().plusSeconds(850)); + assertThat(expiration).isBefore(Instant.now().plusSeconds(950)); + } +} diff --git a/backend/src/test/java/com/yoyuzh/auth/RefreshTokenServiceIntegrationTest.java b/backend/src/test/java/com/yoyuzh/auth/RefreshTokenServiceIntegrationTest.java new file mode 100644 index 0000000..f44fd94 --- /dev/null +++ b/backend/src/test/java/com/yoyuzh/auth/RefreshTokenServiceIntegrationTest.java @@ -0,0 +1,157 @@ +package com.yoyuzh.auth; + +import com.yoyuzh.PortalBackendApplication; +import com.yoyuzh.common.BusinessException; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@SpringBootTest( + classes = PortalBackendApplication.class, + properties = { + "spring.datasource.url=jdbc:h2:mem:refresh_token_test;MODE=MySQL;DB_CLOSE_DELAY=-1;LOCK_TIMEOUT=10000", + "spring.datasource.driver-class-name=org.h2.Driver", + "spring.datasource.username=sa", + "spring.datasource.password=", + "spring.jpa.hibernate.ddl-auto=create-drop", + "app.jwt.secret=0123456789abcdef0123456789abcdef", + "app.storage.root-dir=./target/test-storage-refresh", + "app.cqu.require-login=true", + "app.cqu.mock-enabled=false" + } +) +class RefreshTokenServiceIntegrationTest { + + @Autowired + private RefreshTokenService refreshTokenService; + + @Autowired + private RefreshTokenRepository refreshTokenRepository; + + @Autowired + private UserRepository userRepository; + + @BeforeEach + void setUp() { + refreshTokenRepository.deleteAll(); + userRepository.deleteAll(); + } + + @Test + void shouldRejectRefreshTokenReuseAfterRotation() { + User user = createUser("alice"); + + String rawToken = refreshTokenService.issueRefreshToken(user); + RefreshTokenService.RotatedRefreshToken rotated = refreshTokenService.rotateRefreshToken(rawToken); + + assertThat(rotated.refreshToken()).isNotBlank().isNotEqualTo(rawToken); + assertThatThrownBy(() -> refreshTokenService.rotateRefreshToken(rawToken)) + .isInstanceOf(BusinessException.class) + .hasMessageContaining("无效或已使用"); + assertThat(refreshTokenRepository.findAll()) + .hasSize(2) + .filteredOn(RefreshToken::isRevoked) + .hasSize(1); + } + + @Test + void shouldStoreRefreshTokenAsHashInsteadOfPlaintext() { + User user = createUser("hash-check"); + + String rawToken = refreshTokenService.issueRefreshToken(user); + + assertThat(refreshTokenRepository.findAll()) + .singleElement() + .satisfies(token -> { + assertThat(token.getTokenHash()).hasSize(64); + assertThat(token.getTokenHash()).isNotEqualTo(rawToken); + assertThat(token.getTokenHash()).doesNotContain(rawToken.substring(0, 8)); + }); + } + + @Test + void shouldRejectExpiredRefreshTokenAndRevokeIt() { + User user = createUser("expired"); + String rawToken = refreshTokenService.issueRefreshToken(user); + RefreshToken storedToken = refreshTokenRepository.findAll().get(0); + storedToken.setExpiresAt(LocalDateTime.now().minusSeconds(1)); + refreshTokenRepository.save(storedToken); + + assertThatThrownBy(() -> refreshTokenService.rotateRefreshToken(rawToken)) + .isInstanceOf(BusinessException.class) + .hasMessageContaining("刷新令牌已过期"); + assertThat(refreshTokenRepository.findById(storedToken.getId())) + .get() + .extracting(RefreshToken::isRevoked) + .isEqualTo(true); + } + + @Test + void shouldAllowConcurrentRefreshTokenConsumptionOnlyOnce() throws Exception { + User user = createUser("bob"); + String rawToken = refreshTokenService.issueRefreshToken(user); + ExecutorService executorService = Executors.newFixedThreadPool(2); + CountDownLatch ready = new CountDownLatch(2); + CountDownLatch start = new CountDownLatch(1); + + try { + List> futures = new ArrayList<>(); + for (int i = 0; i < 2; i += 1) { + futures.add(executorService.submit(() -> { + ready.countDown(); + start.await(5, TimeUnit.SECONDS); + try { + return refreshTokenService.rotateRefreshToken(rawToken); + } catch (BusinessException ex) { + return ex; + } + })); + } + + assertThat(ready.await(5, TimeUnit.SECONDS)).isTrue(); + start.countDown(); + + List results = new ArrayList<>(); + for (Future future : futures) { + results.add(future.get(5, TimeUnit.SECONDS)); + } + + assertThat(results) + .filteredOn(result -> result instanceof RefreshTokenService.RotatedRefreshToken) + .hasSize(1); + assertThat(results) + .filteredOn(result -> result instanceof BusinessException) + .singleElement() + .extracting(result -> ((BusinessException) result).getMessage()) + .isEqualTo("刷新令牌无效或已使用"); + assertThat(refreshTokenRepository.findAll()) + .hasSize(2) + .filteredOn(token -> !token.isRevoked()) + .hasSize(1); + } finally { + executorService.shutdownNow(); + } + } + + private User createUser(String username) { + User user = new User(); + user.setUsername(username); + user.setEmail(username + "@example.com"); + user.setPasswordHash("encoded-password"); + user.setCreatedAt(LocalDateTime.now()); + return userRepository.save(user); + } +} diff --git a/backend/src/test/java/com/yoyuzh/auth/RegisterRequestValidationTest.java b/backend/src/test/java/com/yoyuzh/auth/RegisterRequestValidationTest.java new file mode 100644 index 0000000..db922d5 --- /dev/null +++ b/backend/src/test/java/com/yoyuzh/auth/RegisterRequestValidationTest.java @@ -0,0 +1,33 @@ +package com.yoyuzh.auth; + +import com.yoyuzh.auth.dto.RegisterRequest; +import jakarta.validation.Validation; +import jakarta.validation.Validator; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +class RegisterRequestValidationTest { + + private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator(); + + @Test + void shouldRejectWeakPassword() { + RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "weakpass"); + + var violations = validator.validate(request); + + assertThat(violations) + .extracting(violation -> violation.getMessage()) + .contains("密码至少10位,且必须包含大写字母、小写字母、数字和特殊字符"); + } + + @Test + void shouldAcceptStrongPassword() { + RegisterRequest request = new RegisterRequest("alice", "alice@example.com", "StrongPass1!"); + + var violations = validator.validate(request); + + assertThat(violations).isEmpty(); + } +} diff --git a/backend/src/test/java/com/yoyuzh/cqu/CquDataServiceTransactionTest.java b/backend/src/test/java/com/yoyuzh/cqu/CquDataServiceTransactionTest.java index 4841839..c283a15 100644 --- a/backend/src/test/java/com/yoyuzh/cqu/CquDataServiceTransactionTest.java +++ b/backend/src/test/java/com/yoyuzh/cqu/CquDataServiceTransactionTest.java @@ -22,6 +22,7 @@ import static org.mockito.Mockito.when; "spring.datasource.username=sa", "spring.datasource.password=", "spring.jpa.hibernate.ddl-auto=create-drop", + "app.jwt.secret=0123456789abcdef0123456789abcdef", "app.cqu.require-login=true", "app.cqu.mock-enabled=false" } diff --git a/front/src/auth/AuthProvider.tsx b/front/src/auth/AuthProvider.tsx index cb7575b..f22fd39 100644 --- a/front/src/auth/AuthProvider.tsx +++ b/front/src/auth/AuthProvider.tsx @@ -3,6 +3,7 @@ import React, { createContext, useContext, useEffect, useState } from 'react'; import { apiRequest } from '@/src/lib/api'; import { clearStoredSession, + createSession, readStoredSession, saveStoredSession, SESSION_EVENT_NAME, @@ -27,10 +28,7 @@ interface AuthContextValue { const AuthContext = createContext(null); function buildSession(auth: AuthResponse): AuthSession { - return { - token: auth.token, - user: auth.user, - }; + return createSession(auth); } export function AuthProvider({ children }: { children: React.ReactNode }) { diff --git a/front/src/lib/api.test.ts b/front/src/lib/api.test.ts index 23291e4..efb3575 100644 --- a/front/src/lib/api.test.ts +++ b/front/src/lib/api.test.ts @@ -2,7 +2,7 @@ import assert from 'node:assert/strict'; import { afterEach, beforeEach, test } from 'node:test'; import { apiBinaryUploadRequest, apiRequest, apiUploadRequest, shouldRetryRequest, toNetworkApiError } from './api'; -import { clearStoredSession, saveStoredSession } from './session'; +import { clearStoredSession, readStoredSession, saveStoredSession } from './session'; class MemoryStorage implements Storage { private store = new Map(); @@ -135,6 +135,7 @@ test('apiRequest attaches bearer token and unwraps response payload', async () = let request: Request | URL | string | undefined; saveStoredSession({ token: 'token-123', + refreshToken: 'refresh-123', user: { id: 1, username: 'tester', @@ -230,6 +231,7 @@ test('network fetch failures are converted to readable api errors', () => { test('apiUploadRequest attaches auth header and forwards upload progress', async () => { saveStoredSession({ token: 'token-456', + refreshToken: 'refresh-456', user: { id: 2, username: 'uploader', @@ -309,3 +311,158 @@ test('apiBinaryUploadRequest sends raw file body to signed upload url', async () {loaded: 128, total: 128}, ]); }); + +test('apiRequest refreshes expired access token once and retries the original request', async () => { + const calls: Array<{url: string; authorization: string | null; body: string | null}> = []; + saveStoredSession({ + token: 'expired-token', + refreshToken: 'refresh-1', + user: { + id: 3, + username: 'alice', + email: 'alice@example.com', + createdAt: '2026-03-18T10:00:00', + }, + }); + + globalThis.fetch = async (input, init) => { + const url = String(input); + const headers = new Headers(init?.headers); + calls.push({ + url, + authorization: headers.get('Authorization'), + body: typeof init?.body === 'string' ? init.body : null, + }); + + if (url.endsWith('/user/profile') && calls.length === 1) { + return new Response( + JSON.stringify({ + code: 1001, + msg: '用户未登录', + data: null, + }), + { + status: 401, + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + } + + if (url.endsWith('/auth/refresh')) { + return new Response( + JSON.stringify({ + code: 0, + msg: 'success', + data: { + token: 'new-access-token', + accessToken: 'new-access-token', + refreshToken: 'refresh-2', + user: { + id: 3, + username: 'alice', + email: 'alice@example.com', + createdAt: '2026-03-18T10:00:00', + }, + }, + }), + { + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + } + + return new Response( + JSON.stringify({ + code: 0, + msg: 'success', + data: { + id: 3, + username: 'alice', + email: 'alice@example.com', + createdAt: '2026-03-18T10:00:00', + }, + }), + { + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + }; + + const profile = await apiRequest<{id: number; username: string}>('/user/profile'); + + assert.equal(profile.username, 'alice'); + assert.equal(calls.length, 3); + assert.equal(calls[0]?.authorization, 'Bearer expired-token'); + assert.equal(calls[1]?.url, '/api/auth/refresh'); + assert.equal(calls[2]?.authorization, 'Bearer new-access-token'); + assert.deepEqual(JSON.parse(calls[1]?.body || '{}'), {refreshToken: 'refresh-1'}); + assert.deepEqual(readStoredSession(), { + token: 'new-access-token', + refreshToken: 'refresh-2', + user: { + id: 3, + username: 'alice', + email: 'alice@example.com', + createdAt: '2026-03-18T10:00:00', + }, + }); +}); + +test('apiRequest clears session when refresh fails after a 401 response', async () => { + let callCount = 0; + saveStoredSession({ + token: 'expired-token', + refreshToken: 'refresh-1', + user: { + id: 5, + username: 'bob', + email: 'bob@example.com', + createdAt: '2026-03-18T10:00:00', + }, + }); + + globalThis.fetch = async (input) => { + callCount += 1; + const url = String(input); + + if (url.endsWith('/auth/refresh')) { + return new Response( + JSON.stringify({ + code: 1001, + msg: '刷新令牌已过期', + data: null, + }), + { + status: 401, + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + } + + return new Response( + JSON.stringify({ + code: 1001, + msg: '用户未登录', + data: null, + }), + { + status: 401, + headers: { + 'Content-Type': 'application/json', + }, + }, + ); + }; + + await assert.rejects(() => apiRequest('/user/profile'), /用户未登录/); + assert.equal(callCount, 2); + assert.equal(readStoredSession(), null); +}); diff --git a/front/src/lib/api.ts b/front/src/lib/api.ts index 02e3d86..f31d8dd 100644 --- a/front/src/lib/api.ts +++ b/front/src/lib/api.ts @@ -1,4 +1,5 @@ -import { clearStoredSession, readStoredSession } from './session'; +import type { AuthResponse } from './types'; +import { clearStoredSession, createSession, readStoredSession, saveStoredSession } from './session'; interface ApiEnvelope { code: number; @@ -25,6 +26,9 @@ interface ApiBinaryUploadRequestInit { } const API_BASE_URL = (import.meta.env?.VITE_API_BASE_URL || '/api').replace(/\/$/, ''); +const AUTH_REFRESH_PATH = '/auth/refresh'; + +let refreshRequestPromise: Promise | null = null; export class ApiError extends Error { code?: number; @@ -93,6 +97,20 @@ function resolveUrl(path: string) { return `${API_BASE_URL}${normalizedPath}`; } +function normalizePath(path: string) { + return path.startsWith('/') ? path : `/${path}`; +} + +function shouldAttemptTokenRefresh(path: string) { + const normalizedPath = normalizePath(path); + return ![ + '/auth/login', + '/auth/register', + '/auth/dev-login', + AUTH_REFRESH_PATH, + ].includes(normalizedPath); +} + function buildRequestBody(body: ApiRequestInit['body']) { if (body == null) { return undefined; @@ -111,6 +129,58 @@ function buildRequestBody(body: ApiRequestInit['body']) { return JSON.stringify(body); } +async function refreshAccessToken() { + const currentSession = readStoredSession(); + if (!currentSession?.refreshToken) { + clearStoredSession(); + return false; + } + + if (refreshRequestPromise) { + return refreshRequestPromise; + } + + refreshRequestPromise = (async () => { + try { + const response = await fetch(resolveUrl(AUTH_REFRESH_PATH), { + method: 'POST', + headers: { + Accept: 'application/json', + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ + refreshToken: currentSession.refreshToken, + }), + }); + const contentType = response.headers.get('content-type') || ''; + if (!response.ok || !contentType.includes('application/json')) { + clearStoredSession(); + return false; + } + + const payload = (await response.json()) as ApiEnvelope; + if (payload.code !== 0 || !payload.data) { + clearStoredSession(); + return false; + } + + saveStoredSession({ + ...currentSession, + ...createSession(payload.data), + user: payload.data.user ?? currentSession.user, + }); + return true; + } catch { + clearStoredSession(); + return false; + } finally { + refreshRequestPromise = null; + } + })(); + + return refreshRequestPromise; +} + async function parseApiError(response: Response) { const contentType = response.headers.get('content-type') || ''; if (!contentType.includes('application/json')) { @@ -140,7 +210,7 @@ export function shouldRetryRequest( return attempt <= getMaxRetryAttempts(path, init); } -async function performRequest(path: string, init: ApiRequestInit = {}) { +async function performRequest(path: string, init: ApiRequestInit = {}, allowRefresh = true): Promise { const session = readStoredSession(); const headers = new Headers(init.headers); const requestBody = buildRequestBody(init.body); @@ -180,7 +250,14 @@ async function performRequest(path: string, init: ApiRequestInit = {}) { throw toNetworkApiError(lastError); } - if (response.status === 401 || response.status === 403) { + if (response.status === 401 && allowRefresh && shouldAttemptTokenRefresh(path)) { + const refreshed = await refreshAccessToken(); + if (refreshed) { + return performRequest(path, init, false); + } + } + + if (response.status === 401) { clearStoredSession(); } @@ -200,16 +277,13 @@ export async function apiRequest(path: string, init?: ApiRequestInit) { const payload = (await response.json()) as ApiEnvelope; if (!response.ok || payload.code !== 0) { - if (response.status === 401 || payload.code === 401) { - clearStoredSession(); - } throw new ApiError(payload.msg || `请求失败 (${response.status})`, response.status, payload.code); } return payload.data; } -export function apiUploadRequest(path: string, init: ApiUploadRequestInit) { +function apiUploadRequestInternal(path: string, init: ApiUploadRequestInit, allowRefresh: boolean): Promise { const session = readStoredSession(); const headers = new Headers(init.headers); @@ -248,8 +322,21 @@ export function apiUploadRequest(path: string, init: ApiUploadRequestInit) { xhr.onload = () => { const contentType = xhr.getResponseHeader('content-type') || ''; - if (xhr.status === 401 || xhr.status === 403) { - clearStoredSession(); + if (xhr.status === 401 && allowRefresh && shouldAttemptTokenRefresh(path)) { + refreshAccessToken() + .then((refreshed) => { + if (refreshed) { + resolve(apiUploadRequestInternal(path, init, false)); + return; + } + clearStoredSession(); + reject(new ApiError('登录状态已失效,请重新登录', 401)); + }) + .catch((error) => { + clearStoredSession(); + reject(error instanceof ApiError ? error : toNetworkApiError(error)); + }); + return; } if (!contentType.includes('application/json')) { @@ -264,7 +351,7 @@ export function apiUploadRequest(path: string, init: ApiUploadRequestInit) { const payload = JSON.parse(xhr.responseText) as ApiEnvelope; if (xhr.status < 200 || xhr.status >= 300 || payload.code !== 0) { - if (xhr.status === 401 || payload.code === 401) { + if (xhr.status === 401) { clearStoredSession(); } reject(new ApiError(payload.msg || `请求失败 (${xhr.status})`, xhr.status, payload.code)); @@ -278,6 +365,10 @@ export function apiUploadRequest(path: string, init: ApiUploadRequestInit) { }); } +export function apiUploadRequest(path: string, init: ApiUploadRequestInit): Promise { + return apiUploadRequestInternal(path, init, true); +} + export function apiBinaryUploadRequest(path: string, init: ApiBinaryUploadRequestInit) { const headers = new Headers(init.headers); diff --git a/front/src/lib/session.ts b/front/src/lib/session.ts index 6b214d7..9ca9840 100644 --- a/front/src/lib/session.ts +++ b/front/src/lib/session.ts @@ -1,4 +1,4 @@ -import type { AuthSession } from './types'; +import type { AuthResponse, AuthSession } from './types'; const SESSION_STORAGE_KEY = 'portal-session'; const POST_LOGIN_PENDING_KEY = 'portal-post-login-pending'; @@ -10,6 +10,40 @@ function notifySessionChanged() { } } +function normalizeSession(value: unknown): AuthSession | null { + if (!value || typeof value !== 'object') { + return null; + } + + const candidate = value as Partial & {accessToken?: string}; + const token = typeof candidate.token === 'string' && candidate.token.trim() + ? candidate.token + : typeof candidate.accessToken === 'string' && candidate.accessToken.trim() + ? candidate.accessToken + : null; + + if (!token || !candidate.user) { + return null; + } + + return { + token, + refreshToken: + typeof candidate.refreshToken === 'string' && candidate.refreshToken.trim() + ? candidate.refreshToken + : null, + user: candidate.user, + }; +} + +export function createSession(auth: AuthResponse): AuthSession { + return { + token: auth.accessToken || auth.token, + refreshToken: auth.refreshToken ?? null, + user: auth.user, + }; +} + export function readStoredSession(): AuthSession | null { if (typeof localStorage === 'undefined') { return null; @@ -21,7 +55,11 @@ export function readStoredSession(): AuthSession | null { } try { - return JSON.parse(rawValue) as AuthSession; + const session = normalizeSession(JSON.parse(rawValue)); + if (!session) { + localStorage.removeItem(SESSION_STORAGE_KEY); + } + return session; } catch { localStorage.removeItem(SESSION_STORAGE_KEY); return null; diff --git a/front/src/lib/types.ts b/front/src/lib/types.ts index 85cea09..62166d2 100644 --- a/front/src/lib/types.ts +++ b/front/src/lib/types.ts @@ -7,11 +7,14 @@ export interface UserProfile { export interface AuthSession { token: string; + refreshToken?: string | null; user: UserProfile; } export interface AuthResponse { token: string; + accessToken?: string; + refreshToken?: string | null; user: UserProfile; } diff --git a/front/src/pages/Login.tsx b/front/src/pages/Login.tsx index 8a5cfe8..2f38129 100644 --- a/front/src/pages/Login.tsx +++ b/front/src/pages/Login.tsx @@ -8,7 +8,7 @@ import { Button } from '@/src/components/ui/button'; import { Input } from '@/src/components/ui/input'; import { apiRequest, ApiError } from '@/src/lib/api'; import { cn } from '@/src/lib/utils'; -import { markPostLoginPending, saveStoredSession } from '@/src/lib/session'; +import { createSession, markPostLoginPending, saveStoredSession } from '@/src/lib/session'; import type { AuthResponse } from '@/src/lib/types'; const DEV_LOGIN_ENABLED = import.meta.env.DEV || import.meta.env.VITE_ENABLE_DEV_LOGIN === 'true'; @@ -59,10 +59,7 @@ export default function Login() { } } - saveStoredSession({ - token: auth.token, - user: auth.user, - }); + saveStoredSession(createSession(auth)); markPostLoginPending(); setLoading(false); navigate('/overview'); @@ -87,10 +84,7 @@ export default function Login() { }, }); - saveStoredSession({ - token: auth.token, - user: auth.user, - }); + saveStoredSession(createSession(auth)); markPostLoginPending(); setLoading(false); navigate('/overview'); @@ -301,10 +295,13 @@ export default function Login() { value={registerPassword} onChange={(event) => setRegisterPassword(event.target.value)} required - minLength={6} + minLength={10} maxLength={64} /> +

+ 至少 10 位,并包含大写字母、小写字母、数字和特殊字符。 +