diff options
Diffstat (limited to 'pkg/libs/cancellablewebsocket/cancellablewebsocket.go')
| -rw-r--r-- | pkg/libs/cancellablewebsocket/cancellablewebsocket.go | 91 |
1 files changed, 91 insertions, 0 deletions
diff --git a/pkg/libs/cancellablewebsocket/cancellablewebsocket.go b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go new file mode 100644 index 0000000..5dcffc1 --- /dev/null +++ b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go @@ -0,0 +1,91 @@ +package cancellablewebsocket + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type CancellableWebSocket struct { + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + errors chan error +} + +func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) { + conn, _, err := dialer.Dial(url, requestHeader) + if err != nil { + return nil, err + } + + childCtx, cancel := context.WithCancel(ctx) + + cws := &CancellableWebSocket{ + conn: conn, + ctx: childCtx, + cancel: cancel, + errors: make(chan error), + } + + go cws.listenForCancel() + + return cws, nil +} + +func (cws *CancellableWebSocket) ReadJSON(v interface{}) error { + err := cws.conn.ReadJSON(v) + if err != nil { + // Check if context was not cancelled, + // if so, return error + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) WriteJSON(v interface{}) error { + err := cws.conn.WriteJSON(v) + if err != nil { + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) Close() error { + cws.cancel() + err, ok := <-cws.errors + + if ok && err != nil { + return err + } + + return nil +} + +func (cws *CancellableWebSocket) listenForCancel() { + + <-cws.ctx.Done() + + var err error + aLongTimeAgo := time.Unix(0, 0) + + if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.Close(); err != nil { + cws.errors <- err + } + + close(cws.errors) +} |
