about summary refs log tree commit diff
path: root/pkg/discord/discord.go
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-04-06 20:13:25 +0200
committerMel <einebeere@gmail.com>2022-04-06 20:13:25 +0200
commit4baf8edf31a2fc10f401a770636d8c98535264cc (patch)
treed6e01ae430652d14ca867d9a9b7f02e5d2d5fc18 /pkg/discord/discord.go
parenteea8d77a8676feef17ca7c675742958c7bc3e93c (diff)
downloadjinx-4baf8edf31a2fc10f401a770636d8c98535264cc.tar.zst
jinx-4baf8edf31a2fc10f401a770636d8c98535264cc.zip
Split off gateway and heartbeat service
Diffstat (limited to 'pkg/discord/discord.go')
-rw-r--r--pkg/discord/discord.go160
1 files changed, 20 insertions, 140 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index 709f5e1..837101a 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -2,35 +2,34 @@ package discord
 
 import (
 	"context"
-	"encoding/json"
-	"errors"
-	"fmt"
-	"jinx/pkg/libs/cancellablewebsocket"
-	"math/rand"
-	"net/http"
-	"time"
-
-	"github.com/gorilla/websocket"
+
 	"github.com/rs/zerolog"
 )
 
 type Discord struct {
 	token        string
 	logger       *zerolog.Logger
-	conn         *cancellablewebsocket.CancellableWebSocket
-	eventHandler *EventHandlerImpl
+	gateway      Gateway
+	heartbeat    Heartbeat
+	eventHandler EventHandler
 	rest         REST
 }
 
 func NewClient(token string, logger *zerolog.Logger) *Discord {
 	token = "Bot " + token
 
+	eventHandler := NewEventHandler()
+	rest := NewREST(token)
+	gateway := NewGateway(logger, eventHandler)
+	heartbeat := NewHeartbeat(logger, gateway)
+
 	return &Discord{
 		token:        token,
 		logger:       logger,
-		conn:         nil,
-		eventHandler: NewEventHandler(),
-		rest:         NewREST(token),
+		gateway:      gateway,
+		heartbeat:    heartbeat,
+		eventHandler: eventHandler,
+		rest:         rest,
 	}
 }
 
@@ -40,28 +39,23 @@ func (d *Discord) Connect(ctx context.Context) error {
 		return err
 	}
 
-	connectHeader := http.Header{}
-	d.conn, err = cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, gatewayURL, connectHeader)
+	err = d.gateway.Start(ctx, gatewayURL)
 	if err != nil {
 		return err
 	}
 
-	var helloMsg GatewayPayload[GatewayHelloEvent]
-	if err = d.conn.ReadJSON(&helloMsg); err != nil {
+	hello, err := d.gateway.Hello()
+	if err != nil {
 		return err
 	}
 
-	if helloMsg.Op != GATEWAY_OP_HELLO {
-		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_HELLO, helloMsg.Op)
-	}
-
-	go d.heartbeat(ctx, helloMsg.Data.HeartbeatInterval)
+	d.heartbeat.Start(ctx, hello.HeartbeatInterval)
 
-	if err = d.identify(); err != nil {
+	if err = d.gateway.Identify(d.token); err != nil {
 		return err
 	}
 
-	go d.listen(ctx)
+	go d.gateway.Listen()
 
 	// We are ready!
 	d.eventHandler.Fire(DISCORD_EVENT_READY, nil)
@@ -70,121 +64,7 @@ func (d *Discord) Connect(ctx context.Context) error {
 }
 
 func (d *Discord) Disconnect() error {
-	if d.conn == nil {
-		return errors.New("not connected")
-	}
-
-	return d.conn.Close()
-}
-
-func (d *Discord) heartbeat(ctx context.Context, interval uint64) {
-	// REF: heartbeat_interval * jitter
-	jitter := rand.Intn(int(interval))
-	time.Sleep(time.Duration(jitter) * time.Millisecond)
-
-	ticker := time.NewTicker(time.Duration(interval) * time.Millisecond)
-	defer ticker.Stop()
-
-	for {
-		select {
-		case <-ctx.Done():
-			return
-		case <-ticker.C:
-			d.logger.Debug().Msg("sending heartbeat")
-
-			msg := GatewayPayload[any]{
-				Op: GATEWAY_OP_HEARTBEAT,
-			}
-
-			if err := d.conn.WriteJSON(msg); err != nil {
-				d.logger.Error().Err(err).Msg("error sending heartbeat")
-			}
-		}
-	}
-}
-
-func (d *Discord) identify() error {
-	msg := GatewayPayload[GatewayIdentifyMsg]{
-		Op: GATEWAY_OP_IDENTIFY,
-		Data: GatewayIdentifyMsg{
-			Token:   d.token,
-			Intents: 13991,
-			Properties: GatewayIdentifyProperties{
-				OS:      "linux",
-				Browser: "jinx",
-				Device:  "jinx",
-			},
-		},
-		Sequence: 0,
-	}
-
-	if err := d.conn.WriteJSON(msg); err != nil {
-		return err
-	}
-
-	var res GatewayPayload[GatewayReadyEvent]
-	if err := d.conn.ReadJSON(&res); err != nil {
-		return err
-	}
-
-	d.logger.Debug().Msgf("identify response payload: %+v", res)
-
-	if res.Op != GATEWAY_OP_DISPATCH {
-		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_DISPATCH, res.Op)
-	}
-
-	if res.EventName != "READY" {
-		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
-	}
-
-	return nil
-}
-
-func (d *Discord) listen(ctx context.Context) {
-	for {
-		var msg GatewayPayload[json.RawMessage]
-		if err := d.conn.ReadJSON(&msg); err != nil {
-			d.logger.Error().Err(err).Msg("error reading message")
-		}
-
-		select {
-		case <-ctx.Done():
-			return
-		default:
-			d.onEvent(msg)
-		}
-	}
-}
-
-func (d *Discord) onEvent(msg GatewayPayload[json.RawMessage]) error {
-	switch msg.Op {
-	case GATEWAY_OP_HEARTBEAT_ACK:
-		d.logger.Debug().Msg("received heartbeat ack.")
-	case GATEWAY_OP_HEARTBEAT:
-		return errors.New("on demand heartbeat not implemented")
-	case GATEWAY_OP_DISPATCH:
-		return d.onDispatch(msg.EventName, msg.Data)
-	default:
-		d.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
-	}
-
-	return nil
-}
-
-func (d *Discord) onDispatch(eventName string, body json.RawMessage) error {
-	switch eventName {
-	case "MESSAGE_CREATE":
-		var payload GatewayMessageCreateEvent
-		if err := json.Unmarshal(body, &payload); err != nil {
-			return err
-		}
-
-		d.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload)
-	default:
-		d.logger.Warn().Msgf("received unknown event: %s", eventName)
-	}
-
-	return nil
+	return d.gateway.Close()
 }
 
 func (d *Discord) SendMessage(channelID Snowflake, content string) error {