diff options
Diffstat (limited to 'pkg/discord/gateway')
| -rw-r--r-- | pkg/discord/gateway/gateway.go | 202 | ||||
| -rw-r--r-- | pkg/discord/gateway/heartbeat.go | 71 | ||||
| -rw-r--r-- | pkg/discord/gateway/payloads.go | 54 |
3 files changed, 327 insertions, 0 deletions
diff --git a/pkg/discord/gateway/gateway.go b/pkg/discord/gateway/gateway.go new file mode 100644 index 0000000..18cf708 --- /dev/null +++ b/pkg/discord/gateway/gateway.go @@ -0,0 +1,202 @@ +package gateway + +import ( + "context" + "encoding/json" + "fmt" + "jinx/pkg/discord/events" + "jinx/pkg/libs/cancellablewebsocket" + "net/http" + + "github.com/gorilla/websocket" + "github.com/rs/zerolog" +) + +type Gateway interface { + Start(ctx context.Context, url string, token string) error + Close() error + + Heartbeat() error +} + +var _ Gateway = &GatewayImpl{} + +type GatewayImpl struct { + ctx context.Context + logger *zerolog.Logger + conn *cancellablewebsocket.CancellableWebSocket + heartbeat Heartbeat + eventHandler events.EventHandler + lastSeq uint64 +} + +func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl { + gateway := &GatewayImpl{ + ctx: nil, + logger: logger, + conn: nil, + eventHandler: eventHandler, + heartbeat: nil, + lastSeq: 0, + } + + // Cycle dependency, is this the best way to do this? + heartbeat := NewHeartbeat(logger, gateway) + gateway.heartbeat = heartbeat + + return gateway +} + +func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { + connectHeader := http.Header{} + conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) + if err != nil { + return err + } + + g.conn = conn + g.ctx = ctx + + hello, err := g.hello() + if err != nil { + return err + } + + g.heartbeat.Start(ctx, hello.HeartbeatInterval) + + if err = g.identify(token); err != nil { + return err + } + + go g.listen() + + return nil +} + +func (g *GatewayImpl) Close() error { + return g.conn.Close() +} + +func (g *GatewayImpl) Heartbeat() error { + msg := Payload[uint64]{ + Op: OP_HEARTBEAT, + Data: g.lastSeq, + } + + return g.send(msg) +} + +func (g *GatewayImpl) hello() (HelloEvent, error) { + var msg Payload[HelloEvent] + if err := receive(g, &msg); err != nil { + return HelloEvent{}, err + } + + if msg.Op != OP_HELLO { + return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op) + } + + return msg.Data, nil +} + +func (g *GatewayImpl) identify(token string) error { + msg := Payload[IdentifyCmd]{ + Op: OP_IDENTIFY, + Data: IdentifyCmd{ + Token: token, + Intents: 13991, + Properties: IdentifyProperties{ + OS: "linux", + Browser: "jinx", + Device: "jinx", + }, + }, + Sequence: 0, + } + + if err := g.send(msg); err != nil { + return err + } + + var res Payload[ReadyEvent] + if err := receive(g, &res); err != nil { + return err + } + + g.logger.Debug().Msgf("identify response payload: %+v", res) + + if res.Op != OP_DISPATCH { + return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op) + } + + if res.EventName != "READY" { + return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) + } + + return nil +} + +func (g *GatewayImpl) listen() { + for { + var msg Payload[json.RawMessage] + if err := receive(g, &msg); err != nil { + g.logger.Error().Err(err).Msgf("error reading message") + continue + } + + select { + case <-g.ctx.Done(): + return + default: + g.onEvent(msg) + } + } +} + +func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error { + switch msg.Op { + case OP_HEARTBEAT_ACK: + g.heartbeat.Ack() + case OP_HEARTBEAT: + g.heartbeat.ForceHeartbeat() + case OP_DISPATCH: + return g.onDispatch(msg.EventName, msg.Data) + default: + g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op) + } + + return nil +} + +func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error { + switch eventName { + case "MESSAGE_CREATE": + var payload MessageCreateEvent + if err := json.Unmarshal(body, &payload); err != nil { + return err + } + + g.eventHandler.Fire(events.MESSAGE, payload) + default: + g.logger.Warn().Msgf("received unknown event: %s", eventName) + } + + return nil +} + +func (g *GatewayImpl) send(payload any) error { + return g.conn.WriteJSON(payload) +} + +func receive[D any](g *GatewayImpl, res *Payload[D]) error { + err := g.conn.ReadJSON(&res) + if err != nil { + return err + } + + if res.Sequence != 0 { + g.lastSeq = res.Sequence + } + + return nil +} diff --git a/pkg/discord/gateway/heartbeat.go b/pkg/discord/gateway/heartbeat.go new file mode 100644 index 0000000..6cefb21 --- /dev/null +++ b/pkg/discord/gateway/heartbeat.go @@ -0,0 +1,71 @@ +package gateway + +import ( + "context" + "math/rand" + "time" + + "github.com/rs/zerolog" +) + +type Heartbeat interface { + Start(ctx context.Context, interval uint64) + + ForceHeartbeat() + Ack() +} + +var _ Heartbeat = &HeartbeatImpl{} + +type HeartbeatImpl struct { + ctx context.Context + logger *zerolog.Logger + gateway Gateway +} + +func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *HeartbeatImpl { + return &HeartbeatImpl{ + ctx: nil, + logger: logger, + gateway: gateway, + } +} + +func (h *HeartbeatImpl) Start(ctx context.Context, interval uint64) { + h.ctx = ctx + go h.heartbeatRoutine(interval) +} + +func (h *HeartbeatImpl) ForceHeartbeat() { + h.gateway.Heartbeat() +} + +func (h *HeartbeatImpl) Ack() { + // What do we do here? + h.logger.Debug().Msg("received heartbeat ack") +} + +func (h *HeartbeatImpl) heartbeatRoutine(interval uint64) { + h.logger.Debug().Msgf("beating heart every %dms", interval) + + // REF: heartbeat_interval * jitter + jitter := rand.Intn(int(interval)) + time.Sleep(time.Duration(jitter) * time.Millisecond) + + ticker := time.NewTicker(time.Duration(interval) * time.Millisecond) + defer ticker.Stop() + + for { + h.logger.Debug().Msg("sending heartbeat") + if err := h.gateway.Heartbeat(); err != nil { + h.logger.Error().Err(err).Msg("failed to send heartbeat") + } + + select { + case <-h.ctx.Done(): + return + case <-ticker.C: + continue + } + } +} diff --git a/pkg/discord/gateway/payloads.go b/pkg/discord/gateway/payloads.go new file mode 100644 index 0000000..8da894f --- /dev/null +++ b/pkg/discord/gateway/payloads.go @@ -0,0 +1,54 @@ +package gateway + +import ( + "jinx/pkg/discord/entities" +) + +type GatewayOp uint8 + +const ( + OP_DISPATCH GatewayOp = 0 + OP_HEARTBEAT GatewayOp = 1 + OP_IDENTIFY GatewayOp = 2 + OP_PRESENCE_UPDATE GatewayOp = 3 + OP_VOICE_STATE_UPDATE GatewayOp = 4 + OP_RESUME GatewayOp = 6 + OP_RECONNECT GatewayOp = 7 + OP_REQUEST_GUILD_MEMBERS GatewayOp = 8 + OP_INVALID_SESSION GatewayOp = 9 + OP_HELLO GatewayOp = 10 + OP_HEARTBEAT_ACK GatewayOp = 11 +) + +type Payload[D any] struct { + Op GatewayOp `json:"op"` + Data D `json:"d,omitempty"` + Sequence uint64 `json:"s,omitempty"` + EventName string `json:"t,omitempty"` +} + +type IdentifyCmd struct { + Token string `json:"token"` + Intents uint64 `json:"intents"` + Properties IdentifyProperties `json:"properties"` +} + +type HelloEvent struct { + HeartbeatInterval uint64 `json:"heartbeat_interval"` +} + +type ReadyEvent struct { + Version uint64 `json:"v"` + User entities.User `json:"user"` + Guilds []entities.Guild `json:"guilds"` + SessionID string `json:"session_id"` + Shard []int `json:"shard"` +} + +type MessageCreateEvent entities.Message + +type IdentifyProperties struct { + OS string `json:"$os"` + Browser string `json:"$browser"` + Device string `json:"$device"` +} |
