前言
业务中需要对一些接口进行限流处理,防止机器人调用或者保证服务质量;
实现方式
- 基于redis的lua脚本
引入依赖
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-data-redis</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-aop</artifactId>
</dependency>
redis配置
package com.qiangesoft.wechat.config;
import org.springframework.cache.annotation.CachingConfigurerSupport;
import org.springframework.cache.annotation.EnableCaching;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.DefaultRedisScript;
import org.springframework.data.redis.serializer.StringRedisSerializer;
/**
* redis配置
*
* @author qiangesoft
* @date 2024-03-19
*/
@Configuration
@EnableCaching
public class RedisConfig extends CachingConfigurerSupport {
@Bean
@SuppressWarnings(value = {"unchecked", "rawtypes"})
public RedisTemplate<Object, Object> redisTemplate(RedisConnectionFactory connectionFactory) {
RedisTemplate<Object, Object> template = new RedisTemplate<>();
template.setConnectionFactory(connectionFactory);
FastJson2JsonRedisSerializer serializer = new FastJson2JsonRedisSerializer(Object.class);
// 使用StringRedisSerializer来序列化和反序列化redis的key值
template.setKeySerializer(new StringRedisSerializer());
template.setValueSerializer(serializer);
// Hash的key也采用StringRedisSerializer的序列化方式
template.setHashKeySerializer(new StringRedisSerializer());
template.setHashValueSerializer(serializer);
template.afterPropertiesSet();
return template;
}
@Bean
public DefaultRedisScript<Long> limitScript() {
DefaultRedisScript<Long> redisScript = new DefaultRedisScript<>();
redisScript.setScriptText(limitScriptText());
redisScript.setResultType(Long.class);
return redisScript;
}
/**
* 限流脚本
*/
private String limitScriptText() {
return "local key = KEYS[1]\n" +
"local count = tonumber(ARGV[1])\n" +
"local time = tonumber(ARGV[2])\n" +
"local current = redis.call('get', key);\n" +
"if current and tonumber(current) > count then\n" +
" return tonumber(current);\n" +
"end\n" +
"current = redis.call('incr', key)\n" +
"if tonumber(current) == 1 then\n" +
" redis.call('expire', key, time)\n" +
"end\n" +
"return tonumber(current);";
}
}
限流注解
package com.qiangesoft.wechat.config;
import java.lang.annotation.*;
/**
* 限流注解
*
* @author qiangesoft
* @date 2024-03-19
*/
@Target(ElementType.METHOD)
@Retention(RetentionPolicy.RUNTIME)
@Documented
public @interface RateLimiter {
/**
* 限流类型
*/
LimitType limitType() default LimitType.IP;
/**
* 限流时间,单位秒
*/
int time() default 60;
/**
* 限流次数
*/
int count() default 10;
}
限流实现
package com.qiangesoft.wechat.config;
import com.qiangesoft.wechat.utils.IpUtil;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.aspectj.lang.JoinPoint;
import org.aspectj.lang.annotation.Aspect;
import org.aspectj.lang.annotation.Before;
import org.aspectj.lang.reflect.MethodSignature;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.script.RedisScript;
import org.springframework.stereotype.Component;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.List;
/**
* 限流处理
*
* @author qiangesoft
* @date 2024-03-19
*/
@RequiredArgsConstructor
@Slf4j
@Aspect
@Component
public class RateLimiterAspect {
private final RedisTemplate<Object, Object> redisTemplate;
private final RedisScript<Long> limitScript;
@Before("@annotation(rateLimiter)")
public void doBefore(JoinPoint point, RateLimiter rateLimiter) throws Throwable {
int time = rateLimiter.time();
int count = rateLimiter.count();
String key = this.buildKey(rateLimiter, point);
List<Object> keys = Collections.singletonList(key);
try {
Long number = redisTemplate.execute(limitScript, keys, count, time);
if (number == null || number.intValue() > count) {
throw new RuntimeException("访问过于频繁,请稍候再试");
}
} catch (RuntimeException e) {
throw e;
} catch (Exception e) {
throw new RuntimeException("服务器限流异常,请稍候再试");
}
}
/**
* 缓存key
*
* @param rateLimiter
* @param point
* @return
*/
public String buildKey(RateLimiter rateLimiter, JoinPoint point) {
String limitId = "";
if (rateLimiter.limitType() == LimitType.IP) {
limitId = IpUtil.getIpAddr();
}
MethodSignature signature = (MethodSignature) point.getSignature();
Method method = signature.getMethod();
Class<?> targetClass = method.getDeclaringClass();
String key = "rate_limit:" + limitId + "-" + targetClass.getName() + "-" + method.getName();
return key;
}
}
测试代码
package com.qiangesoft.wechat.controller;
import com.qiangesoft.wechat.config.RateLimiter;
import com.qiangesoft.wechat.utils.ResultVO;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
/**
* 测试接口
*
* @author qiangesoft
* @date 2024-03-19
*/
@RestController
public class TestController {
@RateLimiter
@GetMapping("/test")
public ResultVO test() {
return ResultVO.ok("调用成功!");
}
}
继续频繁点击