这是一个使用 Java JDK 8 和 Spring Boot 实现的WebSocket演示项目。目的是为解决多端消息通讯的问题。
WebSocket 是一种基于 TCP 的全双工通信协议,核心作用是解决传统 HTTP 协议 “请求 - 响应” 模式的局限性,实现 客户端与服务器之间的实时、双向、低延迟数据传输。
源码地址:https://gitee.com/lqh4188/web-socket
可通过UserId来创建独立的联接,进行用户隔离

由于websocket对传输的内容有限制,若内容较大可进行缓冲区大小设置,并对不同文本进行分片处理
ChatWebSocketHandler.java代码:
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
package com.example.websocket; import java.io.ByteArrayOutputStream; import java.net.URI; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap;
import org.springframework.web.socket.BinaryMessage; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.handler.AbstractWebSocketHandler; import com.fasterxml.jackson.databind.ObjectMapper;
public class ChatWebSocketHandler extends AbstractWebSocketHandler { private final ConcurrentHashMap<String, Set<WebSocketSession>> userSessions = new ConcurrentHashMap<>(); private static final ObjectMapper MAPPER = new ObjectMapper(); private final ConcurrentHashMap<String, StringBuilder> textFragments = new ConcurrentHashMap<>(); private final ConcurrentHashMap<String, ByteArrayOutputStream> binaryFragments = new ConcurrentHashMap<>();
@Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { // 验证用户ID的有效性 String uid = resolveUserId(session); if (uid == null || uid.isEmpty()) { session.close(CloseStatus.BAD_DATA); return; } session.getAttributes().put("userId", uid); //多会话管理 userSessions.computeIfAbsent(uid, k -> ConcurrentHashMap.newKeySet()).add(session); }
@Override protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception { // 分片处理 String id = session.getId(); if (!message.isLast()) { textFragments.computeIfAbsent(id, k -> new StringBuilder()).append(message.getPayload()); return; } StringBuilder sb = textFragments.remove(id); String payload = sb != null ? sb.append(message.getPayload()).toString() : message.getPayload(); routePayload(session, payload); }
@Override protected void handleBinaryMessage(WebSocketSession session, BinaryMessage message) throws Exception { //二进制消息处理 String id = session.getId(); ByteBuffer buf = message.getPayload(); byte[] chunk = new byte[buf.remaining()]; buf.get(chunk); ByteArrayOutputStream acc = binaryFragments.computeIfAbsent(id, k -> new ByteArrayOutputStream()); acc.write(chunk); if (message.isLast()) { byte[] all = acc.toByteArray(); binaryFragments.remove(id); String payload = new String(all, StandardCharsets.UTF_8); routePayload(session, payload); } }
@Override public boolean supportsPartialMessages() { return true; }
@Override public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception { // WebSocket 连接关闭时的清理逻辑 Object v = session.getAttributes().get("userId"); if (v == null) return; String uid = String.valueOf(v); Set<WebSocketSession> set = userSessions.get(uid); if (set != null) { set.remove(session); if (set.isEmpty()) userSessions.remove(uid); } }
/** 从 WebSocket 连接的 URL 查询参数中提取用户ID */ private String resolveUserId(WebSocketSession session) { URI uri = session.getUri(); if (uri == null) return null; String q = uri.getQuery(); if (q == null || q.isEmpty()) return null; String[] parts = q.split("&"); for (String p : parts) { int i = p.indexOf('='); if (i > 0) { String k = p.substring(0, i); String val = p.substring(i + 1); if ("userId".equals(k)) return val; } } return null; }
private void routePayload(WebSocketSession session, String payload) throws Exception { Object v = session.getAttributes().get("userId"); if (v == null) return; String fromUid = String.valueOf(v);
// 解析消息 Message message = new Message(); message.setFromUserId(fromUid);
try { // 尝试将payload解析为Message对象 Message receivedMsg = MAPPER.readValue(payload, Message.class); message.setToUserId(receivedMsg.getToUserId()); message.setContent(receivedMsg.getContent()); message.setType(receivedMsg.getType()); } catch (Exception e) { // 如果解析失败,将整个payload作为content message.setContent(payload); }
String toUid = message.getToUserId(); boolean isP2P = toUid != null && !toUid.isEmpty();
Set<WebSocketSession> targets; if (isP2P) { targets = userSessions.get(toUid); } else { targets = userSessions.get(fromUid); }
// 序列化消息对象 String outStr = MAPPER.writeValueAsString(message); TextMessage msg = new TextMessage(outStr);
if (targets == null || targets.isEmpty()) { if (session.isOpen()) { session.sendMessage(msg); } return; }
for (WebSocketSession s : targets) { if (s.isOpen()) { s.sendMessage(msg); } } if (isP2P && session.isOpen()) { session.sendMessage(msg); } } } |
配置类WebSocketConfig.java
|
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
package com.example.websocket;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean;
@Configuration @EnableWebSocket public class WebSocketConfig implements WebSocketConfigurer { @Override public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) { registry.addHandler(chatHandler(), "/ws").setAllowedOriginPatterns("*"); }
@Bean public WebSocketHandler chatHandler() { return new ChatWebSocketHandler(); }
// 配置 WebSocket 容器参数(解决消息过大、超时等问题) @Bean public ServletServerContainerFactoryBean createWebSocketContainer() { ServletServerContainerFactoryBean container = new ServletServerContainerFactoryBean(); // 文本消息缓冲区:2MB(解决解码后消息过大的核心配置) container.setMaxTextMessageBufferSize(2 * 1024 * 1024); // 二进制消息缓冲区:4MB(按需配置) container.setMaxBinaryMessageBufferSize(4 * 1024 * 1024); // 会话空闲超时:60秒(无交互则关闭连接) container.setMaxSessionIdleTimeout(60_000L); return container; } } |