package cancellablewebsocket import ( "context" "net/http" "time" "github.com/gorilla/websocket" ) type CancellableWebSocket struct { conn *websocket.Conn ctx context.Context cancel context.CancelFunc errors chan error } func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) { conn, _, err := dialer.Dial(url, requestHeader) if err != nil { return nil, err } childCtx, cancel := context.WithCancel(ctx) cws := &CancellableWebSocket{ conn: conn, ctx: childCtx, cancel: cancel, errors: make(chan error), } go cws.listenForCancel() return cws, nil } func (cws *CancellableWebSocket) ReadJSON(v interface{}) error { err := cws.conn.ReadJSON(v) if err != nil { // Check if context was not cancelled, // if so, return error if cws.ctx.Err() == nil { return err } } return nil } func (cws *CancellableWebSocket) WriteJSON(v interface{}) error { err := cws.conn.WriteJSON(v) if err != nil { if cws.ctx.Err() == nil { return err } } return nil } func (cws *CancellableWebSocket) Close() error { cws.cancel() err, ok := <-cws.errors if ok && err != nil { return err } return nil } func (cws *CancellableWebSocket) listenForCancel() { <-cws.ctx.Done() var err error aLongTimeAgo := time.Unix(0, 0) if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil { cws.errors <- err } if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil { cws.errors <- err } if err = cws.conn.Close(); err != nil { cws.errors <- err } close(cws.errors) }