diff options
| -rw-r--r-- | pkg/discord/discord.go | 5 | ||||
| -rw-r--r-- | pkg/libs/cancellablewebsocket/cancellablewebsocket.go | 91 |
2 files changed, 94 insertions, 2 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go index bd8a7f7..709f5e1 100644 --- a/pkg/discord/discord.go +++ b/pkg/discord/discord.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "jinx/pkg/libs/cancellablewebsocket" "math/rand" "net/http" "time" @@ -16,7 +17,7 @@ import ( type Discord struct { token string logger *zerolog.Logger - conn *websocket.Conn + conn *cancellablewebsocket.CancellableWebSocket eventHandler *EventHandlerImpl rest REST } @@ -40,7 +41,7 @@ func (d *Discord) Connect(ctx context.Context) error { } connectHeader := http.Header{} - d.conn, _, err = websocket.DefaultDialer.Dial(gatewayURL, connectHeader) + d.conn, err = cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, gatewayURL, connectHeader) if err != nil { return err } diff --git a/pkg/libs/cancellablewebsocket/cancellablewebsocket.go b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go new file mode 100644 index 0000000..5dcffc1 --- /dev/null +++ b/pkg/libs/cancellablewebsocket/cancellablewebsocket.go @@ -0,0 +1,91 @@ +package cancellablewebsocket + +import ( + "context" + "net/http" + "time" + + "github.com/gorilla/websocket" +) + +type CancellableWebSocket struct { + conn *websocket.Conn + ctx context.Context + cancel context.CancelFunc + errors chan error +} + +func Dial(dialer *websocket.Dialer, ctx context.Context, url string, requestHeader http.Header) (*CancellableWebSocket, error) { + conn, _, err := dialer.Dial(url, requestHeader) + if err != nil { + return nil, err + } + + childCtx, cancel := context.WithCancel(ctx) + + cws := &CancellableWebSocket{ + conn: conn, + ctx: childCtx, + cancel: cancel, + errors: make(chan error), + } + + go cws.listenForCancel() + + return cws, nil +} + +func (cws *CancellableWebSocket) ReadJSON(v interface{}) error { + err := cws.conn.ReadJSON(v) + if err != nil { + // Check if context was not cancelled, + // if so, return error + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) WriteJSON(v interface{}) error { + err := cws.conn.WriteJSON(v) + if err != nil { + if cws.ctx.Err() == nil { + return err + } + } + return nil +} + +func (cws *CancellableWebSocket) Close() error { + cws.cancel() + err, ok := <-cws.errors + + if ok && err != nil { + return err + } + + return nil +} + +func (cws *CancellableWebSocket) listenForCancel() { + + <-cws.ctx.Done() + + var err error + aLongTimeAgo := time.Unix(0, 0) + + if err = cws.conn.SetReadDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.SetWriteDeadline(aLongTimeAgo); err != nil { + cws.errors <- err + } + + if err = cws.conn.Close(); err != nil { + cws.errors <- err + } + + close(cws.errors) +} |
