一、前言
通过学习第三篇文章,我们已经成功地建立了IM与AI服务之间的数据链路。然而,我们目前面临一个紧迫需要解决的安全性问题,即非法用户可能会通过获取WebSocket的连接信息,顺利地连接到我们的服务。这不仅占用了大量的无效连接和资源,还对业务数据带来了潜在的风险。因此,我们需要逐步完善这个安全问题。
二、术语
2.1. 多设备登录
是指在一个应用或平台上使用多个设备进行登录和访问。传统上,用户只能使用单个设备(如个人电脑或手机)登录到应用程序,但随着技术的发展,许多应用和平台开始支持多设备登录功能。
2.2. 黑名单
是一种记录被列入不受欢迎或禁止的个人、组织、IP地址或其他实体的列表。在各种环境中,黑名单用于限制或阻止对特定实体的访问、参与或特权。
三、前置条件
3.1. 调通IM与AI服务的数据链路(参见开源模型应用落地-业务整合篇(三))
3.2. 了解Netty的基本使用
四、技术实现
4.1. 业务流程
# 上游服务(即ws的客户福安)先发送MsgType为2消息,进行全局初始化,示例:
{"userId":12345,"msgType":2}
# 认证通过后,再进行业务对话,示例:
{"userId":12345,"msgType":1,"contents":"你好","history":[]}
4.2. 消息类型枚举类增加初始化类型
import lombok.Getter;
@Getter
public enum MsgType {
CHAT(1, "聊天消息"),
INIT(2, "初始化"),
SYSTEM(9, "系统消息");
private int code;
private String desc;
MsgType(int code, String desc) {
this.code = code;
this.desc = desc;
}
}
4.3. 修改IM的业务逻辑处理类
注释上一篇的代码
增加这一篇的代码
PS: 此处USERID的校验规则应该根据实际业务调整,这里仅简单判断是否小于10000.
import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelHandler;
import lombok.extern.slf4j.Slf4j;
import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Component;
/**
* @Description: 处理消息的handler
*/
@Slf4j
@ChannelHandler.Sharable
@Component
public class BusinessHandler extends AbstractBusinessLogicHandler<TextWebSocketFrame> {
@Autowired
private AIChatUtils aiChatUtils;
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
String channelId = ctx.channel().id().asShortText();
log.info("add client,channelId:{}", channelId);
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
String channelId = ctx.channel().id().asShortText();
log.info("remove client,channelId:{}", channelId);
}
@Override
protected void channelRead0(ChannelHandlerContext channelHandlerContext, TextWebSocketFrame textWebSocketFrame)
throws Exception {
// 获取客户端传输过来的消息
String content = textWebSocketFrame.text();
log.info("接收到客户端发送的信息: {}",content);
Long userIdForReq;
String msgType = "";
String contents = "";
try {
ApiReqMessage apiReqMessage = JSON.parseObject(content, ApiReqMessage.class);
msgType = apiReqMessage.getMsgType();
contents = apiReqMessage.getContents();
userIdForReq = apiReqMessage.getUserId();
// 用户身份标识校验
if((long)userIdForReq <= 10000){
ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SYSTEM_ERROR.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.contents("用户身份标识有误!")
.msgType(String.valueOf(MsgType.SYSTEM.getCode()))
.build();
buildResponseAndClose(channelHandlerContext, apiRespMessage);
return;
}
// 添加用户
// if(!isExists(userIdForReq)){
// addChannel(channelHandlerContext, userIdForReq);
// }
if(StringUtils.equals(msgType,String.valueOf(MsgType.CHAT.getCode()))){
// ApiRespMessage apiRespMessage = ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
// .respTime(String.valueOf(System.currentTimeMillis()))
// .contents("测试通过,很高兴收到你的信息")
// .msgType(String.valueOf(MsgType.CHAT.getCode()))
// .build();
// String response = JSON.toJSONString(apiRespMessage);
// channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));
if(!isExists(userIdForReq)){
String respMessage = "用户标识: "+userIdForReq+" 未登录";
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.NO_LOGIN_711.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.INIT.getCode()))
.contents(respMessage)
.build());
}else{
aiChatUtils.chatStream(apiReqMessage);
}
}else if(StringUtils.equals(msgType,String.valueOf(MsgType.INIT.getCode()))){
//一、业务黑名单检测(多次违规,永久锁定)
//二、账户锁定检测(临时锁定)
//三、多设备登录检测
//四、剩余对话次数检测
//检测通过,绑定用户与channel之间关系
addChannel(channelHandlerContext, userIdForReq);
String respMessage = "用户标识: "+userIdForReq+" 登录成功";
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.INIT.getCode()))
.contents(respMessage)
.build());
}else{
log.info("用户标识: {}, 消息类型有误,不支持类型: {}",userIdForReq,msgType);
}
} catch (Exception e) {
log.warn("【BusinessHandler】接收到请求内容:{},异常信息:{}", content, e.getMessage(), e);
// 异常返回
return;
}
}
}
4.4. 修改IM的业务逻辑处理抽象类
增加这一篇的代码
import com.alibaba.fastjson.JSON;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.GenericFutureListener;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.DisposableBean;
import java.util.concurrent.ConcurrentHashMap;
@SuppressWarnings("all")
@Slf4j
public abstract class AbstractBusinessLogicHandler<I> extends SimpleChannelInboundHandler<I> implements DisposableBean {
protected static final ConcurrentHashMap<Long, ChannelHandlerContext> USER_ID_TO_CHANNEL = new ConcurrentHashMap<>();
// 用户特征属性与channel绑定
public static final AttributeKey<Long> USER_ID_ATTRIBUTE_KEY = AttributeKey.valueOf("userId");
/**
* 添加socket通道
*
* @param channelHandlerContext socket通道上下文
*/
protected void addChannel(ChannelHandlerContext channelHandlerContext, Long userId) {
// 将当前通道存放起来
USER_ID_TO_CHANNEL.put(userId, channelHandlerContext);
// 记录用户ID
channelHandlerContext.channel().attr(USER_ID_ATTRIBUTE_KEY).set(userId);
}
/**
* 判斷用戶是否存在
* @param userId
* @return
*/
protected boolean isExists(Long userId){
return USER_ID_TO_CHANNEL.containsKey(userId);
}
protected void buildResponse(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg) {
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code))
.respTime(String.valueOf(respTime))
.msgType(String.valueOf(msgType))
.contents(msg).build());
}
protected void buildResponseIncludeOperateId(ChannelHandlerContext channelHandlerContext, int code, long respTime, int msgType, String msg, String operateId) {
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(code))
.respTime(String.valueOf(respTime))
.msgType(String.valueOf(msgType))
.operateId(operateId)
.contents(msg).build());
}
/**
* 获取用户ID
*
* @param channelHandlerContext socket通道上下文
*/
protected Long GetUserIdByChannel(ChannelHandlerContext channelHandlerContext) {
if (channelHandlerContext.channel().hasAttr(USER_ID_ATTRIBUTE_KEY)) {
Long userId = channelHandlerContext.channel().attr(USER_ID_ATTRIBUTE_KEY).get();
if (userId == null) {
return null;
}
if (USER_ID_TO_CHANNEL.containsKey(userId)) {
return userId;
}
}
return null;
}
protected void buildResponseAndClose(ChannelHandlerContext channelHandlerContext, ApiRespMessage apiRespMessage) {
String response = JSON.toJSONString(apiRespMessage);
Long userId = GetUserIdByChannel(channelHandlerContext);
if(null == userId){
log.warn("【AbstractBusinessLogicHandler】关闭通道!响应内容: {}", response);
ChannelFuture future = channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));
future.addListener(new GenericFutureListener<Future<? super Void>>() {
public void operationComplete(Future future) throws Exception {
channelHandlerContext.close();
}
});
}else{
log.warn("【AbstractBusinessLogicHandler】关闭通道!用户ID:{},响应内容: {}", userId, response);
ChannelFuture future = channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));
future.addListener(new GenericFutureListener<Future<? super Void>>() {
public void operationComplete(Future future) throws Exception {
// 清除离线用户
channelHandlerContext.channel().attr(USER_ID_ATTRIBUTE_KEY).remove();
channelHandlerContext.close();
}
});
}
}
@Override
public void destroy() throws Exception {
try {
USER_ID_TO_CHANNEL.clear();
} catch (Throwable e) {
}
}
protected static void buildResponse(ChannelHandlerContext channelHandlerContext, ApiRespMessage apiRespMessage) {
String response = JSON.toJSONString(apiRespMessage);
channelHandlerContext.writeAndFlush(new TextWebSocketFrame(response));
}
public static void pushChatMessageForUser(Long userId,String chatRespMessage) {
ChannelHandlerContext channelHandlerContext = USER_ID_TO_CHANNEL.get(userId);
if (channelHandlerContext != null ) {
buildResponse(channelHandlerContext, ApiRespMessage.builder().code(String.valueOf(StatusCode.SUCCESS.getCode()))
.respTime(String.valueOf(System.currentTimeMillis()))
.msgType(String.valueOf(MsgType.CHAT.getCode()))
.contents(chatRespMessage)
.build());
return;
}
}
}
五、测试
5.1. 用户未登录场景测试
# 测试参数:
{"userId":12345,"msgType":1,"contents":"你好","history":[]}
5.2. 用户登录场景测试
# 测试参数:
{"userId":12345,"msgType":2}
# 测试参数:
{"userId":12345,"msgType":1,"contents":"你好","history":[]}
六、附带说明
6.1. 业务黑名单检测
用户在指定周期内被多次锁定,触发系统阈值,被系统自动或人工拉黑
6.2. 账户锁定检测
用户在指定周期内多次发起违规对话(例如:涉黄/涉政/血腥/暴恐等),触发系统阈值,被系统自动锁定
6.3. 多设备登录检测
在业务上只允许单设备在线,但用户在多个设备(例如:手机、平板等)发起登录操作,触发系统阈值
6.4. 剩余对话次数检测
在业务上,我们限制未付费用户每天只能进行N次对话。
PS:上述内容的完善将放在“业务安全系列”文章中,里面包含算法备案、违规词检测、重新开始新的话题等复杂业务。