package cancellablewebsocket import ( "context" "net/http" "time" "github.com/gorilla/websocket" ) type CancellableWebSocket struct { conn *websocket.Conn ctx context.Context cancel context.CancelFunc closeCode int // 0 means killed errors chan error } func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) { conn, _, err := dialer.Dial(url, requestHeader) if err != nil { return nil, err } childCtx, cancel := context.WithCancel(ctx) cws := &CancellableWebSocket{ conn: conn, ctx: childCtx, cancel: cancel, closeCode: 0, errors: make(chan error), } go cws.listenForCancel() return cws, nil } func (cws *CancellableWebSocket) ReadJSON(v interface{}) error { err := cws.conn.ReadJSON(v) if err != nil { // Check if context was not cancelled, // if so, return error if cws.ctx.Err() == nil { return err } } return nil } func (cws *CancellableWebSocket) WriteJSON(v interface{}) error { err := cws.conn.WriteJSON(v) if err != nil { if cws.ctx.Err() == nil { return err } } return nil } func (cws *CancellableWebSocket) Kill() error { cws.cancel() err, ok := <-cws.errors if ok && err != nil { return err } return nil } func (cws *CancellableWebSocket) Close(code int) error { cws.closeCode = code if err := cws.Kill(); err != nil { cws.closeCode = 0 // Reset close code return err } return nil } func (cws *CancellableWebSocket) OnClose(f func(code int, text string) error) { cws.conn.SetCloseHandler(f) } func (cws *CancellableWebSocket) listenForCancel() { <-cws.ctx.Done() var err error aLongTimeAgo := time.Unix(0, 0) if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil { cws.errors <- err } if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil { cws.errors <- err } if cws.closeCode == 0 { // Close without sending close. if err = cws.conn.Close(); err != nil { cws.errors <- err } } else { // Send close with code. // NOTE: The SetWriteDeadline above does not affect this. :) err = cws.conn.WriteControl(websocket.CloseMessage, websocket.FormatCloseMessage(cws.closeCode, ""), time.Now().Add(time.Second)) if err != nil { cws.errors <- err } } close(cws.errors) }