diff options
Diffstat (limited to 'pkg/discord')
| -rw-r--r-- | pkg/discord/gateway/closecodes.go | 22 | ||||
| -rw-r--r-- | pkg/discord/gateway/gateway.go | 210 | ||||
| -rw-r--r-- | pkg/discord/gateway/heartbeat.go | 37 | ||||
| -rw-r--r-- | pkg/discord/gateway/payloads.go | 10 |
4 files changed, 234 insertions, 45 deletions
diff --git a/pkg/discord/gateway/closecodes.go b/pkg/discord/gateway/closecodes.go new file mode 100644 index 0000000..d6e9804 --- /dev/null +++ b/pkg/discord/gateway/closecodes.go @@ -0,0 +1,22 @@ +package gateway + +// REF: https://discord.com/developers/docs/topics/opcodes-and-status-codes#gateway-gateway-close-event-codes +const ( + // Reconnect when receiving these. + CloseUnknownError = 4000 // We're not sure what went wrong. Try reconnecting? + CloseUnkownOpcode = 4001 // You sent an invalid Gateway opcode or an invalid payload for an opcode. Don't do that! + CloseDecodeError = 4002 // You sent an invalid payload to us. Don't do that! + CloseNotAuthenticated = 4003 // You sent us a payload prior to identifying. + CloseAlreadyAuthenticated = 4005 // You sent more than one identify payload. Don't do that! + CloseInvalidSeq = 4007 // The sequence sent when resuming the session was invalid. Reconnect and start a new session. + CloseRateLimited = 4008 // Woah nelly! You're sending payloads to us too quickly. Slow it down! You will be disconnected on receiving this. + CloseSessionTimedOut = 4009 // Your session timed out. Reconnect and start a new one. + + // Don't attempt to reconnect on any of the following. + CloseAuthenticationFailed = 4004 // The account token sent with your identify payload is incorrect. + CloseInvalidShard = 4010 // You sent us an invalid shard when identifying. + CloseShardingRequired = 4011 // The session would have handled too many guilds - you are required to shard your connection in order to connect. + CloseInvalidAPIVersion = 4012 // You sent an invalid version for the gateway. + CloseInvalidIntents = 4013 // You sent an invalid intent for a Gateway Intent. You may have incorrectly calculated the bitwise value. + CloseDisallowedIntents = 4014 // You sent a disallowed intent for a Gateway Intent. You may have tried to specify an intent that you have not enabled or are not approved for. +) diff --git a/pkg/discord/gateway/gateway.go b/pkg/discord/gateway/gateway.go index 9c86bb6..114eabc 100644 --- a/pkg/discord/gateway/gateway.go +++ b/pkg/discord/gateway/gateway.go @@ -20,25 +20,35 @@ type Gateway interface { Heartbeat() error } -var _ Gateway = &GatewayImpl{} +var _ Gateway = &gatewayImpl{} -type GatewayImpl struct { - ctx context.Context +type gatewayImpl struct { logger *zerolog.Logger conn *cancellablewebsocket.CancellableWebSocket heartbeat Heartbeat eventHandler events.EventHandler - lastSeq uint64 + + parentCtx context.Context + cancelRoutine context.CancelFunc + + url string + token string + sessionID string + lastSeq uint64 } -func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl { - gateway := &GatewayImpl{ - ctx: nil, - logger: logger, - conn: nil, - eventHandler: eventHandler, - heartbeat: nil, - lastSeq: 0, +func New(logger *zerolog.Logger, eventHandler events.EventHandler) Gateway { + gateway := &gatewayImpl{ + logger: logger, + conn: nil, + eventHandler: eventHandler, + heartbeat: nil, + parentCtx: nil, + cancelRoutine: nil, + url: "", + token: "", + sessionID: "", + lastSeq: 0, } // Cycle dependency, is this the best way to do this? @@ -48,15 +58,20 @@ func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl return gateway } -func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { +func (g *gatewayImpl) Start(ctx context.Context, url string, token string) error { + g.parentCtx = ctx + routineCtx, cancelRoutine := context.WithCancel(ctx) + g.cancelRoutine = cancelRoutine + connectHeader := http.Header{} - conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) + conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, url, connectHeader) if err != nil { return err } + g.url = url + g.token = token g.conn = conn - g.ctx = ctx hello, err := g.hello() if err != nil { @@ -64,18 +79,25 @@ func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error } heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond - g.heartbeat.Start(ctx, heartbeatInterval) + g.heartbeat.Start(g.parentCtx, heartbeatInterval) if err = g.identify(token); err != nil { return err } - go g.listen() + // Set up close handler for reconnecting or killing. + // Done after identity and hello, because we kill the gateway if those fail. + g.conn.OnClose(func(code int, text string) error { + g.onClose(code) + return nil + }) + + go g.listen(routineCtx) return nil } -func (g *GatewayImpl) Close() error { +func (g *gatewayImpl) Close() error { // Try closing gracefully. g.logger.Debug().Msg("closing gateway gracefully...") if err := g.conn.Close(1000); err != nil { @@ -86,7 +108,7 @@ func (g *GatewayImpl) Close() error { return nil } -func (g *GatewayImpl) Heartbeat() error { +func (g *gatewayImpl) Heartbeat() error { msg := Payload[uint64]{ Op: OP_HEARTBEAT, Data: g.lastSeq, @@ -95,7 +117,7 @@ func (g *GatewayImpl) Heartbeat() error { return g.send(msg) } -func (g *GatewayImpl) hello() (HelloEvent, error) { +func (g *gatewayImpl) hello() (HelloEvent, error) { var msg Payload[HelloEvent] if err := receive(g, &msg); err != nil { return HelloEvent{}, err @@ -108,7 +130,7 @@ func (g *GatewayImpl) hello() (HelloEvent, error) { return msg.Data, nil } -func (g *GatewayImpl) identify(token string) error { +func (g *gatewayImpl) identify(token string) error { msg := Payload[IdentifyCmd]{ Op: OP_IDENTIFY, Data: IdentifyCmd{ @@ -132,8 +154,6 @@ func (g *GatewayImpl) identify(token string) error { return err } - g.logger.Debug().Msgf("identify response payload: %+v", res) - if res.Op != OP_DISPATCH { return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op) } @@ -142,10 +162,88 @@ func (g *GatewayImpl) identify(token string) error { return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) } + g.sessionID = res.Data.SessionID + + return nil +} + +func (g *gatewayImpl) resume() error { + msg := Payload[ResumeCmd]{ + Op: OP_RESUME, + Data: ResumeCmd{ + Token: g.token, + SessionID: g.sessionID, + Sequence: g.lastSeq, + }, + Sequence: 0, + } + + if err := g.send(msg); err != nil { + return err + } + + // From now on Discord will send us all missed events, + // which we handle outside of the listener here. + for { + var msg Payload[json.RawMessage] + if err := receive(g, &msg); err != nil { + return err + } + + if msg.Op == OP_DISPATCH && msg.EventName == "RESUMED" { + // Done replaying. + return nil + } + + // Handle normal event. + g.onEvent(msg) + } +} + +func (g *gatewayImpl) reconnect() error { + if err := g.conn.Kill(); err != nil { + return err + } + + g.heartbeat.Stop() + + // Renew context. + g.cancelRoutine() + routineCtx, cancelRoutines := context.WithCancel(g.parentCtx) + g.cancelRoutine = cancelRoutines + + connectHeader := http.Header{} + conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, g.url, connectHeader) + if err != nil { + return err + } + + g.conn = conn + + hello, err := g.hello() + if err != nil { + return err + } + + heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond + g.heartbeat.Start(routineCtx, heartbeatInterval) + + // Replay all missed events. + if err = g.resume(); err != nil { + return err + } + + g.conn.OnClose(func(code int, text string) error { + g.onClose(code) + return nil + }) + + go g.listen(routineCtx) + return nil } -func (g *GatewayImpl) listen() { +func (g *gatewayImpl) listen(ctx context.Context) { for { var msg Payload[json.RawMessage] if err := receive(g, &msg); err != nil { @@ -154,7 +252,7 @@ func (g *GatewayImpl) listen() { } select { - case <-g.ctx.Done(): + case <-ctx.Done(): return default: g.onEvent(msg) @@ -162,7 +260,7 @@ func (g *GatewayImpl) listen() { } } -func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error { +func (g *gatewayImpl) onEvent(msg Payload[json.RawMessage]) error { switch msg.Op { case OP_HEARTBEAT_ACK: g.heartbeat.Ack() @@ -177,7 +275,7 @@ func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error { return nil } -func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error { +func (g *gatewayImpl) onDispatch(eventName string, body json.RawMessage) error { switch eventName { case "MESSAGE_CREATE": var gatewayEvent MessageCreateEvent @@ -194,11 +292,65 @@ func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error { return nil } -func (g *GatewayImpl) send(payload any) error { +func (g *gatewayImpl) onClose(code int) { + shouldReconnect := true + + switch code { + case CloseUnknownError: + g.logger.Warn().Msg("gateway closed because of an unknown error, reconnecting...") + case CloseUnkownOpcode: + g.logger.Warn().Msg("gateway closed after receiving an unknown opcode, reconnecting...") + case CloseDecodeError: + g.logger.Warn().Msg("gateway closed after receiving an invalid payload, reconnecting...") + case CloseNotAuthenticated: + g.logger.Warn().Msg("gateway closed because of missing authentication, reconnecting...") + case CloseAlreadyAuthenticated: + g.logger.Warn().Msg("gateway closed because the client is already authenticated, reconnecting...") + case CloseInvalidSeq: + g.logger.Warn().Msg("gateway closed after receiving an invalid sequence number, reconnecting...") + case CloseRateLimited: + // TODO: Implement rate limiting. + g.logger.Warn().Msg("gateway rate-limited, reconnecting...") + case CloseSessionTimedOut: + g.logger.Warn().Msg("gateway closed because of a session timeout, reconnecting...") + + case CloseAuthenticationFailed: + g.logger.Error().Msg("gateway closed because of an authentication error.") + shouldReconnect = false + case CloseInvalidShard: + g.logger.Error().Msg("gateway closed because the given shard is invalid.") + shouldReconnect = false + case CloseShardingRequired: + g.logger.Error().Msg("gateway closed because the bot requires sharding.") + shouldReconnect = false + case CloseInvalidAPIVersion: + g.logger.Error().Msg("gateway closed because the given API version is invalid.") + shouldReconnect = false + case CloseInvalidIntents: + g.logger.Error().Msg("gateway closed because the given intents are invalid.") + shouldReconnect = false + case CloseDisallowedIntents: + g.logger.Error().Msg("gateway closed because the given intents are disallowed.") + shouldReconnect = false + default: + g.logger.Error().Msgf("gateway closed with unknown code %d.", code) + shouldReconnect = false + } + + if shouldReconnect { + if err := g.reconnect(); err != nil { + g.logger.Error().Err(err).Msg("error reconnecting") + return + } + g.logger.Info().Msg("successfully reconnected!") + } +} + +func (g *gatewayImpl) send(payload any) error { return g.conn.WriteJSON(payload) } -func receive[D any](g *GatewayImpl, res *Payload[D]) error { +func receive[D any](g *gatewayImpl, res *Payload[D]) error { err := g.conn.ReadJSON(&res) if err != nil { return err diff --git a/pkg/discord/gateway/heartbeat.go b/pkg/discord/gateway/heartbeat.go index 1df753a..1ebe964 100644 --- a/pkg/discord/gateway/heartbeat.go +++ b/pkg/discord/gateway/heartbeat.go @@ -10,42 +10,51 @@ import ( type Heartbeat interface { Start(ctx context.Context, interval time.Duration) + Stop() ForceHeartbeat() Ack() } -var _ Heartbeat = &HeartbeatImpl{} +var _ Heartbeat = &heartbeatImpl{} -type HeartbeatImpl struct { - ctx context.Context +type heartbeatImpl struct { logger *zerolog.Logger gateway Gateway + + cancelRoutine context.CancelFunc } -func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *HeartbeatImpl { - return &HeartbeatImpl{ - ctx: nil, +func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *heartbeatImpl { + return &heartbeatImpl{ logger: logger, gateway: gateway, + + cancelRoutine: nil, } } -func (h *HeartbeatImpl) Start(ctx context.Context, interval time.Duration) { - h.ctx = ctx - go h.heartbeatRoutine(interval) +func (h *heartbeatImpl) Start(ctx context.Context, interval time.Duration) { + routineCtx, cancel := context.WithCancel(ctx) + h.cancelRoutine = cancel + go h.heartbeatRoutine(routineCtx, interval) +} + +func (h *heartbeatImpl) Stop() { + h.cancelRoutine() + h.cancelRoutine = nil } -func (h *HeartbeatImpl) ForceHeartbeat() { +func (h *heartbeatImpl) ForceHeartbeat() { h.gateway.Heartbeat() } -func (h *HeartbeatImpl) Ack() { +func (h *heartbeatImpl) Ack() { // What do we do here? h.logger.Debug().Msg("received heartbeat ack") } -func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) { +func (h *heartbeatImpl) heartbeatRoutine(ctx context.Context, interval time.Duration) { h.logger.Debug().Msgf("beating heart every %dms", interval.Milliseconds()) // REF: heartbeat_interval * jitter @@ -53,7 +62,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) { select { case <-time.After(time.Duration(jitter)): - case <-h.ctx.Done(): + case <-ctx.Done(): h.logger.Debug().Msg("heartbeat routine stopped before jitter heartbeat") return } @@ -68,7 +77,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) { } select { - case <-h.ctx.Done(): + case <-ctx.Done(): return case <-ticker.C: continue diff --git a/pkg/discord/gateway/payloads.go b/pkg/discord/gateway/payloads.go index 8da894f..87c37d3 100644 --- a/pkg/discord/gateway/payloads.go +++ b/pkg/discord/gateway/payloads.go @@ -28,11 +28,17 @@ type Payload[D any] struct { } type IdentifyCmd struct { - Token string `json:"token"` - Intents uint64 `json:"intents"` + Token string `json:"token"` + Intents uint64 `json:"intents"` Properties IdentifyProperties `json:"properties"` } +type ResumeCmd struct { + Token string `json:"token"` + SessionID string `json:"session_id"` + Sequence uint64 `json:"seq"` +} + type HelloEvent struct { HeartbeatInterval uint64 `json:"heartbeat_interval"` } |
