package gateway import ( "context" "encoding/json" "fmt" "jinx/pkg/discord/events" "jinx/pkg/libs/cancellablewebsocket" "net/http" "time" "github.com/gorilla/websocket" "github.com/rs/zerolog" ) type Gateway interface { Start(ctx context.Context, url string, token string) error Close() error Heartbeat() error } var _ Gateway = &gatewayImpl{} type gatewayImpl struct { logger *zerolog.Logger conn *cancellablewebsocket.CancellableWebSocket heartbeat Heartbeat eventHandler events.EventHandler parentCtx context.Context cancelRoutine context.CancelFunc url string token string sessionID string lastSeq uint64 } func New(logger *zerolog.Logger, eventHandler events.EventHandler) Gateway { gateway := &gatewayImpl{ logger: logger, conn: nil, eventHandler: eventHandler, heartbeat: nil, parentCtx: nil, cancelRoutine: nil, url: "", token: "", sessionID: "", lastSeq: 0, } // Cycle dependency, is this the best way to do this? heartbeat := NewHeartbeat(logger, gateway) gateway.heartbeat = heartbeat return gateway } func (g *gatewayImpl) Start(ctx context.Context, url string, token string) error { g.parentCtx = ctx routineCtx, cancelRoutine := context.WithCancel(ctx) g.cancelRoutine = cancelRoutine connectHeader := http.Header{} conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, url, connectHeader) if err != nil { return err } g.url = url g.token = token g.conn = conn hello, err := g.hello() if err != nil { return err } heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond g.heartbeat.Start(g.parentCtx, heartbeatInterval) if err = g.identify(token); err != nil { return err } // Set up close handler for reconnecting or killing. // Done after identity and hello, because we kill the gateway if those fail. g.conn.OnClose(func(code int, text string) error { g.onClose(code) return nil }) go g.listen(routineCtx) return nil } func (g *gatewayImpl) Close() error { // Try closing gracefully. g.logger.Debug().Msg("closing gateway gracefully...") if err := g.conn.Close(1000); err != nil { g.logger.Error().Err(err).Msg("error closing gateway gracefully, trying to murder...") return g.conn.Kill() } return nil } func (g *gatewayImpl) Heartbeat() error { msg := Payload[uint64]{ Op: OP_HEARTBEAT, Data: g.lastSeq, } return g.send(msg) } func (g *gatewayImpl) hello() (HelloEvent, error) { var msg Payload[HelloEvent] if err := receive(g, &msg); err != nil { return HelloEvent{}, err } if msg.Op != OP_HELLO { return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op) } return msg.Data, nil } func (g *gatewayImpl) identify(token string) error { msg := Payload[IdentifyCmd]{ Op: OP_IDENTIFY, Data: IdentifyCmd{ Token: token, Intents: 13991, Properties: IdentifyProperties{ OS: "linux", Browser: "jinx", Device: "jinx", }, }, Sequence: 0, } if err := g.send(msg); err != nil { return err } var res Payload[ReadyEvent] if err := receive(g, &res); err != nil { return err } if res.Op != OP_DISPATCH { return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op) } if res.EventName != "READY" { return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) } g.sessionID = res.Data.SessionID return nil } func (g *gatewayImpl) resume() error { msg := Payload[ResumeCmd]{ Op: OP_RESUME, Data: ResumeCmd{ Token: g.token, SessionID: g.sessionID, Sequence: g.lastSeq, }, Sequence: 0, } if err := g.send(msg); err != nil { return err } // From now on Discord will send us all missed events, // which we handle outside of the listener here. for { var msg Payload[json.RawMessage] if err := receive(g, &msg); err != nil { return err } if msg.Op == OP_DISPATCH && msg.EventName == "RESUMED" { // Done replaying. return nil } // Handle normal event. g.onEvent(msg) } } func (g *gatewayImpl) reconnect() error { if err := g.conn.Kill(); err != nil { return err } g.heartbeat.Stop() // Renew context. g.cancelRoutine() routineCtx, cancelRoutines := context.WithCancel(g.parentCtx) g.cancelRoutine = cancelRoutines connectHeader := http.Header{} conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, g.url, connectHeader) if err != nil { return err } g.conn = conn hello, err := g.hello() if err != nil { return err } heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond g.heartbeat.Start(routineCtx, heartbeatInterval) // Replay all missed events. if err = g.resume(); err != nil { return err } g.conn.OnClose(func(code int, text string) error { g.onClose(code) return nil }) go g.listen(routineCtx) return nil } func (g *gatewayImpl) listen(ctx context.Context) { for { var msg Payload[json.RawMessage] if err := receive(g, &msg); err != nil { g.logger.Error().Err(err).Msgf("error reading message") continue } select { case <-ctx.Done(): return default: g.onEvent(msg) } } } func (g *gatewayImpl) onEvent(msg Payload[json.RawMessage]) error { switch msg.Op { case OP_HEARTBEAT_ACK: g.heartbeat.Ack() case OP_HEARTBEAT: g.heartbeat.ForceHeartbeat() case OP_DISPATCH: return g.onDispatch(msg.EventName, msg.Data) default: g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op) } return nil } func (g *gatewayImpl) onDispatch(eventName string, body json.RawMessage) error { switch eventName { case "MESSAGE_CREATE": var gatewayEvent MessageCreateEvent if err := json.Unmarshal(body, &gatewayEvent); err != nil { return err } payload := events.Message(gatewayEvent) g.eventHandler.Fire(events.MESSAGE, payload) default: g.logger.Warn().Msgf("received unknown event: %s", eventName) } return nil } func (g *gatewayImpl) onClose(code int) { shouldReconnect := true switch code { case CloseUnknownError: g.logger.Warn().Msg("gateway closed because of an unknown error, reconnecting...") case CloseUnkownOpcode: g.logger.Warn().Msg("gateway closed after receiving an unknown opcode, reconnecting...") case CloseDecodeError: g.logger.Warn().Msg("gateway closed after receiving an invalid payload, reconnecting...") case CloseNotAuthenticated: g.logger.Warn().Msg("gateway closed because of missing authentication, reconnecting...") case CloseAlreadyAuthenticated: g.logger.Warn().Msg("gateway closed because the client is already authenticated, reconnecting...") case CloseInvalidSeq: g.logger.Warn().Msg("gateway closed after receiving an invalid sequence number, reconnecting...") case CloseRateLimited: // TODO: Implement rate limiting. g.logger.Warn().Msg("gateway rate-limited, reconnecting...") case CloseSessionTimedOut: g.logger.Warn().Msg("gateway closed because of a session timeout, reconnecting...") case CloseAuthenticationFailed: g.logger.Error().Msg("gateway closed because of an authentication error.") shouldReconnect = false case CloseInvalidShard: g.logger.Error().Msg("gateway closed because the given shard is invalid.") shouldReconnect = false case CloseShardingRequired: g.logger.Error().Msg("gateway closed because the bot requires sharding.") shouldReconnect = false case CloseInvalidAPIVersion: g.logger.Error().Msg("gateway closed because the given API version is invalid.") shouldReconnect = false case CloseInvalidIntents: g.logger.Error().Msg("gateway closed because the given intents are invalid.") shouldReconnect = false case CloseDisallowedIntents: g.logger.Error().Msg("gateway closed because the given intents are disallowed.") shouldReconnect = false default: g.logger.Error().Msgf("gateway closed with unknown code %d.", code) shouldReconnect = false } if shouldReconnect { if err := g.reconnect(); err != nil { g.logger.Error().Err(err).Msg("error reconnecting") return } g.logger.Info().Msg("successfully reconnected!") } } func (g *gatewayImpl) send(payload any) error { return g.conn.WriteJSON(payload) } func receive[D any](g *gatewayImpl, res *Payload[D]) error { err := g.conn.ReadJSON(&res) if err != nil { return err } if res.Sequence != 0 { g.lastSeq = res.Sequence } return nil }