diff options
| -rw-r--r-- | pkg/bot/bot.go | 20 | ||||
| -rw-r--r-- | pkg/discord/discord.go | 93 | ||||
| -rw-r--r-- | pkg/discord/event_handler.go | 36 | ||||
| -rw-r--r-- | pkg/discord/payloads.go | 7 |
4 files changed, 122 insertions, 34 deletions
diff --git a/pkg/bot/bot.go b/pkg/bot/bot.go index 091e1be..9095d7c 100644 --- a/pkg/bot/bot.go +++ b/pkg/bot/bot.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "jinx/pkg/discord" + "log" ) type Bot struct { @@ -16,16 +17,27 @@ func Start(token string) (*Bot, error) { client := discord.NewClient(token) - ctx, cancel := context.WithCancel(context.Background()) + client.AddEventHandler(discord.DISCORD_EVENT_READY, func(_ any) { + log.Println("bot is ready!") + }) + + client.AddEventHandler(discord.DISCORD_EVENT_MESSAGE, func(m any) { + msg := m.(discord.GatewayMessageCreateEvent) + log.Printf("message: %s", msg.Content) + + if msg.Content == "ping" { + if err := client.SendMessage(msg.ChannelID, "pong"); err != nil { + log.Printf("error sending message: %s", err) + } + } + }) - fmt.Println("connecting..") + ctx, cancel := context.WithCancel(context.Background()) if err := client.Connect(ctx); err != nil { cancel() return nil, err } - fmt.Println("connected..") - return &Bot{ Client: client, cancelContext: cancel, diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go index fd36c37..fd60edd 100644 --- a/pkg/discord/discord.go +++ b/pkg/discord/discord.go @@ -2,6 +2,7 @@ package discord import ( "context" + "encoding/json" "errors" "fmt" "log" @@ -15,13 +16,17 @@ import ( type Discord struct { token string conn *websocket.Conn + eventHandler *EventHandlerImpl rest REST } func NewClient(token string) *Discord { + token = "Bot " + token + return &Discord{ token: token, conn: nil, + eventHandler: NewEventHandler(), rest: NewREST(token), } } @@ -38,7 +43,7 @@ func (d *Discord) Connect(ctx context.Context) error { return err } - var helloMsg GatewayPayload[GatewayHelloMsg] + var helloMsg GatewayPayload[GatewayHelloEvent] if err = d.conn.ReadJSON(&helloMsg); err != nil { return err } @@ -55,6 +60,9 @@ func (d *Discord) Connect(ctx context.Context) error { go d.listen(ctx) + // We are ready! + d.eventHandler.Fire(DISCORD_EVENT_READY, nil) + return nil } @@ -92,32 +100,6 @@ func (d *Discord) heartbeat(ctx context.Context, interval uint64) { } } -func (d *Discord) listen(ctx context.Context) { - for { - var msg GatewayPayload[any] - if err := d.conn.ReadJSON(&msg); err != nil { - log.Fatalf("error reading message: %s\n", err) - } - - select { - case <-ctx.Done(): - return - default: - fmt.Printf("received message: %+v\n", msg) - - if msg.EventName == "MESSAGE_CREATE" { - event := msg.Data.(map[string]interface{}) - if event["content"] == "ping" { - fmt.Println("got ping, sending pong...") - if err := d.rest.SendMessage(Snowflake(event["channel_id"].(string)), "pong"); err != nil { - log.Fatalf("error sending message: %s\n", err) - } - } - } - } - } -} - func (d *Discord) identify() error { msg := GatewayPayload[GatewayIdentifyMsg]{ Op: GATEWAY_OP_IDENTIFY, @@ -137,7 +119,7 @@ func (d *Discord) identify() error { return err } - var res GatewayPayload[GatewayReadyMsg] + var res GatewayPayload[GatewayReadyEvent] if err := d.conn.ReadJSON(&res); err != nil { return err } @@ -154,3 +136,58 @@ func (d *Discord) identify() error { return nil } + +func (d *Discord) listen(ctx context.Context) { + for { + var msg GatewayPayload[json.RawMessage] + if err := d.conn.ReadJSON(&msg); err != nil { + log.Fatalf("error reading message: %s\n", err) + } + + select { + case <-ctx.Done(): + return + default: + d.onEvent(msg) + } + } +} + +func (d *Discord) onEvent(msg GatewayPayload[json.RawMessage]) error { + switch msg.Op { + case GATEWAY_OP_HEARTBEAT_ACK: + fmt.Println("received heartbeat ack.") + case GATEWAY_OP_HEARTBEAT: + return errors.New("on demand heartbeat not implemented") + case GATEWAY_OP_DISPATCH: + return d.onDispatch(msg.EventName, msg.Data) + default: + fmt.Printf("received unknown opcode: %d\n", msg.Op) + } + + return nil +} + +func (d *Discord) onDispatch(eventName string, body json.RawMessage) error { + switch eventName { + case "MESSAGE_CREATE": + var payload GatewayMessageCreateEvent + if err := json.Unmarshal(body, &payload); err != nil { + return err + } + + d.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload) + default: + fmt.Println("received unknown event:", eventName) + } + + return nil +} + +func (d *Discord) SendMessage(channelID Snowflake, content string) error { + return d.rest.SendMessage(channelID, content) +} + +func (d *Discord) AddEventHandler(eventName DiscordEvent, handler func(payload any)) { + d.eventHandler.Add(eventName, handler) +} diff --git a/pkg/discord/event_handler.go b/pkg/discord/event_handler.go new file mode 100644 index 0000000..6f3ded5 --- /dev/null +++ b/pkg/discord/event_handler.go @@ -0,0 +1,36 @@ +package discord + +type DiscordEvent uint8 + +const ( + DISCORD_EVENT_READY DiscordEvent = iota + DISCORD_EVENT_MESSAGE +) + +type EventHandler interface { + Add(event DiscordEvent, handler func(payload any)) + + Fire(event DiscordEvent, payload any) +} + +var _ EventHandler = &EventHandlerImpl{} + +type EventHandlerImpl struct { + handlers map[DiscordEvent]func(payload any) +} + +func NewEventHandler() *EventHandlerImpl { + return &EventHandlerImpl{ + handlers: make(map[DiscordEvent]func(payload any)), + } +} + +func (h *EventHandlerImpl) Add(event DiscordEvent, handler func(payload any)) { + h.handlers[event] = handler +} + +func (h *EventHandlerImpl) Fire(event DiscordEvent, payload any) { + if handler, ok := h.handlers[event]; ok { + handler(payload) + } +} diff --git a/pkg/discord/payloads.go b/pkg/discord/payloads.go index d7ab6d4..81c6c68 100644 --- a/pkg/discord/payloads.go +++ b/pkg/discord/payloads.go @@ -28,11 +28,12 @@ type GatewayIdentifyMsg struct { Intents uint64 `json:"intents"` Properties GatewayIdentifyProperties `json:"properties"` } -type GatewayHelloMsg struct { + +type GatewayHelloEvent struct { HeartbeatInterval uint64 `json:"heartbeat_interval"` } -type GatewayReadyMsg struct { +type GatewayReadyEvent struct { Version uint64 `json:"v"` User User `json:"user"` Guilds []Guild `json:"guilds"` @@ -40,6 +41,8 @@ type GatewayReadyMsg struct { Shard []int `json:"shard"` } +type GatewayMessageCreateEvent Message + type GatewayIdentifyProperties struct { OS string `json:"$os"` Browser string `json:"$browser"` |
