From eea8d77a8676feef17ca7c675742958c7bc3e93c Mon Sep 17 00:00:00 2001 From: Mel Date: Wed, 6 Apr 2022 01:05:43 +0200 Subject: Excuse websockets cancelled through ctx --- .../cancellablewebsocket/cancellablewebsocket.go | 91 ++++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 pkg/libs/cancellablewebsocket/cancellablewebsocket.go (limited to 'pkg/libs') diff --git a/pkg/libs/cancellablewebsocket/cancellablewebsocket.go b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go new file mode 100644 index 0000000..5dcffc1 --- /dev/null +++ b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go @@ -0,0 +1,91 @@ +package cancellablewebsocket + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type CancellableWebSocket struct { + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + errors chan error +} + +func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) { + conn, _, err := dialer.Dial(url, requestHeader) + if err != nil { + return nil, err + } + + childCtx, cancel := context.WithCancel(ctx) + + cws := &CancellableWebSocket{ + conn: conn, + ctx: childCtx, + cancel: cancel, + errors: make(chan error), + } + + go cws.listenForCancel() + + return cws, nil +} + +func (cws *CancellableWebSocket) ReadJSON(v interface{}) error { + err := cws.conn.ReadJSON(v) + if err != nil { + // Check if context was not cancelled, + // if so, return error + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) WriteJSON(v interface{}) error { + err := cws.conn.WriteJSON(v) + if err != nil { + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) Close() error { + cws.cancel() + err, ok := <-cws.errors + + if ok && err != nil { + return err + } + + return nil +} + +func (cws *CancellableWebSocket) listenForCancel() { + + <-cws.ctx.Done() + + var err error + aLongTimeAgo := time.Unix(0, 0) + + if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.Close(); err != nil { + cws.errors <- err + } + + close(cws.errors) +} -- cgit 1.4.1