about summary refs log tree commit diff
path: root/pkg/discord/gateway
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-04-12 17:21:05 +0200
committerMel <einebeere@gmail.com>2022-04-12 17:21:05 +0200
commit6163d259ed52991e2f95632b5a0516607aa56a5f (patch)
treed87514d024d55f976ec78176fd0b0d8ebe946d7c /pkg/discord/gateway
parent6cd2890450aaf71e97004d421237996f0a42d04c (diff)
downloadjinx-6163d259ed52991e2f95632b5a0516607aa56a5f.tar.zst
jinx-6163d259ed52991e2f95632b5a0516607aa56a5f.zip
Handle gateway errors and reconnections
Diffstat (limited to 'pkg/discord/gateway')
-rw-r--r--pkg/discord/gateway/closecodes.go22
-rw-r--r--pkg/discord/gateway/gateway.go210
-rw-r--r--pkg/discord/gateway/heartbeat.go37
-rw-r--r--pkg/discord/gateway/payloads.go10
4 files changed, 234 insertions, 45 deletions
diff --git a/pkg/discord/gateway/closecodes.go b/pkg/discord/gateway/closecodes.go
new file mode 100644
index 0000000..d6e9804
--- /dev/null
+++ b/pkg/discord/gateway/closecodes.go
@@ -0,0 +1,22 @@
+package gateway
+
+// REF: https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes
+const (
+	// Reconnect when receiving these.
+	CloseUnknownError         = 4000 // We're not sure what went wrong. Try reconnecting?
+	CloseUnkownOpcode         = 4001 // You sent an invalid Gateway opcode or an invalid payload for an opcode. Don't do that!
+	CloseDecodeError          = 4002 // You sent an invalid payload to us. Don't do that!
+	CloseNotAuthenticated     = 4003 // You sent us a payload prior to identifying.
+	CloseAlreadyAuthenticated = 4005 // You sent more than one identify payload. Don't do that!
+	CloseInvalidSeq           = 4007 // The sequence sent when resuming the session was invalid. Reconnect and start a new session.
+	CloseRateLimited          = 4008 // Woah nelly! You're sending payloads to us too quickly. Slow it down! You will be disconnected on receiving this.
+	CloseSessionTimedOut      = 4009 // Your session timed out. Reconnect and start a new one.
+
+	// Don't attempt to reconnect on any of the following.
+	CloseAuthenticationFailed = 4004 // The account token sent with your identify payload is incorrect.
+	CloseInvalidShard         = 4010 // You sent us an invalid shard when identifying.
+	CloseShardingRequired     = 4011 // The session would have handled too many guilds - you are required to shard your connection in order to connect.
+	CloseInvalidAPIVersion    = 4012 // You sent an invalid version for the gateway.
+	CloseInvalidIntents       = 4013 // You sent an invalid intent for a Gateway Intent. You may have incorrectly calculated the bitwise value.
+	CloseDisallowedIntents    = 4014 // You sent a disallowed intent for a Gateway Intent. You may have tried to specify an intent that you have not enabled or are not approved for.
+)
diff --git a/pkg/discord/gateway/gateway.go b/pkg/discord/gateway/gateway.go
index 9c86bb6..114eabc 100644
--- a/pkg/discord/gateway/gateway.go
+++ b/pkg/discord/gateway/gateway.go
@@ -20,25 +20,35 @@ type Gateway interface {
 	Heartbeat() error
 }
 
-var _ Gateway = &GatewayImpl{}
+var _ Gateway = &gatewayImpl{}
 
