about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-04-06 20:13:25 +0200
committerMel <einebeere@gmail.com>2022-04-06 20:13:25 +0200
commit4baf8edf31a2fc10f401a770636d8c98535264cc (patch)
treed6e01ae430652d14ca867d9a9b7f02e5d2d5fc18 /pkg
parenteea8d77a8676feef17ca7c675742958c7bc3e93c (diff)
downloadjinx-4baf8edf31a2fc10f401a770636d8c98535264cc.tar.zst
jinx-4baf8edf31a2fc10f401a770636d8c98535264cc.zip
Split off gateway and heartbeat service
Diffstat (limited to 'pkg')
-rw-r--r--pkg/discord/discord.go160
-rw-r--r--pkg/discord/gateway.go174
-rw-r--r--pkg/discord/heartbeat.go67
3 files changed, 261 insertions, 140 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index 709f5e1..837101a 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -2,35 +2,34 @@ package discord
 
 import (
 	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"jinx/pkg/libs/cancellablewebsocket"
-	"math/rand"
-	"net/http"
-	"time"
-
-	"github.com/gorilla/websocket"
+
 	"github.com/rs/zerolog"
 )
 
 type Discord struct {
 	token        string
 	logger       *zerolog.Logger
-	conn         *cancellablewebsocket.CancellableWebSocket
-	eventHandler *EventHandlerImpl
+	gateway      Gateway
+	heartbeat    Heartbeat
+	eventHandler EventHandler
 	rest         REST
 }
 
 func NewClient(token string, logger *zerolog.Logger) *Discord {
 	token = "Bot " + token
 
+	eventHandler := NewEventHandler()
+	rest := NewREST(token)
+	gateway := NewGateway(logger, eventHandler)
+	heartbeat := NewHeartbeat(logger, gateway)
+
 	return &Discord{
 		token:        token,
 		logger:       logger,
-		conn:         nil,
-		eventHandler: NewEventHandler(),
-		rest:         NewREST(token),
+		gateway:      gateway,
+		heartbeat:    heartbeat,
+		eventHandler: eventHandler,
+		rest:         rest,
 	}
 }
 
@@ -40,28 +39,23 @@ func (d *Discord) Connect(ctx context.Context) error {
 		return err
 	}
 
-	connectHeader := http.Header{}
-	d.conn, err = cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, gatewayURL, connectHeader)
+	err = d.gateway.Start(ctx, gatewayURL)
 	if err != nil {
 		return err
 	}
 
-	var helloMsg GatewayPayload[GatewayHelloEvent]
-	if err = d.conn.ReadJSON(&helloMsg); err != nil {
+	hello, err := d.gateway.Hello()
+	if err != nil {
 		return err
 	}
 
-	if helloMsg.Op != GATEWAY_OP_HELLO {
-		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_HELLO, helloMsg.Op)
-	}
-
-	go d.heartbeat(ctx, helloMsg.Data.HeartbeatInterval)
+	d.heartbeat.Start(ctx, hello.HeartbeatInterval)
 
