about summary refs log tree commit diff
path: root/pkg/discord/gateway/gateway.go
blob: 114eabcb6ecad78fce7a66384362b97985143669 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
package gateway

import (
	"context"
	"encoding/json"
	"fmt"
	"jinx/pkg/discord/events"
	"jinx/pkg/libs/cancellablewebsocket"
	"net/http"
	"time"

	"github.com/gorilla/websocket"
	"github.com/rs/zerolog"
)

type Gateway interface {
	Start(ctx context.Context, url string, token string) error
	Close() error

	Heartbeat() error
}

var _ Gateway = &gatewayImpl{}

type gatewayImpl struct {
	logger       *zerolog.Logger
	conn         *cancellablewebsocket.CancellableWebSocket
	heartbeat    Heartbeat
	eventHandler events.EventHandler

	parentCtx     context.Context
	cancelRoutine context.CancelFunc

	url       string
	token     string
	sessionID string
	lastSeq   uint64
}

func New(logger *zerolog.Logger, eventHandler events.EventHandler) Gateway {
	gateway := &gatewayImpl{
		logger:        logger,
		conn:          nil,
		eventHandler:  eventHandler,
		heartbeat:     nil,
		parentCtx:     nil,
		cancelRoutine: nil,
		url:           "",
		token:         "",
		sessionID:     "",
		lastSeq:       0,
	}

	// Cycle dependency, is this the best way to do this?
	heartbeat := NewHeartbeat(logger, gateway)
	gateway.heartbeat = heartbeat

	return gateway
}

func (g *gatewayImpl) Start(ctx context.Context, url string, token string) error {
	g.parentCtx = ctx
	routineCtx, cancelRoutine := context.WithCancel(ctx)
	g.cancelRoutine = cancelRoutine

	connectHeader := http.Header{}
	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, url, connectHeader)
	if err != nil {
		return err
	}

	g.url = url
	g.token = token
	g.conn = conn

	hello, err := g.hello()
	if err != nil {
		return err
	}

	heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond
	g.heartbeat.Start(g.parentCtx, heartbeatInterval)

	if err = g.identify(token); err != nil {
		return err
	}

	// Set up close handler for reconnecting or killing.
	// Done after identity and hello, because we kill the gateway if those fail.
	g.conn.OnClose(func(code int, text string) error {
		g.onClose(code)
		return nil
	})

	go g.listen(routineCtx)

	return nil
}

func (g *gatewayImpl) Close() error {
	// Try closing gracefully.
	g.logger.Debug().Msg("closing gateway gracefully...")
	if err := g.conn.Close(1000); err != nil {
		g.logger.Error().Err(err).Msg("error closing gateway gracefully, trying to murder...")
		return g.conn.Kill()
	}

	return nil
}

func (g *gatewayImpl) Heartbeat() error {
	msg := Payload[uint64]{
		Op:   OP_HEARTBEAT,
		Data: g.lastSeq,
	}

	return g.send(msg)
}

func (g *gatewayImpl) hello() (HelloEvent, error) {
	var msg Payload[HelloEvent]
	if err := receive(g, &msg); err != nil {
		return HelloEvent{}, err
	}

	if msg.Op != OP_HELLO {
		return HelloEvent{}, fmt.Errorf("expected opcode %d, got %d", OP_HELLO, msg.Op)
	}

	return msg.Data, nil
}

func (g *gatewayImpl) identify(token string) error {
	msg := Payload[IdentifyCmd]{
		Op: OP_IDENTIFY,
		Data: IdentifyCmd{
			Token:   token,
			Intents: 13991,
			Properties: IdentifyProperties{
				OS:      "linux",
				Browser: "jinx",
				Device:  "jinx",
			},
		},
		Sequence: 0,
	}

	if err := g.send(msg); err != nil {
		return err
	}

	var res Payload[ReadyEvent]
	if err := receive(g, &res); err != nil {
		return err
	}

	if res.Op != OP_DISPATCH {
		return fmt.Errorf("expected opcode %d, got %d", OP_DISPATCH, res.Op)
	}

	if res.EventName != "READY" {
		return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName)
	}

	g.sessionID = res.Data.SessionID

	return nil
}

func (g *gatewayImpl) resume() error {
	msg := Payload[ResumeCmd]{
		Op: OP_RESUME,
		Data: ResumeCmd{
			Token:     g.token,
			SessionID: g.sessionID,
			Sequence:  g.lastSeq,
		},
		Sequence: 0,
	}

	if err := g.send(msg); err != nil {
		return err
	}

	// From now on Discord will send us all missed events,
	// which we handle outside of the listener here.
	for {
		var msg Payload[json.RawMessage]
		if err := receive(g, &msg); err != nil {
			return err
		}

		if msg.Op == OP_DISPATCH && msg.EventName == "RESUMED" {
			// Done replaying.
			return nil
		}

		// Handle normal event.
		g.onEvent(msg)
	}
}