-type GatewayImpl struct {
-	ctx          context.Context
+type gatewayImpl struct {
 	logger       *zerolog.Logger
 	conn         *cancellablewebsocket.CancellableWebSocket
 	heartbeat    Heartbeat
 	eventHandler events.EventHandler
-	lastSeq      uint64
+
+	parentCtx     context.Context
+	cancelRoutine context.CancelFunc
+
+	url       string
+	token     string
+	sessionID string
+	lastSeq   uint64
 }
 
-func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl {
-	gateway := &GatewayImpl{
-		ctx:          nil,
-		logger:       logger,
-		conn:         nil,
-		eventHandler: eventHandler,
-		heartbeat:    nil,
-		lastSeq:      0,
+func New(logger *zerolog.Logger, eventHandler events.EventHandler) Gateway {
+	gateway := &gatewayImpl{
+		logger:        logger,
+		conn:          nil,
+		eventHandler:  eventHandler,
+		heartbeat:     nil,
+		parentCtx:     nil,
+		cancelRoutine: nil,
+		url:           "",
+		token:         "",
+		sessionID:     "",
+		lastSeq:       0,
 	}
 
 	// Cycle dependency, is this the best way to do this?
@@ -48,15 +58,20 @@ func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl
 	return gateway
 }
 
-func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error {
+func (g *gatewayImpl) Start(ctx context.Context, url string, token string) error {
+	g.parentCtx = ctx
+	routineCtx, cancelRoutine := context.WithCancel(ctx)
+	g.cancelRoutine = cancelRoutine
+
 	connectHeader := http.Header{}
-	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader)
+	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, url, connectHeader)
 	if err != nil {
 		return err
 	}
 
+	g.url = url
+	g.token = token
 	g.conn = conn
-	g.ctx = ctx
 
 	hello, err := g.hello()
 	if err != nil {
@@ -64,18 +79,25 @@ func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error
 	}
 
 	heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond
-	g.heartbeat.Start(ctx, heartbeatInterval)
+	g.heartbeat.Start(g.parentCtx, heartbeatInterval)
 
 	if err = g.identify(token); err != nil {
 		return err
 	}
 
-	go g.listen()
+	// Set up close handler for reconnecting or killing.
+	// Done after identity and hello, because we kill the gateway if those fail.
+	g.conn.OnClose(func(code int, text string) error {
+		g.onClose(code)
+		return nil
+	})
+
+	go g.listen(routineCtx)
 
 	return nil
 }
 
-func (g *GatewayImpl) Close() error {
+func (g *gatewayImpl) Close() error {
 	// Try closing gracefully.
 	g.logger.Debug().Msg("closing gateway gracefully...")
 	if err := g.conn.Close(1000); err != nil {
@@ -86,7 +108,7 @@ func (g *GatewayImpl) Close() error {
 	return nil
 }
 
-func (g *GatewayImpl) Heartbeat() error {
+func (g *gatewayImpl) Heartbeat() error {
 	msg := Payload[uint64]{
 		Op:   OP_HEARTBEAT,
 		Data: g.lastSeq,
@@ -95,7 +117,7 @@ func (g *GatewayImpl) Heartbeat() error {
 	return g.send(msg)
 }
 
-func (g *GatewayImpl) hello() (HelloEvent, error) {
+func (g *gatewayImpl) hello() (HelloEvent, error) {
 	var msg Payload[HelloEvent]
 	if err := receive(g, &msg); err != nil {
 		return HelloEvent{}, err
@@ -108,7 +130,7 @@ func (g *GatewayImpl) hello() (HelloEvent, error) {
 	return msg.Data, nil
 }
 
-func (g *GatewayImpl) identify(token string) error {
+func (g *gatewayImpl) identify(token string) error {
 	msg := Payload[IdentifyCmd]{
 		Op: OP_IDENTIFY,
 		Data: IdentifyCmd{
@@ -132,8 +154,6 @@ func (g *GatewayImpl) identify(token string) error {
 		return err
 	}
 
-	g.logger.Debug().Msgf("identify response payload: %+v", res)
-
 	if res.Op != OP_DISPATCH {
 		return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op)
 	}
@@ -142,10 +162,88 @@ func (g *GatewayImpl) identify(token string) error {
 		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
 	}
 
+	g.sessionID = res.Data.SessionID
+
+	return nil
+}
+
+func (g *gatewayImpl) resume() error {
+	msg := Payload[ResumeCmd]{
+		Op: OP_RESUME,
+		Data: ResumeCmd{
+			Token:     g.token,
+			SessionID: g.sessionID,
+			Sequence:  g.lastSeq,
+		},
+		Sequence: 0,
+	}
+
+	if err := g.send(msg); err != nil {
+		return err
+	}
+
+	// From now on Discord will send us all missed events,
+	// which we handle outside of the listener here.
+	for {
+		var msg Payload[json.RawMessage]
+		if err := receive(g, &msg); err != nil {
+			return err
+		}
+
+		if msg.Op == OP_DISPATCH && msg.EventName == "RESUMED" {
+			// Done replaying.
+			return nil
+		}
+
+		// Handle normal event.
+		g.onEvent(msg)
+	}
+}
+
+func (g *gatewayImpl) reconnect() error {
+	if err := g.conn.Kill(); err != nil {
+		return err
+	}
+
+	g.heartbeat.Stop()
+
+	// Renew context.
+	g.cancelRoutine()
+	routineCtx, cancelRoutines := context.WithCancel(g.parentCtx)
+	g.cancelRoutine = cancelRoutines
+
+	connectHeader := http.Header{}
+	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, g.url, connectHeader)
+	if err != nil {
+		return err
+	}
+
+	g.conn = conn
+
+	hello, err := g.hello()
+	if err != nil {
+		return err
+	}
+
+	heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond
+	g.heartbeat.Start(routineCtx, heartbeatInterval)
+
+	// Replay all missed events.
+	if err = g.resume(); err != nil {
+		return err
+	}
+
+	g.conn.OnClose(func(code int, text string) error {
+		g.onClose(code)
+		return nil
+	})
+
+	go g.listen(routineCtx)
+
 	return nil
 }
 
-func (g *GatewayImpl) listen() {
+func (g *gatewayImpl) listen(ctx context.Context) {
 	for {
 		var msg Payload[json.RawMessage]
 		if err := receive(g, &msg); err != nil {
@@ -154,7 +252,7 @@ func (g *GatewayImpl) listen() {
 		}
 
 		select {
-		case <-g.ctx.Done():
+		case <-ctx.Done():
 			return
 		default:
 			g.onEvent(msg)
@@ -162,7 +260,7 @@ func (g *GatewayImpl) listen() {
 	}
 }
 
-func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error {
+func (g *gatewayImpl) onEvent(msg Payload[json.RawMessage]) error {
 	switch msg.Op {
 	case OP_HEARTBEAT_ACK:
 		g.heartbeat.Ack()
@@ -177,7 +275,7 @@ func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error {
 	return nil
 }
 
-func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
+func (g *gatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
 	switch eventName {
 	case "MESSAGE_CREATE":
 		var gatewayEvent MessageCreateEvent
@@ -194,11 +292,65 @@ func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
 	return nil
 }
 
-func (g *GatewayImpl) send(payload any) error {
+func (g *gatewayImpl) onClose(code int) {
+	shouldReconnect := true
+
+	switch code {
+	case CloseUnknownError:
+		g.logger.Warn().Msg("gateway closed because of an unknown error, reconnecting...")
+	case CloseUnkownOpcode:
+		g.logger.Warn().Msg("gateway closed after receiving an unknown opcode, reconnecting...")
+	case CloseDecodeError:
+		g.logger.Warn().Msg("gateway closed after receiving an invalid payload, reconnecting...")
+	case CloseNotAuthenticated:
+		g.logger.Warn().Msg("gateway closed because of missing authentication, reconnecting...")
+	case CloseAlreadyAuthenticated:
+		g.logger.Warn().Msg("gateway closed because the client is already authenticated, reconnecting...")
+	case CloseInvalidSeq:
+		g.logger.Warn().Msg("gateway closed after receiving an invalid sequence number, reconnecting...")
+	case CloseRateLimited:
+		// TODO: Implement rate limiting.
+		g.logger.Warn().Msg("gateway rate-limited, reconnecting...")
+	case CloseSessionTimedOut:
+		g.logger.Warn().Msg("gateway closed because of a session timeout, reconnecting...")
+
+	case CloseAuthenticationFailed:
+		g.logger.Error().Msg("gateway closed because of an authentication error.")
+		shouldReconnect = false
+	case CloseInvalidShard:
+		g.logger.Error().Msg("gateway closed because the given shard is invalid.")
+		shouldReconnect = false
+	case CloseShardingRequired:
+		g.logger.Error().Msg("gateway closed because the bot requires sharding.")
+		shouldReconnect = false
+	case CloseInvalidAPIVersion:
+		g.logger.Error().Msg("gateway closed because the given API version is invalid.")
+		shouldReconnect = false
+	case CloseInvalidIntents:
+		g.logger.Error().Msg("gateway closed because the given intents are invalid.")
+		shouldReconnect = false
+	case CloseDisallowedIntents:
+		g.logger.Error().Msg("gateway closed because the given intents are disallowed.")
+		shouldReconnect = false
+	default:
+		g.logger.Error().Msgf("gateway closed with unknown code %d.", code)
+		shouldReconnect = false
+	}
+
+	if shouldReconnect {
+		if err := g.reconnect(); err != nil {
+			g.logger.Error().Err(err).Msg("error reconnecting")
+			return
+		}
+		g.logger.Info().Msg("successfully reconnected!")
+	}
+}
+
+func (g *gatewayImpl) send(payload any) error {
 	return g.conn.WriteJSON(payload)
 }
 
-func receive[D any](g *GatewayImpl, res *Payload[D]) error {
+func receive[D any](g *gatewayImpl, res *Payload[D]) error {
 	err := g.conn.ReadJSON(&res)
 	if err != nil {
 		return err
diff --git a/pkg/discord/gateway/heartbeat.go b/pkg/discord/gateway/heartbeat.go
index 1df753a..1ebe964 100644
--- a/pkg/discord/gateway/heartbeat.go
+++ b/pkg/discord/gateway/heartbeat.go
@@ -10,42 +10,51 @@ import (
 
 type Heartbeat interface {
 	Start(ctx context.Context, interval time.Duration)
+	Stop()
 
 	ForceHeartbeat()
 	Ack()
 }
 
-var _ Heartbeat = &HeartbeatImpl{}
+var _ Heartbeat = &heartbeatImpl{}
 
-type HeartbeatImpl struct {
-	ctx     context.Context
+type heartbeatImpl struct {
 	logger  *zerolog.Logger
 	gateway Gateway
+
+	cancelRoutine context.CancelFunc
 }
 
-func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *HeartbeatImpl {
-	return &HeartbeatImpl{
-		ctx:     nil,
+func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *heartbeatImpl {
+	return &heartbeatImpl{
 		logger:  logger,
 		gateway: gateway,
+
+		cancelRoutine: nil,
 	}
 }
 
-func (h *HeartbeatImpl) Start(ctx context.Context, interval time.Duration) {
-	h.ctx = ctx
-	go h.heartbeatRoutine(interval)
+func (h *heartbeatImpl) Start(ctx context.Context, interval time.Duration) {
+	routineCtx, cancel := context.WithCancel(ctx)
+	h.cancelRoutine = cancel
+	go h.heartbeatRoutine(routineCtx, interval)
+}
+
+func (h *heartbeatImpl) Stop() {
+	h.cancelRoutine()
+	h.cancelRoutine = nil
 }
 
-func (h *HeartbeatImpl) ForceHeartbeat() {
+func (h *heartbeatImpl) ForceHeartbeat() {
 	h.gateway.Heartbeat()
 }
 
-func (h *HeartbeatImpl) Ack() {
+func (h *heartbeatImpl) Ack() {
 	// What do we do here?
 	h.logger.Debug().Msg("received heartbeat ack")
 }
 
-func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
+func (h *heartbeatImpl) heartbeatRoutine(ctx context.Context, interval time.Duration) {
 	h.logger.Debug().Msgf("beating heart every %dms", interval.Milliseconds())
 
 	// REF: heartbeat_interval * jitter
@@ -53,7 +62,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
 
 	select {
 	case <-time.After(time.Duration(jitter)):
-	case <-h.ctx.Done():
+	case <-ctx.Done():
 		h.logger.Debug().Msg("heartbeat routine stopped before jitter heartbeat")
 		return
 	}
@@ -68,7 +77,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
 		}
 
 		select {
-		case <-h.ctx.Done():
+		case <-ctx.Done():
 			return
 		case <-ticker.C:
 			continue
diff --git a/pkg/discord/gateway/payloads.go b/pkg/discord/gateway/payloads.go
index 8da894f..87c37d3 100644
--- a/pkg/discord/gateway/payloads.go
+++ b/pkg/discord/gateway/payloads.go
@@ -28,11 +28,17 @@ type Payload[D any] struct {
 }
 
 type IdentifyCmd struct {
-	Token      string                    `json:"token"`
-	Intents    uint64                    `json:"intents"`
+	Token      string             `json:"token"`
+	Intents    uint64             `json:"intents"`
 	Properties IdentifyProperties `json:"properties"`
 }
 
+type ResumeCmd struct {
+	Token     string `json:"token"`
+	SessionID string `json:"session_id"`
+	Sequence  uint64 `json:"seq"`
+}
+
 type HelloEvent struct {
 	HeartbeatInterval uint64 `json:"heartbeat_interval"`
 }