diff options
Diffstat (limited to 'pkg/discord')
| -rw-r--r-- | pkg/discord/discord.go | 18 | ||||
| -rw-r--r-- | pkg/discord/gateway.go | 61 |
2 files changed, 39 insertions, 40 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go index 837101a..d0f1ad9 100644 --- a/pkg/discord/discord.go +++ b/pkg/discord/discord.go @@ -10,7 +10,6 @@ type Discord struct { token string logger *zerolog.Logger gateway Gateway - heartbeat Heartbeat eventHandler EventHandler rest REST } @@ -21,13 +20,11 @@ func NewClient(token string, logger *zerolog.Logger) *Discord { eventHandler := NewEventHandler() rest := NewREST(token) gateway := NewGateway(logger, eventHandler) - heartbeat := NewHeartbeat(logger, gateway) return &Discord{ token: token, logger: logger, gateway: gateway, - heartbeat: heartbeat, eventHandler: eventHandler, rest: rest, } @@ -39,24 +36,11 @@ func (d *Discord) Connect(ctx context.Context) error { return err } - err = d.gateway.Start(ctx, gatewayURL) + err = d.gateway.Start(ctx, gatewayURL, d.token) if err != nil { return err } - hello, err := d.gateway.Hello() - if err != nil { - return err - } - - d.heartbeat.Start(ctx, hello.HeartbeatInterval) - - if err = d.gateway.Identify(d.token); err != nil { - return err - } - - go d.gateway.Listen() - // We are ready! d.eventHandler.Fire(DISCORD_EVENT_READY, nil) diff --git a/pkg/discord/gateway.go b/pkg/discord/gateway.go index 68bb3f2..32a1e99 100644 --- a/pkg/discord/gateway.go +++ b/pkg/discord/gateway.go @@ -3,7 +3,6 @@ package discord import ( "context" "encoding/json" - "errors" "fmt" "jinx/pkg/libs/cancellablewebsocket" "net/http" @@ -13,14 +12,9 @@ import ( ) type Gateway interface { - Start(ctx context.Context, url string) error + Start(ctx context.Context, url string, token string) error Close() error - Hello() (GatewayHelloEvent, error) - Identify(token string) error - - Listen() - Heartbeat() error } @@ -30,21 +24,29 @@ type GatewayImpl struct { ctx context.Context logger *zerolog.Logger conn *cancellablewebsocket.CancellableWebSocket + heartbeat Heartbeat eventHandler EventHandler lastSeq uint64 } func NewGateway(logger *zerolog.Logger, eventHandler EventHandler) *GatewayImpl { - return &GatewayImpl{ + gateway := &GatewayImpl{ ctx: nil, logger: logger, conn: nil, eventHandler: eventHandler, + heartbeat: nil, lastSeq: 0, } + + // Cycle dependency, is this the best way to do this? + heartbeat := NewHeartbeat(logger, gateway) + gateway.heartbeat = heartbeat + + return gateway } -func (g *GatewayImpl) Start(ctx context.Context, url string) error { +func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error { connectHeader := http.Header{} conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader) if err != nil { @@ -54,6 +56,19 @@ func (g *GatewayImpl) Start(ctx context.Context, url string) error { g.conn = conn g.ctx = ctx + hello, err := g.hello() + if err != nil { + return err + } + + g.heartbeat.Start(ctx, hello.HeartbeatInterval) + + if err = g.identify(token); err != nil { + return err + } + + go g.listen() + return nil } @@ -61,7 +76,16 @@ func (g *GatewayImpl) Close() error { return g.conn.Close() } -func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) { +func (g *GatewayImpl) Heartbeat() error { + msg := GatewayPayload[uint64]{ + Op: GATEWAY_OP_HEARTBEAT, + Data: g.lastSeq, + } + + return g.send(msg) +} + +func (g *GatewayImpl) hello() (GatewayHelloEvent, error) { var msg GatewayPayload[GatewayHelloEvent] if err := receive(g, &msg); err != nil { return GatewayHelloEvent{}, err @@ -74,7 +98,7 @@ func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) { return msg.Data, nil } -func (g *GatewayImpl) Identify(token string) error { +func (g *GatewayImpl) identify(token string) error { msg := GatewayPayload[GatewayIdentifyMsg]{ Op: GATEWAY_OP_IDENTIFY, Data: GatewayIdentifyMsg{ @@ -111,7 +135,7 @@ func (g *GatewayImpl) Identify(token string) error { return nil } -func (g *GatewayImpl) Listen() { +func (g *GatewayImpl) listen() { for { var msg GatewayPayload[json.RawMessage] if err := receive(g, &msg); err != nil { @@ -128,21 +152,12 @@ func (g *GatewayImpl) Listen() { } } -func (g *GatewayImpl) Heartbeat() error { - msg := GatewayPayload[uint64]{ - Op: GATEWAY_OP_HEARTBEAT, - Data: g.lastSeq, - } - - return g.send(msg) -} - func (g *GatewayImpl) onEvent(msg GatewayPayload[json.RawMessage]) error { switch msg.Op { case GATEWAY_OP_HEARTBEAT_ACK: - g.logger.Debug().Msg("received heartbeat ack.") + g.heartbeat.Ack() case GATEWAY_OP_HEARTBEAT: - return errors.New("on demand heartbeat not implemented") + g.heartbeat.ForceHeartbeat() case GATEWAY_OP_DISPATCH: return g.onDispatch(msg.EventName, msg.Data) default: |
