about summary refs log tree commit diff
path: root/pkg/discord
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/discord')
-rw-r--r--pkg/discord/discord.go18
-rw-r--r--pkg/discord/gateway.go61
2 files changed, 39 insertions, 40 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index 837101a..d0f1ad9 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -10,7 +10,6 @@ type Discord struct {
 	token        string
 	logger       *zerolog.Logger
 	gateway      Gateway
-	heartbeat    Heartbeat
 	eventHandler EventHandler
 	rest         REST
 }
@@ -21,13 +20,11 @@ func NewClient(token string, logger *zerolog.Logger) *Discord {
 	eventHandler := NewEventHandler()
 	rest := NewREST(token)
 	gateway := NewGateway(logger, eventHandler)
-	heartbeat := NewHeartbeat(logger, gateway)
 
 	return &Discord{
 		token:        token,
 		logger:       logger,
 		gateway:      gateway,
-		heartbeat:    heartbeat,
 		eventHandler: eventHandler,
 		rest:         rest,
 	}
@@ -39,24 +36,11 @@ func (d *Discord) Connect(ctx context.Context) error {
 		return err
 	}
 
-	err = d.gateway.Start(ctx, gatewayURL)
+	err = d.gateway.Start(ctx, gatewayURL, d.token)
 	if err != nil {
 		return err
 	}
 
-	hello, err := d.gateway.Hello()
-	if err != nil {
-		return err
-	}
-
-	d.heartbeat.Start(ctx, hello.HeartbeatInterval)
-
-	if err = d.gateway.Identify(d.token); err != nil {
-		return err
-	}
-
-	go d.gateway.Listen()
-
 	// We are ready!
 	d.eventHandler.Fire(DISCORD_EVENT_READY, nil)
 
diff --git a/pkg/discord/gateway.go b/pkg/discord/gateway.go
index 68bb3f2..32a1e99 100644
--- a/pkg/discord/gateway.go
+++ b/pkg/discord/gateway.go
@@ -3,7 +3,6 @@ package discord
 import (
 	"context"
 	"encoding/json"
-	"errors"
 	"fmt"
 	"jinx/pkg/libs/cancellablewebsocket"
 	"net/http"
@@ -13,14 +12,9 @@ import (
 )
 
 type Gateway interface {
-	Start(ctx context.Context, url string) error
+	Start(ctx context.Context, url string, token string) error
 	Close() error
 
-	Hello() (GatewayHelloEvent, error)
-	Identify(token string) error
-
-	Listen()
-
 	Heartbeat() error
 }
 
@@ -30,21 +24,29 @@ type GatewayImpl struct {
 	ctx          context.Context
 	logger       *zerolog.Logger
 	conn         *cancellablewebsocket.CancellableWebSocket
+	heartbeat    Heartbeat
 	eventHandler EventHandler
 	lastSeq      uint64
 }
 
 func NewGateway(logger *zerolog.Logger, eventHandler EventHandler) *GatewayImpl {
-	return &GatewayImpl{
+	gateway := &GatewayImpl{
 		ctx:          nil,
 		logger:       logger,
 		conn:         nil,
 		eventHandler: eventHandler,
+		heartbeat:    nil,
 		lastSeq:      0,
 	}
+
+	// Cycle dependency, is this the best way to do this?
+	heartbeat := NewHeartbeat(logger, gateway)
+	gateway.heartbeat = heartbeat
+
+	return gateway
 }
 
-func (g *GatewayImpl) Start(ctx context.Context, url string) error {
+func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error {
 	connectHeader := http.Header{}
 	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader)
 	if err != nil {
@@ -54,6 +56,19 @@ func (g *GatewayImpl) Start(ctx context.Context, url string) error {
 	g.conn = conn
 	g.ctx = ctx
 
+	hello, err := g.hello()
+	if err != nil {
+		return err
+	}
+
+	g.heartbeat.Start(ctx, hello.HeartbeatInterval)
+
+	if err = g.identify(token); err != nil {
+		return err
+	}
+
+	go g.listen()
+
 	return nil
 }
 
@@ -61,7 +76,16 @@ func (g *GatewayImpl) Close() error {
 	return g.conn.Close()
 }
 
-func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) {
+func (g *GatewayImpl) Heartbeat() error {
+	msg := GatewayPayload[uint64]{
+		Op:   GATEWAY_OP_HEARTBEAT,
+		Data: g.lastSeq,
+	}
+
+	return g.send(msg)
+}
+
+func (g *GatewayImpl) hello() (GatewayHelloEvent, error) {
 	var msg GatewayPayload[GatewayHelloEvent]
 	if err := receive(g, &msg); err != nil {
 		return GatewayHelloEvent{}, err
@@ -74,7 +98,7 @@ func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) {
 	return msg.Data, nil
 }
 
-func (g *GatewayImpl) Identify(token string) error {
+func (g *GatewayImpl) identify(token string) error {
 	msg := GatewayPayload[GatewayIdentifyMsg]{
 		Op: GATEWAY_OP_IDENTIFY,
 		Data: GatewayIdentifyMsg{
@@ -111,7 +135,7 @@ func (g *GatewayImpl) Identify(token string) error {
 	return nil
 }
 
-func (g *GatewayImpl) Listen() {
+func (g *GatewayImpl) listen() {
 	for {
 		var msg GatewayPayload[json.RawMessage]
 		if err := receive(g, &msg); err != nil {
@@ -128,21 +152,12 @@ func (g *GatewayImpl) Listen() {
 	}
 }
 
-func (g *GatewayImpl) Heartbeat() error {
-	msg := GatewayPayload[uint64]{
-		Op:   GATEWAY_OP_HEARTBEAT,
-		Data: g.lastSeq,
-	}
-
-	return g.send(msg)
-}
-
 func (g *GatewayImpl) onEvent(msg GatewayPayload[json.RawMessage]) error {
 	switch msg.Op {
 	case GATEWAY_OP_HEARTBEAT_ACK:
-		g.logger.Debug().Msg("received heartbeat ack.")
+		g.heartbeat.Ack()
 	case GATEWAY_OP_HEARTBEAT:
-		return errors.New("on demand heartbeat not implemented")
+		g.heartbeat.ForceHeartbeat()
 	case GATEWAY_OP_DISPATCH:
 		return g.onDispatch(msg.EventName, msg.Data)
 	default: