about summary refs log tree commit diff
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-04-06 01:05:43 +0200
committerMel <einebeere@gmail.com>2022-04-06 01:05:43 +0200
commiteea8d77a8676feef17ca7c675742958c7bc3e93c (patch)
tree375bd5f917a4efa5190d82db8c82eb8d72919c92
parent3160b53ddfb6773624e96c306aec86ac2e80de31 (diff)
downloadjinx-eea8d77a8676feef17ca7c675742958c7bc3e93c.tar.zst
jinx-eea8d77a8676feef17ca7c675742958c7bc3e93c.zip
Excuse websockets cancelled through ctx
-rw-r--r--pkg/discord/discord.go5
-rw-r--r--pkg/libs/cancellablewebsocket/cancellablewebsocket.go91
2 files changed, 94 insertions, 2 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index bd8a7f7..709f5e1 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -5,6 +5,7 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
+	"jinx/pkg/libs/cancellablewebsocket"
 	"math/rand"
 	"net/http"
 	"time"
@@ -16,7 +17,7 @@ import (
 type Discord struct {
 	token        string
 	logger       *zerolog.Logger
-	conn         *websocket.Conn
+	conn         *cancellablewebsocket.CancellableWebSocket
 	eventHandler *EventHandlerImpl
 	rest         REST
 }
@@ -40,7 +41,7 @@ func (d *Discord) Connect(ctx context.Context) error {
 	}
 
 	connectHeader := http.Header{}
-	d.conn, _, err = websocket.DefaultDialer.Dial(gatewayURL, connectHeader)
+	d.conn, err = cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, gatewayURL, connectHeader)
 	if err != nil {
 		return err
 	}
diff --git a/pkg/libs/cancellablewebsocket/cancellablewebsocket.go b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go
new file mode 100644
index 0000000..5dcffc1
--- /dev/null
+++ b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go
@@ -0,0 +1,91 @@
+package cancellablewebsocket
+
+import (
+	"context"
+	"net/http"
+	"time"
+
+	"github.com/gorilla/websocket"
+)
+
+type CancellableWebSocket struct {
+	conn   *websocket.Conn
+	ctx    context.Context
+	cancel context.CancelFunc
+	errors chan error
+}
+
+func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) {
+	conn, _, err := dialer.Dial(url, requestHeader)
+	if err != nil {
+		return nil, err
+	}
+
+	childCtx, cancel := context.WithCancel(ctx)
+
+	cws := &CancellableWebSocket{
+		conn:   conn,
+		ctx:    childCtx,
+		cancel: cancel,
+		errors: make(chan error),
+	}
+
+	go cws.listenForCancel()
+
+	return cws, nil
+}
+
+func (cws *CancellableWebSocket) ReadJSON(v interface{}) error {
+	err := cws.conn.ReadJSON(v)
+	if err != nil {
+		// Check if context was not cancelled,
+		// if so, return error
+		if cws.ctx.Err() == nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (cws *CancellableWebSocket) WriteJSON(v interface{}) error {
+	err := cws.conn.WriteJSON(v)
+	if err != nil {
+		if cws.ctx.Err() == nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (cws *CancellableWebSocket) Close() error {
+	cws.cancel()
+	err, ok := <-cws.errors
+
+	if ok && err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func (cws *CancellableWebSocket) listenForCancel() {
+
+	<-cws.ctx.Done()
+
+	var err error
+	aLongTimeAgo := time.Unix(0, 0)
+
+	if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil {
+		cws.errors <- err
+	}
+
+	if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil {
+		cws.errors <- err
+	}
+
+	if err = cws.conn.Close(); err != nil {
+		cws.errors <- err
+	}
+
+	close(cws.errors)
+}