-	if err = d.identify(); err != nil {
+	if err = d.gateway.Identify(d.token); err != nil {
 		return err
 	}
 
-	go d.listen(ctx)
+	go d.gateway.Listen()
 
 	// We are ready!
 	d.eventHandler.Fire(DISCORD_EVENT_READY, nil)
@@ -70,121 +64,7 @@ func (d *Discord) Connect(ctx context.Context) error {
 }
 
 func (d *Discord) Disconnect() error {
-	if d.conn == nil {
-		return errors.New("not connected")
-	}
-
-	return d.conn.Close()
-}
-
-func (d *Discord) heartbeat(ctx context.Context, interval uint64) {
-	// REF: heartbeat_interval * jitter
-	jitter := rand.Intn(int(interval))
-	time.Sleep(time.Duration(jitter) * time.Millisecond)
-
-	ticker := time.NewTicker(time.Duration(interval) * time.Millisecond)
-	defer ticker.Stop()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		case <-ticker.C:
-			d.logger.Debug().Msg("sending heartbeat")
-
-			msg := GatewayPayload[any]{
-				Op: GATEWAY_OP_HEARTBEAT,
-			}
-
-			if err := d.conn.WriteJSON(msg); err != nil {
-				d.logger.Error().Err(err).Msg("error sending heartbeat")
-			}
-		}
-	}
-}
-
-func (d *Discord) identify() error {
-	msg := GatewayPayload[GatewayIdentifyMsg]{
-		Op: GATEWAY_OP_IDENTIFY,
-		Data: GatewayIdentifyMsg{
-			Token:   d.token,
-			Intents: 13991,
-			Properties: GatewayIdentifyProperties{
-				OS:      "linux",
-				Browser: "jinx",
-				Device:  "jinx",
-			},
-		},
-		Sequence: 0,
-	}
-
-	if err := d.conn.WriteJSON(msg); err != nil {
-		return err
-	}
-
-	var res GatewayPayload[GatewayReadyEvent]
-	if err := d.conn.ReadJSON(&res); err != nil {
-		return err
-	}
-
-	d.logger.Debug().Msgf("identify response payload: %+v", res)
-
-	if res.Op != GATEWAY_OP_DISPATCH {
-		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_DISPATCH, res.Op)
-	}
-
-	if res.EventName != "READY" {
-		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
-	}
-
-	return nil
-}
-
-func (d *Discord) listen(ctx context.Context) {
-	for {
-		var msg GatewayPayload[json.RawMessage]
-		if err := d.conn.ReadJSON(&msg); err != nil {
-			d.logger.Error().Err(err).Msg("error reading message")
-		}
-
-		select {
-		case <-ctx.Done():
-			return
-		default:
-			d.onEvent(msg)
-		}
-	}
-}
-
-func (d *Discord) onEvent(msg GatewayPayload[json.RawMessage]) error {
-	switch msg.Op {
-	case GATEWAY_OP_HEARTBEAT_ACK:
-		d.logger.Debug().Msg("received heartbeat ack.")
-	case GATEWAY_OP_HEARTBEAT:
-		return errors.New("on demand heartbeat not implemented")
-	case GATEWAY_OP_DISPATCH:
-		return d.onDispatch(msg.EventName, msg.Data)
-	default:
-		d.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
-	}
-
-	return nil
-}
-
-func (d *Discord) onDispatch(eventName string, body json.RawMessage) error {
-	switch eventName {
-	case "MESSAGE_CREATE":
-		var payload GatewayMessageCreateEvent
-		if err := json.Unmarshal(body, &payload); err != nil {
-			return err
-		}
-
-		d.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload)
-	default:
-		d.logger.Warn().Msgf("received unknown event: %s", eventName)
-	}
-
-	return nil
+	return d.gateway.Close()
 }
 
 func (d *Discord) SendMessage(channelID Snowflake, content string) error {
diff --git a/pkg/discord/gateway.go b/pkg/discord/gateway.go
new file mode 100644
index 0000000..cfabdc7
--- /dev/null
+++ b/pkg/discord/gateway.go
@@ -0,0 +1,174 @@
+package discord
+
+import (
+	"context"
+	"encoding/json"
+	"errors"
+	"fmt"
+	"jinx/pkg/libs/cancellablewebsocket"
+	"net/http"
+
+	"github.com/gorilla/websocket"
+	"github.com/rs/zerolog"
+)
+
+type Gateway interface {
+	Start(ctx context.Context, url string) error
+	Close() error
+
+	Hello() (GatewayHelloEvent, error)
+	Identify(token string) error
+
+	Listen()
+
+	Heartbeat() error
+}
+
+var _ Gateway = &GatewayImpl{}
+
+type GatewayImpl struct {
+	ctx          context.Context
+	logger       *zerolog.Logger
+	conn         *cancellablewebsocket.CancellableWebSocket
+	eventHandler EventHandler
+}
+
+func NewGateway(logger *zerolog.Logger, eventHandler EventHandler) *GatewayImpl {
+	return &GatewayImpl{
+		ctx:          nil,
+		logger:       logger,
+		conn:         nil,
+		eventHandler: eventHandler,
+	}
+}
+
+func (g *GatewayImpl) Start(ctx context.Context, url string) error {
+	connectHeader := http.Header{}
+	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader)
+	if err != nil {
+		return err
+	}
+
+	g.conn = conn
+	g.ctx = ctx
+
+	return nil
+}
+
+func (g *GatewayImpl) Close() error {
+	return g.conn.Close()
+}
+
+func (g *GatewayImpl) Hello() (GatewayHelloEvent, error) {
+	var msg GatewayPayload[GatewayHelloEvent]
+	if err := g.receive(&msg); err != nil {
+		return GatewayHelloEvent{}, err
+	}
+
+	if msg.Op != GATEWAY_OP_HELLO {
+		return GatewayHelloEvent{}, fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_HELLO, msg.Op)
+	}
+
+	return msg.Data, nil
+}
+
+func (g *GatewayImpl) Identify(token string) error {
+	msg := GatewayPayload[GatewayIdentifyMsg]{
+		Op: GATEWAY_OP_IDENTIFY,
+		Data: GatewayIdentifyMsg{
+			Token:   token,
+			Intents: 13991,
+			Properties: GatewayIdentifyProperties{
+				OS:      "linux",
+				Browser: "jinx",
+				Device:  "jinx",
+			},
+		},
+		Sequence: 0,
+	}
+
+	if err := g.send(msg); err != nil {
+		return err
+	}
+
+	var res GatewayPayload[GatewayReadyEvent]
+	if err := g.receive(&res); err != nil {
+		return err
+	}
+
+	g.logger.Debug().Msgf("identify response payload: %+v", res)
+
+	if res.Op != GATEWAY_OP_DISPATCH {
+		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_DISPATCH, res.Op)
+	}
+
+	if res.EventName != "READY" {
+		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) Listen() {
+	for {
+		var msg GatewayPayload[json.RawMessage]
+		if err := g.receive(&msg); err != nil {
+			g.logger.Error().Err(err).Msgf("error reading message")
+			continue
+		}
+
+		select {
+		case <-g.ctx.Done():
+			return
+		default:
+			g.onEvent(msg)
+		}
+	}
+}
+
+func (g *GatewayImpl) Heartbeat() error {
+	msg := GatewayPayload[any]{
+		Op: GATEWAY_OP_HEARTBEAT,
+	}
+
+	return g.send(msg)
+}
+
+func (g *GatewayImpl) onEvent(msg GatewayPayload[json.RawMessage]) error {
+	switch msg.Op {
+	case GATEWAY_OP_HEARTBEAT_ACK:
+		g.logger.Debug().Msg("received heartbeat ack.")
+	case GATEWAY_OP_HEARTBEAT:
+		return errors.New("on demand heartbeat not implemented")
+	case GATEWAY_OP_DISPATCH:
+		return g.onDispatch(msg.EventName, msg.Data)
+	default:
+		g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
+	switch eventName {
+	case "MESSAGE_CREATE":
+		var payload GatewayMessageCreateEvent
+		if err := json.Unmarshal(body, &payload); err != nil {
+			return err
+		}
+
+		g.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload)
+	default:
+		g.logger.Warn().Msgf("received unknown event: %s", eventName)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) receive(res interface{}) error {
+	return g.conn.ReadJSON(res)
+}
+
+func (g *GatewayImpl) send(payload interface{}) error {
+	return g.conn.WriteJSON(payload)
+}
diff --git a/pkg/discord/heartbeat.go b/pkg/discord/heartbeat.go
new file mode 100644
index 0000000..5d2972f
--- /dev/null
+++ b/pkg/discord/heartbeat.go
@@ -0,0 +1,67 @@
+package discord
+
+import (
+	"context"
+	"math/rand"
+	"time"
+
+	"github.com/rs/zerolog"
+)
+
+type Heartbeat interface {
+	Start(ctx context.Context, interval uint64)
+
+	ForceHeartbeat()
+	Ack()
+}
+
+var _ Heartbeat = &HeartbeatImpl{}
+
+type HeartbeatImpl struct {
+	ctx     context.Context
+	logger  *zerolog.Logger
+	gateway Gateway
+}
+
+func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *HeartbeatImpl {
+	return &HeartbeatImpl{
+		ctx:     nil,
+		logger:  logger,
+		gateway: gateway,
+	}
+}
+
+func (h *HeartbeatImpl) Start(ctx context.Context, interval uint64) {
+	h.ctx = ctx
+	go h.heartbeatRoutine(interval)
+}
+
+func (h *HeartbeatImpl) ForceHeartbeat() {
+	h.gateway.Heartbeat()
+}
+
+func (h *HeartbeatImpl) Ack() {
+	// What do we do here?
+	h.logger.Debug().Msg("received heartbeat ack")
+}
+
+func (h *HeartbeatImpl) heartbeatRoutine(interval uint64) {
+	// REF: heartbeat_interval * jitter
+	jitter := rand.Intn(int(interval))
+	time.Sleep(time.Duration(jitter) * time.Millisecond)
+
+	ticker := time.NewTicker(time.Duration(interval) * time.Millisecond)
+	defer ticker.Stop()
+
+	for {
+		select {
+		case <-h.ctx.Done():
+			return
+		case <-ticker.C:
+			h.logger.Debug().Msg("sending heartbeat")
+			if err := h.gateway.Heartbeat(); err != nil {
+				h.logger.Error().Err(err).Msg("failed to send heartbeat")
+			}
+		}
+	}
+}