about summary refs log tree commit diff
path: root/pkg/libs/cancellablewebsocket/cancellablewebsocket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/libs/cancellablewebsocket/cancellablewebsocket.go')
-rw-r--r--pkg/libs/cancellablewebsocket/cancellablewebsocket.go91
1 files changed, 91 insertions, 0 deletions
diff --git a/pkg/libs/cancellablewebsocket/cancellablewebsocket.go b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go
new file mode 100644
index 0000000..5dcffc1
--- /dev/null
+++ b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go
@@ -0,0 +1,91 @@
+package cancellablewebsocket
+
+import (
+	"context"
+	"net/http"
+	"time"
+
+	"github.com/gorilla/websocket"
+)
+
+type CancellableWebSocket struct {
+	conn   *websocket.Conn
+	ctx    context.Context
+	cancel context.CancelFunc
+	errors chan error
+}
+
+func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) {
+	conn, _, err := dialer.Dial(url, requestHeader)
+	if err != nil {
+		return nil, err
+	}
+
+	childCtx, cancel := context.WithCancel(ctx)
+
+	cws := &CancellableWebSocket{
+		conn:   conn,
+		ctx:    childCtx,
+		cancel: cancel,
+		errors: make(chan error),
+	}
+
+	go cws.listenForCancel()
+
+	return cws, nil
+}
+
+func (cws *CancellableWebSocket) ReadJSON(v interface{}) error {
+	err := cws.conn.ReadJSON(v)
+	if err != nil {
+		// Check if context was not cancelled,
+		// if so, return error
+		if cws.ctx.Err() == nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (cws *CancellableWebSocket) WriteJSON(v interface{}) error {
+	err := cws.conn.WriteJSON(v)
+	if err != nil {
+		if cws.ctx.Err() == nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (cws *CancellableWebSocket) Close() error {
+	cws.cancel()
+	err, ok := <-cws.errors
+
+	if ok && err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (cws *CancellableWebSocket) listenForCancel() {
+
+	<-cws.ctx.Done()
+
+	var err error
+	aLongTimeAgo := time.Unix(0, 0)
+
+	if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil {
+		cws.errors <- err
+	}
+
+	if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil {
+		cws.errors <- err
+	}
+
+	if err = cws.conn.Close(); err != nil {
+		cws.errors <- err
+	}
+
+	close(cws.errors)
+}