about summary refs log tree commit diff
path: root/pkg/discord/gateway/gateway.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/discord/gateway/gateway.go')
-rw-r--r--pkg/discord/gateway/gateway.go202
1 files changed, 202 insertions, 0 deletions
diff --git a/pkg/discord/gateway/gateway.go b/pkg/discord/gateway/gateway.go
new file mode 100644
index 0000000..18cf708
--- /dev/null
+++ b/pkg/discord/gateway/gateway.go
@@ -0,0 +1,202 @@
+package gateway
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"jinx/pkg/discord/events"
+	"jinx/pkg/libs/cancellablewebsocket"
+	"net/http"
+
+	"github.com/gorilla/websocket"
+	"github.com/rs/zerolog"
+)
+
+type Gateway interface {
+	Start(ctx context.Context, url string, token string) error
+	Close() error
+
+	Heartbeat() error
+}
+
+var _ Gateway = &GatewayImpl{}
+
+type GatewayImpl struct {
+	ctx          context.Context
+	logger       *zerolog.Logger
+	conn         *cancellablewebsocket.CancellableWebSocket
+	heartbeat    Heartbeat
+	eventHandler events.EventHandler
+	lastSeq      uint64
+}
+
+func New(logger *zerolog.Logger, eventHandler events.EventHandler) *GatewayImpl {
+	gateway := &GatewayImpl{
+		ctx:          nil,
+		logger:       logger,
+		conn:         nil,
+		eventHandler: eventHandler,
+		heartbeat:    nil,
+		lastSeq:      0,
+	}
+
+	// Cycle dependency, is this the best way to do this?
+	heartbeat := NewHeartbeat(logger, gateway)
+	gateway.heartbeat = heartbeat
+
+	return gateway
+}
+
+func (g *GatewayImpl) Start(ctx context.Context, url string, token string) error {
+	connectHeader := http.Header{}
+	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, ctx, url, connectHeader)
+	if err != nil {
+		return err
+	}
+
+	g.conn = conn
+	g.ctx = ctx
+
+	hello, err := g.hello()
+	if err != nil {
+		return err
+	}
+
+	g.heartbeat.Start(ctx, hello.HeartbeatInterval)
+
+	if err = g.identify(token); err != nil {
+		return err
+	}
+
+	go g.listen()
+
+	return nil
+}
+
+func (g *GatewayImpl) Close() error {
+	return g.conn.Close()
+}
+
+func (g *GatewayImpl) Heartbeat() error {
+	msg := Payload[uint64]{
+		Op:   OP_HEARTBEAT,
+		Data: g.lastSeq,
+	}
+
+	return g.send(msg)
+}
+
+func (g *GatewayImpl) hello() (HelloEvent, error) {
+	var msg Payload[HelloEvent]
+	if err := receive(g, &msg); err != nil {
+		return HelloEvent{}, err
+	}
+
+	if msg.Op != OP_HELLO {
+		return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op)
+	}
+
+	return msg.Data, nil
+}
+
+func (g *GatewayImpl) identify(token string) error {
+	msg := Payload[IdentifyCmd]{
+		Op: OP_IDENTIFY,
+		Data: IdentifyCmd{
+			Token:   token,
+			Intents: 13991,
+			Properties: IdentifyProperties{
+				OS:      "linux",
+				Browser: "jinx",
+				Device:  "jinx",
+			},
+		},
+		Sequence: 0,
+	}
+
+	if err := g.send(msg); err != nil {
+		return err
+	}
+
+	var res Payload[ReadyEvent]
+	if err := receive(g, &res); err != nil {
+		return err
+	}
+
+	g.logger.Debug().Msgf("identify response payload: %+v", res)
+
+	if res.Op != OP_DISPATCH {
+		return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op)
+	}
+
+	if res.EventName != "READY" {
+		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) listen() {
+	for {
+		var msg Payload[json.RawMessage]
+		if err := receive(g, &msg); err != nil {
+			g.logger.Error().Err(err).Msgf("error reading message")
+			continue
+		}
+
+		select {
+		case <-g.ctx.Done():
+			return
+		default:
+			g.onEvent(msg)
+		}
+	}
+}
+
+func (g *GatewayImpl) onEvent(msg Payload[json.RawMessage]) error {
+	switch msg.Op {
+	case OP_HEARTBEAT_ACK:
+		g.heartbeat.Ack()
+	case OP_HEARTBEAT:
+		g.heartbeat.ForceHeartbeat()
+	case OP_DISPATCH:
+		return g.onDispatch(msg.EventName, msg.Data)
+	default:
+		g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
+	switch eventName {
+	case "MESSAGE_CREATE":
+		var payload MessageCreateEvent
+		if err := json.Unmarshal(body, &payload); err != nil {
+			return err
+		}
+
+		g.eventHandler.Fire(events.MESSAGE, payload)
+	default:
+		g.logger.Warn().Msgf("received unknown event: %s", eventName)
+	}
+
+	return nil
+}
+
+func (g *GatewayImpl) send(payload any) error {
+	return g.conn.WriteJSON(payload)
+}
+
+func receive[D any](g *GatewayImpl, res *Payload[D]) error {
+	err := g.conn.ReadJSON(&res)
+	if err != nil {
+		return err
+	}
+
+	if res.Sequence != 0 {
+		g.lastSeq = res.Sequence
+	}
+
+	return nil
+}