|
package com.example.websocket;
import java.io.ByteArrayOutputStream;
import java.net.URI;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import org.springframework.web.socket.BinaryMessage;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.AbstractWebSocketHandler;
import com.fasterxml.jackson.databind.ObjectMapper;
public class ChatWebSocketHandler extends AbstractWebSocketHandler {
private final ConcurrentHashMap<String, Set<WebSocketSession>> userSessions = new ConcurrentHashMap<>();
private static final ObjectMapper MAPPER = new ObjectMapper();
private final ConcurrentHashMap<String, StringBuilder> textFragments = new ConcurrentHashMap<>();
private final ConcurrentHashMap<String, ByteArrayOutputStream> binaryFragments = new ConcurrentHashMap<>();
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
// 验证用户ID的有效性
String uid = resolveUserId(session);
if (uid == null || uid.isEmpty()) {
session.close(CloseStatus.BAD_DATA);
return;
}
session.getAttributes().put("userId", uid);
//多会话管理
userSessions.computeIfAbsent(uid, k -> ConcurrentHashMap.newKeySet()).add(session);
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
// 分片处理
String id = session.getId();
if (!message.isLast()) {
textFragments.computeIfAbsent(id, k -> new StringBuilder()).append(message.getPayload());
return;
}
StringBuilder sb = textFragments.remove(id);
String payload = sb != null ? sb.append(message.getPayload()).toString() : message.getPayload();
routePayload(session, payload);
}
@Override
protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception {
//二进制消息处理
String id = session.getId();
ByteBuffer buf = message.getPayload();
byte[] chunk = new byte[buf.remaining()];
buf.get(chunk);
ByteArrayOutputStream acc = binaryFragments.computeIfAbsent(id, k -> new ByteArrayOutputStream());
acc.write(chunk);
if (message.isLast()) {
byte[] all = acc.toByteArray();
binaryFragments.remove(id);
String payload = new String(all, StandardCharsets.UTF_8);
routePayload(session, payload);
}
}
@Override
public boolean supportsPartialMessages() {
return true;
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
// WebSocket 连接关闭时的清理逻辑
Object v = session.getAttributes().get("userId");
if (v == null) return;
String uid = String.valueOf(v);
Set<WebSocketSession> set = userSessions.get(uid);
if (set != null) {
set.remove(session);
if (set.isEmpty()) userSessions.remove(uid);
}
}
/** 从 WebSocket 连接的 URL 查询参数中提取用户ID */
private String resolveUserId(WebSocketSession session) {
URI uri = session.getUri();
if (uri == null) return null;
String q = uri.getQuery();
if (q == null || q.isEmpty()) return null;
String[] parts = q.split("&");
for (String p : parts) {
int i = p.indexOf('=');
if (i > 0) {
String k = p.substring(0, i);
String val = p.substring(i + 1);
if ("userId".equals(k)) return val;
}
}
return null;
}
private void routePayload(WebSocketSession session, String payload) throws Exception {
Object v = session.getAttributes().get("userId");
if (v == null) return;
String fromUid = String.valueOf(v);
// 解析消息
Message message = new Message();
message.setFromUserId(fromUid);
try {
// 尝试将payload解析为Message对象
Message receivedMsg = MAPPER.readValue(payload, Message.class);
message.setToUserId(receivedMsg.getToUserId());
message.setContent(receivedMsg.getContent());
message.setType(receivedMsg.getType());
} catch (Exception e) {
// 如果解析失败,将整个payload作为content
message.setContent(payload);
}
String toUid = message.getToUserId();
boolean isP2P = toUid != null && !toUid.isEmpty();
Set<WebSocketSession> targets;
if (isP2P) {
targets = userSessions.get(toUid);
} else {
targets = userSessions.get(fromUid);
}
// 序列化消息对象
String outStr = MAPPER.writeValueAsString(message);
TextMessage msg = new TextMessage(outStr);
if (targets == null || targets.isEmpty()) {
if (session.isOpen()) {
session.sendMessage(msg);
}
return;
}
for (WebSocketSession s : targets) {
if (s.isOpen()) {
s.sendMessage(msg);
}
}
if (isP2P && session.isOpen()) {
session.sendMessage(msg);
}
}
}
|