package discord import ( "context" "encoding/json" "errors" "fmt" "log" "math/rand" "net/http" "time" "github.com/gorilla/websocket" ) type Discord struct { token string conn *websocket.Conn eventHandler *EventHandlerImpl rest REST } func NewClient(token string) *Discord { token = "Bot " + token return &Discord{ token: token, conn: nil, eventHandler: NewEventHandler(), rest: NewREST(token), } } func (d *Discord) Connect(ctx context.Context) error { gatewayURL, err := d.rest.Gateway() if err != nil { return err } connectHeader := http.Header{} d.conn, _, err = websocket.DefaultDialer.Dial(gatewayURL, connectHeader) if err != nil { return err } var helloMsg GatewayPayload[GatewayHelloEvent] if err = d.conn.ReadJSON(&helloMsg); err != nil { return err } if helloMsg.Op != GATEWAY_OP_HELLO { return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_HELLO, helloMsg.Op) } go d.heartbeat(ctx, helloMsg.Data.HeartbeatInterval) if err = d.identify(); err != nil { return err } go d.listen(ctx) // We are ready! d.eventHandler.Fire(DISCORD_EVENT_READY, nil) return nil } func (d *Discord) Disconnect() error { if d.conn == nil { return errors.New("not connected") } return d.conn.Close() } func (d *Discord) heartbeat(ctx context.Context, interval uint64) { // REF: heartbeat_interval * jitter jitter := rand.Intn(int(interval)) time.Sleep(time.Duration(jitter) * time.Millisecond) ticker := time.NewTicker(time.Duration(interval) * time.Millisecond) defer ticker.Stop() for { select { case <-ctx.Done(): return case <-ticker.C: fmt.Println("sending heartbeat.") msg := GatewayPayload[any]{ Op: GATEWAY_OP_HEARTBEAT, } if err := d.conn.WriteJSON(msg); err != nil { log.Fatalf("error sending heartbeat: %s\n", err) } } } } func (d *Discord) identify() error { msg := GatewayPayload[GatewayIdentifyMsg]{ Op: GATEWAY_OP_IDENTIFY, Data: GatewayIdentifyMsg{ Token: d.token, Intents: 13991, Properties: GatewayIdentifyProperties{ OS: "linux", Browser: "jinx", Device: "jinx", }, }, Sequence: 0, } if err := d.conn.WriteJSON(msg); err != nil { return err } var res GatewayPayload[GatewayReadyEvent] if err := d.conn.ReadJSON(&res); err != nil { return err } fmt.Printf("identify response payload: %+v\n", res) if res.Op != GATEWAY_OP_DISPATCH { return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_DISPATCH, res.Op) } if res.EventName != "READY" { return fmt.Errorf("expected event name %s, got %s", "READY", res.EventName) } return nil } func (d *Discord) listen(ctx context.Context) { for { var msg GatewayPayload[json.RawMessage] if err := d.conn.ReadJSON(&msg); err != nil { log.Fatalf("error reading message: %s\n", err) } select { case <-ctx.Done(): return default: d.onEvent(msg) } } } func (d *Discord) onEvent(msg GatewayPayload[json.RawMessage]) error { switch msg.Op { case GATEWAY_OP_HEARTBEAT_ACK: fmt.Println("received heartbeat ack.") case GATEWAY_OP_HEARTBEAT: return errors.New("on demand heartbeat not implemented") case GATEWAY_OP_DISPATCH: return d.onDispatch(msg.EventName, msg.Data) default: fmt.Printf("received unknown opcode: %d\n", msg.Op) } return nil } func (d *Discord) onDispatch(eventName string, body json.RawMessage) error { switch eventName { case "MESSAGE_CREATE": var payload GatewayMessageCreateEvent if err := json.Unmarshal(body, &payload); err != nil { return err } d.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload) default: fmt.Println("received unknown event:", eventName) } return nil } func (d *Discord) SendMessage(channelID Snowflake, content string) error { return d.rest.SendMessage(channelID, content) } func (d *Discord) AddEventHandler(eventName DiscordEvent, handler func(payload any)) { d.eventHandler.Add(eventName, handler) }