func (g *gatewayImpl) reconnect() error {
	if err := g.conn.Kill(); err != nil {
		return err
	}

	g.heartbeat.Stop()

	// Renew context.
	g.cancelRoutine()
	routineCtx, cancelRoutines := context.WithCancel(g.parentCtx)
	g.cancelRoutine = cancelRoutines

	connectHeader := http.Header{}
	conn, err := cancellablewebsocket.Dial(websocket.DefaultDialer, routineCtx, g.url, connectHeader)
	if err != nil {
		return err
	}

	g.conn = conn

	hello, err := g.hello()
	if err != nil {
		return err
	}

	heartbeatInterval := time.Duration(hello.HeartbeatInterval) * time.Millisecond
	g.heartbeat.Start(routineCtx, heartbeatInterval)

	// Replay all missed events.
	if err = g.resume(); err != nil {
		return err
	}

	g.conn.OnClose(func(code int, text string) error {
		g.onClose(code)
		return nil
	})

	go g.listen(routineCtx)

	return nil
}

func (g *gatewayImpl) listen(ctx context.Context) {
	for {
		var msg Payload[json.RawMessage]
		if err := receive(g, &msg); err != nil {
			g.logger.Error().Err(err).Msgf("error reading message")
			continue
		}

		select {
		case <-ctx.Done():
			return
		default:
			g.onEvent(msg)
		}
	}
}

func (g *gatewayImpl) onEvent(msg Payload[json.RawMessage]) error {
	switch msg.Op {
	case OP_HEARTBEAT_ACK:
		g.heartbeat.Ack()
	case OP_HEARTBEAT:
		g.heartbeat.ForceHeartbeat()
	case OP_DISPATCH:
		return g.onDispatch(msg.EventName, msg.Data)
	default:
		g.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
	}

	return nil
}

func (g *gatewayImpl) onDispatch(eventName string, body json.RawMessage) error {
	switch eventName {
	case "MESSAGE_CREATE":
		var gatewayEvent MessageCreateEvent
		if err := json.Unmarshal(body, &gatewayEvent); err != nil {
			return err
		}

		payload := events.Message(gatewayEvent)
		g.eventHandler.Fire(events.MESSAGE, payload)
	default:
		g.logger.Warn().Msgf("received unknown event: %s", eventName)
	}

	return nil
}

func (g *gatewayImpl) onClose(code int) {
	shouldReconnect := true

	switch code {
	case CloseUnknownError:
		g.logger.Warn().Msg("gateway closed because of an unknown error, reconnecting...")
	case CloseUnkownOpcode:
		g.logger.Warn().Msg("gateway closed after receiving an unknown opcode, reconnecting...")
	case CloseDecodeError:
		g.logger.Warn().Msg("gateway closed after receiving an invalid payload, reconnecting...")
	case CloseNotAuthenticated:
		g.logger.Warn().Msg("gateway closed because of missing authentication, reconnecting...")
	case CloseAlreadyAuthenticated:
		g.logger.Warn().Msg("gateway closed because the client is already authenticated, reconnecting...")
	case CloseInvalidSeq:
		g.logger.Warn().Msg("gateway closed after receiving an invalid sequence number, reconnecting...")
	case CloseRateLimited:
		// TODO: Implement rate limiting.
		g.logger.Warn().Msg("gateway rate-limited, reconnecting...")
	case CloseSessionTimedOut:
		g.logger.Warn().Msg("gateway closed because of a session timeout, reconnecting...")

	case CloseAuthenticationFailed:
		g.logger.Error().Msg("gateway closed because of an authentication error.")
		shouldReconnect = false
	case CloseInvalidShard:
		g.logger.Error().Msg("gateway closed because the given shard is invalid.")
		shouldReconnect = false
	case CloseShardingRequired:
		g.logger.Error().Msg("gateway closed because the bot requires sharding.")
		shouldReconnect = false
	case CloseInvalidAPIVersion:
		g.logger.Error().Msg("gateway closed because the given API version is invalid.")
		shouldReconnect = false
	case CloseInvalidIntents:
		g.logger.Error().Msg("gateway closed because the given intents are invalid.")
		shouldReconnect = false
	case CloseDisallowedIntents:
		g.logger.Error().Msg("gateway closed because the given intents are disallowed.")
		shouldReconnect = false
	default:
		g.logger.Error().Msgf("gateway closed with unknown code %d.", code)
		shouldReconnect = false
	}

	if shouldReconnect {
		if err := g.reconnect(); err != nil {
			g.logger.Error().Err(err).Msg("error reconnecting")
			return
		}
		g.logger.Info().Msg("successfully reconnected!")
	}
}

func (g *gatewayImpl) send(payload any) error {
	return g.conn.WriteJSON(payload)
}

func receive[D any](g *gatewayImpl, res *Payload[D]) error {
	err := g.conn.ReadJSON(&res)
	if err != nil {
		return err
	}

	if res.Sequence != 0 {
		g.lastSeq = res.Sequence
	}

	return nil
}