diff options
Diffstat (limited to 'pkg/discord/gateway/gateway.go')
| -rw-r--r-- | pkg/discord/gateway/gateway.go | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/pkg/discord/gateway/gateway.go b/pkg/discord/gateway/gateway.go new file mode 100644 index 0000000..18cf708 --- /dev/null +++ b/pkg/discord/gateway/gateway.go @@ -0,0 +1,202 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "jinx/pkg/discord/events" + "jinx/pkg/libs/cancellablewebsocket" + "net/http" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +type Gateway interface { + Start(ctx context.Context, url string, token string) error + Close() error + + Heartbeat() error +} + +var _ Gateway = &GatewayImpl{} + +type GatewayImpl struct { + ctx context.Context + logger *zerolog.Logger + conn *cancellablewebsocket.CancellableWebSocket + heartbeat Heartbeat + eventHandler events.EventHandler + lastSeq uint64 +} + +func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl { + gateway := &GatewayImpl{ + ctx: nil, + logger: logger, + conn: nil, + eventHandler: eventHandler, + heartbeat: nil, + lastSeq: 0, + } + + // Cycle dependency, is this the best way to do this? + heartbeat := NewHeartbeat(logger, gateway) + gateway.heartbeat = heartbeat + + return gateway +} + +func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { + connectHeader := http.Header{} + conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) + if err != nil { + return err + } + + g.conn = conn + g.ctx = ctx + + hello, err := g.hello() + if err != nil { + return err + } + + g.heartbeat.Start(ctx, hello.HeartbeatInterval) + + if err = g.identify(token); err != nil { + return err + } + + go g.listen() + + return nil +} + +func (g *GatewayImpl) Close() error { + return g.conn.Close() +} + +func (g *GatewayImpl) Heartbeat() error { + msg := Payload[uint64]{ + Op: OP_HEARTBEAT, + Data: g.lastSeq, + } + + return g.send(msg) +} + +func (g *GatewayImpl) hello() (HelloEvent, error) { + var msg Payload[HelloEvent] + if err := receive(g, &msg); err != nil { + return HelloEvent{}, err + } + + if msg.Op != OP_HELLO { + return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op) + } + + return msg.Data, nil +} + +func (g *GatewayImpl) identify(token string) error { + msg := Payload[IdentifyCmd]{ + Op: OP_IDENTIFY, + Data: IdentifyCmd{ + Token: token, + Intents: 13991, + Properties: IdentifyProperties{ + OS: "linux", + Browser: "jinx", + Device: "jinx", + }, + }, + Sequence: 0, + } + + if err := g.send(msg); err != nil { + return err + } + + var res Payload[ReadyEvent] + if err := receive(g, &res); err != nil { + return err + } + + g.logger.Debug().Msgf("identify response payload: %+v", res) + + if res.Op != OP_DISPATCH { + return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op) + } + + if res.EventName != "READY" { + return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) + } + + return nil +} + +func (g *GatewayImpl) listen() { + for { + var msg Payload[json.RawMessage] + if err := receive(g, &msg); err != nil { + g.logger.Error().Err(err).Msgf("error reading message") + continue + } + + select { + case <-g.ctx.Done(): + return + default: + g.onEvent(msg) + } + } +} + +func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error { + switch msg.Op { + case OP_HEARTBEAT_ACK: + g.heartbeat.Ack() + case OP_HEARTBEAT: + g.heartbeat.ForceHeartbeat() + case OP_DISPATCH: + return g.onDispatch(msg.EventName, msg.Data) + default: + g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op) + } + + return nil +} + +func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error { + switch eventName { + case "MESSAGE_CREATE": + var payload MessageCreateEvent + if err := json.Unmarshal(body, &payload); err != nil { + return err + } + + g.eventHandler.Fire(events.MESSAGE, payload) + default: + g.logger.Warn().Msgf("received unknown event: %s", eventName) + } + + return nil +} + +func (g *GatewayImpl) send(payload any) error { + return g.conn.WriteJSON(payload) +} + +func receive[D any](g *GatewayImpl, res *Payload[D]) error { + err := g.conn.ReadJSON(&res) + if err != nil { + return err + } + + if res.Sequence != 0 { + g.lastSeq = res.Sequence + } + + return nil +} |
