66import io .jooby .SneakyThrows ;
77import io .jooby .WebSocket ;
88import io .jooby .WebSocketCloseStatus ;
9- import io .jooby .WebSocketContext ;
9+ import io .jooby .WebSocketListener ;
1010import io .jooby .WebSocketMessage ;
1111import io .netty .buffer .ByteBuf ;
1212import io .netty .buffer .Unpooled ;
1717import io .netty .handler .codec .http .websocketx .TextWebSocketFrame ;
1818import io .netty .handler .codec .http .websocketx .WebSocketFrame ;
1919
20- public class NettyWebSocketContext implements WebSocketContext , WebSocket {
20+ import java .util .Collections ;
21+ import java .util .List ;
22+ import java .util .concurrent .ConcurrentHashMap ;
23+ import java .util .concurrent .ConcurrentMap ;
24+ import java .util .concurrent .CopyOnWriteArrayList ;
25+
26+ public class NettyWebSocket implements WebSocketListener , WebSocket {
27+ /** All connected websocket. */
28+ private static final ConcurrentMap <String , List <WebSocket >> all = new ConcurrentHashMap <>();
29+
2130 private final NettyContext netty ;
2231 private final boolean dispatch ;
32+ private final String key ;
2333 private ByteBuf buffer ;
2434 private WebSocket .OnConnect connectCallback ;
2535 private WebSocket .OnMessage messageCallback ;
2636 private OnClose onCloseCallback ;
2737 private OnError onErrorCallback ;
2838
29- public NettyWebSocketContext (NettyContext ctx ) {
39+ public NettyWebSocket (NettyContext ctx ) {
3040 this .netty = ctx ;
41+ this .key = ctx .pathString ();
3142 dispatch = !ctx .isInIoThread ();
3243 }
3344
34- public WebSocket send (String text ) {
35- return send (new TextWebSocketFrame (text ));
45+ public WebSocket send (String text , boolean broadcast ) {
46+ if (broadcast ) {
47+ for (WebSocket ws : all .getOrDefault (key , Collections .emptyList ())) {
48+ ws .send (text , false );
49+ }
50+ } else {
51+ send (new TextWebSocketFrame (text ));
52+ }
53+ return this ;
54+ }
55+
56+ public WebSocket send (byte [] bytes , boolean broadcast ) {
57+ if (broadcast ) {
58+ for (WebSocket ws : all .getOrDefault (key , Collections .emptyList ())) {
59+ ws .send (bytes , false );
60+ }
61+ } else {
62+ send (new TextWebSocketFrame (Unpooled .wrappedBuffer (bytes )));
63+ }
64+ return this ;
3665 }
3766
38- public WebSocket send (byte [] bytes ) {
39- return send (new TextWebSocketFrame (Unpooled .wrappedBuffer (bytes )));
67+ @ Override public WebSocket render (Object message , boolean broadcast ) {
68+ if (broadcast ) {
69+ for (WebSocket ws : all .getOrDefault (key , Collections .emptyList ())) {
70+ ws .render (message , false );
71+ }
72+ } else {
73+ Context .websocket (netty , this ).render (message );
74+ }
75+ return this ;
4076 }
4177
4278 private WebSocket send (TextWebSocketFrame frame ) {
@@ -48,11 +84,6 @@ private WebSocket send(TextWebSocketFrame frame) {
4884 return this ;
4985 }
5086
51- @ Override public WebSocket render (Object message ) {
52- Context .websocket (netty , this ).render (message );
53- return this ;
54- }
55-
5687 @ Override public Context getContext () {
5788 return Context .readOnly (netty );
5889 }
@@ -61,20 +92,24 @@ public boolean isOpen() {
6192 return netty .ctx .channel ().isOpen ();
6293 }
6394
64- @ Override public void onConnect (WebSocket .OnConnect callback ) {
95+ @ Override public WebSocketListener onConnect (WebSocket .OnConnect callback ) {
6596 connectCallback = callback ;
97+ return this ;
6698 }
6799
68- @ Override public void onMessage (WebSocket .OnMessage callback ) {
100+ @ Override public WebSocketListener onMessage (WebSocket .OnMessage callback ) {
69101 messageCallback = callback ;
102+ return this ;
70103 }
71104
72- @ Override public void onClose (WebSocket .OnClose callback ) {
105+ @ Override public WebSocketListener onClose (WebSocket .OnClose callback ) {
73106 onCloseCallback = callback ;
107+ return this ;
74108 }
75109
76- @ Override public void onError (OnError callback ) {
110+ @ Override public WebSocketListener onError (OnError callback ) {
77111 onErrorCallback = callback ;
112+ return this ;
78113 }
79114
80115 void handleFrame (WebSocketFrame frame ) {
@@ -113,22 +148,26 @@ private void handleMessage(WebSocketFrame frame) {
113148 }
114149
115150 private void handleClose (WebSocketCloseStatus closeStatus ) {
116- if (isOpen ()) {
117- if (onCloseCallback != null ) {
118- Runnable task = webSocketTask (() -> onCloseCallback .onClose (this , closeStatus ));
119- Runnable closeCallback = () -> {
120- try {
121- task .run ();
122- } finally {
123- netty .ctx .channel ()
124- .writeAndFlush (
125- new CloseWebSocketFrame (closeStatus .getCode (), closeStatus .getReason ()))
126- .addListener (ChannelFutureListener .CLOSE );
127- }
128- };
129-
130- fireCallback (closeCallback );
151+ try {
152+ if (isOpen ()) {
153+ if (onCloseCallback != null ) {
154+ Runnable task = webSocketTask (() -> onCloseCallback .onClose (this , closeStatus ));
155+ Runnable closeCallback = () -> {
156+ try {
157+ task .run ();
158+ } finally {
159+ netty .ctx .channel ()
160+ .writeAndFlush (
161+ new CloseWebSocketFrame (closeStatus .getCode (), closeStatus .getReason ()))
162+ .addListener (ChannelFutureListener .CLOSE );
163+ }
164+ };
165+
166+ fireCallback (closeCallback );
167+ }
131168 }
169+ } finally {
170+ removeSession (this );
132171 }
133172 }
134173
@@ -150,6 +189,7 @@ private void handleError(Throwable x) {
150189 }
151190
152191 void fireConnect () {
192+ addSession (this );
153193 if (connectCallback != null ) {
154194 fireCallback (webSocketTask (() -> connectCallback .onConnect (this )));
155195 }
@@ -193,4 +233,15 @@ private static WebSocketCloseStatus toWebSocketCloseStatus(CloseWebSocketFrame f
193233 frame .release ();
194234 }
195235 }
236+
237+ private void addSession (NettyWebSocket ws ) {
238+ all .computeIfAbsent (ws .key , k -> new CopyOnWriteArrayList <>()).add (ws );
239+ }
240+
241+ private void removeSession (NettyWebSocket ws ) {
242+ List <WebSocket > sockets = all .get (ws .key );
243+ if (sockets != null ) {
244+ sockets .remove (ws );
245+ }
246+ }
196247}
0 commit comments