diff options
Diffstat (limited to 'pkg/discord/gateway.go')
| -rw-r--r-- | pkg/discord/gateway.go | 61 |
1 files changed, 38 insertions, 23 deletions
diff --git a/pkg/discord/gateway.go b/pkg/discord/gateway.go index 68bb3f2..32a1e99 100644 --- a/pkg/discord/gateway.go +++ b/pkg/discord/gateway.go @@ -3,7 +3,6 @@ package discord import ( "context" "encoding/json" - "errors" "fmt" "jinx/pkg/libs/cancellablewebsocket" "net/http" @@ -13,14 +12,9 @@ import ( ) type Gateway interface { - Start(ctx context.Context, url string) error + Start(ctx context.Context, url string, token string) error Close() error - Hello() (GatewayHelloEvent, error) - Identify(token string) error - - Listen() - Heartbeat() error } @@ -30,21 +24,29 @@ type GatewayImpl struct { ctx context.Context logger *zerolog.Logger conn *cancellablewebsocket.CancellableWebSocket + heartbeat Heartbeat eventHandler EventHandler lastSeq uint64 } func NewGateway(logger *zerolog.Logger, eventHandler EventHandler) *GatewayImpl { - return &GatewayImpl{ + gateway := &GatewayImpl{ ctx: nil, logger: logger, conn: nil, eventHandler: eventHandler, + heartbeat: nil, lastSeq: 0, } + + // Cycle dependency, is this the best way to do this? + heartbeat := NewHeartbeat(logger, gateway) + gateway.heartbeat = heartbeat + + return gateway } -func (g *GatewayImpl) Start(ctx context.Context, url string) error { +func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { connectHeader := http.Header{} conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) if err != nil { @@ -54,6 +56,19 @@ func (g *GatewayImpl) Start(ctx context.Context, url string) error { g.conn = conn g.ctx = ctx + hello, err := g.hello() + if err != nil { + return err + } + + g.heartbeat.Start(ctx, hello.HeartbeatInterval) + + if err = g.identify(token); err != nil { + return err + } + + go g.listen() + return nil } @@ -61,7 +76,16 @@ func (g *GatewayImpl) Close() error { return g.conn.Close() } -func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) { +func (g *GatewayImpl) Heartbeat() error { + msg := GatewayPayload[uint64]{ + Op: GATEWAY_OP_HEARTBEAT, + Data: g.lastSeq, + } + + return g.send(msg) +} + +func (g *GatewayImpl) hello() (GatewayHelloEvent, error) { var msg GatewayPayload[GatewayHelloEvent] if err := receive(g, &msg); err != nil { return GatewayHelloEvent{}, err @@ -74,7 +98,7 @@ func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) { return msg.Data, nil } -func (g *GatewayImpl) Identify(token string) error { +func (g *GatewayImpl) identify(token string) error { msg := GatewayPayload[GatewayIdentifyMsg]{ Op: GATEWAY_OP_IDENTIFY, Data: GatewayIdentifyMsg{ @@ -111,7 +135,7 @@ func (g *GatewayImpl) Identify(token string) error { return nil } -func (g *GatewayImpl) Listen() { +func (g *GatewayImpl) listen() { for { var msg GatewayPayload[json.RawMessage] if err := receive(g, &msg); err != nil { @@ -128,21 +152,12 @@ func (g *GatewayImpl) Listen() { } } -func (g *GatewayImpl) Heartbeat() error { - msg := GatewayPayload[uint64]{ - Op: GATEWAY_OP_HEARTBEAT, - Data: g.lastSeq, - } - - return g.send(msg) -} - func (g *GatewayImpl) onEvent(msg GatewayPayload[json.RawMessage]) error { switch msg.Op { case GATEWAY_OP_HEARTBEAT_ACK: - g.logger.Debug().Msg("received heartbeat ack.") + g.heartbeat.Ack() case GATEWAY_OP_HEARTBEAT: - return errors.New("on demand heartbeat not implemented") + g.heartbeat.ForceHeartbeat() case GATEWAY_OP_DISPATCH: return g.onDispatch(msg.EventName, msg.Data) default: |
