package gateway import ( "context" "encoding/json" "fmt" "jinx/pkg/discord/events" "jinx/pkg/libs/cancellablewebsocket" "net/http" "github.com/gorilla/websocket" "github.com/rs/zerolog" ) type Gateway interface { Start(ctx context.Context, url string, token string) error Close() error Heartbeat() error } var _ Gateway = &GatewayImpl{} type GatewayImpl struct { ctx context.Context logger *zerolog.Logger conn *cancellablewebsocket.CancellableWebSocket heartbeat Heartbeat eventHandler events.EventHandler lastSeq uint64 } func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl { gateway := &GatewayImpl{ ctx: nil, logger: logger, conn: nil, eventHandler: eventHandler, heartbeat: nil, lastSeq: 0, } // Cycle dependency, is this the best way to do this? heartbeat := NewHeartbeat(logger, gateway) gateway.heartbeat = heartbeat return gateway } func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { connectHeader := http.Header{} conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) if err != nil { return err } g.conn = conn g.ctx = ctx hello, err := g.hello() if err != nil { return err } g.heartbeat.Start(ctx, hello.HeartbeatInterval) if err = g.identify(token); err != nil { return err } go g.listen() return nil } func (g *GatewayImpl) Close() error { return g.conn.Close() } func (g *GatewayImpl) Heartbeat() error { msg := Payload[uint64]{ Op: OP_HEARTBEAT, Data: g.lastSeq, } return g.send(msg) } func (g *GatewayImpl) hello() (HelloEvent, error) { var msg Payload[HelloEvent] if err := receive(g, &msg); err != nil { return HelloEvent{}, err } if msg.Op != OP_HELLO { return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op) } return msg.Data, nil } func (g *GatewayImpl) identify(token string) error { msg := Payload[IdentifyCmd]{ Op: OP_IDENTIFY, Data: IdentifyCmd{ Token: token, Intents: 13991, Properties: IdentifyProperties{ OS: "linux", Browser: "jinx", Device: "jinx", }, }, Sequence: 0, } if err := g.send(msg); err != nil { return err } var res Payload[ReadyEvent] if err := receive(g, &res); err != nil { return err } g.logger.Debug().Msgf("identify response payload: %+v", res) if res.Op != OP_DISPATCH { return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op) } if res.EventName != "READY" { return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) } return nil } func (g *GatewayImpl) listen() { for { var msg Payload[json.RawMessage] if err := receive(g, &msg); err != nil { g.logger.Error().Err(err).Msgf("error reading message") continue } select { case <-g.ctx.Done(): return default: g.onEvent(msg) } } } func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error { switch msg.Op { case OP_HEARTBEAT_ACK: g.heartbeat.Ack() case OP_HEARTBEAT: g.heartbeat.ForceHeartbeat() case OP_DISPATCH: return g.onDispatch(msg.EventName, msg.Data) default: g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op) } return nil } func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error { switch eventName { case "MESSAGE_CREATE": var payload MessageCreateEvent if err := json.Unmarshal(body, &payload); err != nil { return err } g.eventHandler.Fire(events.MESSAGE, payload) default: g.logger.Warn().Msgf("received unknown event: %s", eventName) } return nil } func (g *GatewayImpl) send(payload any) error { return g.conn.WriteJSON(payload) } func receive[D any](g *GatewayImpl, res *Payload[D]) error { err := g.conn.ReadJSON(&res) if err != nil { return err } if res.Sequence != 0 { g.lastSeq = res.Sequence } return nil }