about summary refs log tree commit diff
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/bot/bot.go57
-rw-r--r--pkg/discord/discord.go20
2 files changed, 47 insertions, 30 deletions
diff --git a/pkg/bot/bot.go b/pkg/bot/bot.go
index 9095d7c..add3e56 100644
--- a/pkg/bot/bot.go
+++ b/pkg/bot/bot.go
@@ -2,49 +2,64 @@ package bot
 
 import (
 	"context"
-	"fmt"
+	"errors"
 	"jinx/pkg/discord"
-	"log"
+
+	"github.com/rs/zerolog"
 )
 
 type Bot struct {
-	Client        *discord.Discord
+	client        *discord.Discord
+	logger        *zerolog.Logger
 	cancelContext context.CancelFunc
 }
 
-func Start(token string) (*Bot, error) {
-	fmt.Println("hi..!")
+func NewBot(token string, logger *zerolog.Logger) *Bot {
+	return &Bot{
+		client:        discord.NewClient(token, logger),
+		logger:        logger,
+		cancelContext: nil,
+	}
+}
 
-	client := discord.NewClient(token)
+func (b *Bot) Start() error {
+	var err error
+	ctx, cancel := context.WithCancel(context.Background())
+	b.cancelContext = cancel
 
-	client.AddEventHandler(discord.DISCORD_EVENT_READY, func(_ any) {
-		log.Println("bot is ready!")
+	defer func() {
+		if err != nil {
+			cancel()
+		}
+	}()
+
+	b.client.AddEventHandler(discord.DISCORD_EVENT_READY, func(_ any) {
+		b.logger.Info().Msg("ready!")
 	})
 
-	client.AddEventHandler(discord.DISCORD_EVENT_MESSAGE, func(m any) {
+	b.client.AddEventHandler(discord.DISCORD_EVENT_MESSAGE, func(m any) {
 		msg := m.(discord.GatewayMessageCreateEvent)
-		log.Printf("message: %s", msg.Content)
+		b.logger.Debug().Msgf("message: %s", msg.Content)
 
 		if msg.Content == "ping" {
-			if err := client.SendMessage(msg.ChannelID, "pong"); err != nil {
-				log.Printf("error sending message: %s", err)
+			if err := b.client.SendMessage(msg.ChannelID, "pong"); err != nil {
+				b.logger.Error().Err(err).Msg("error sending message")
 			}
 		}
 	})
 
-	ctx, cancel := context.WithCancel(context.Background())
-	if err := client.Connect(ctx); err != nil {
-		cancel()
-		return nil, err
+	if err := b.client.Connect(ctx); err != nil {
+		return err
 	}
 
-	return &Bot{
-		Client:        client,
-		cancelContext: cancel,
-	}, nil
+	return nil
 }
 
 func (b *Bot) Stop() error {
+	if b.cancelContext == nil {
+		return errors.New("bot has no context")
+	}
+
 	b.cancelContext()
-	return b.Client.Disconnect()
+	return b.client.Disconnect()
 }
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index fd60edd..bd8a7f7 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -5,26 +5,28 @@ import (
 	"encoding/json"
 	"errors"
 	"fmt"
-	"log"
 	"math/rand"
 	"net/http"
 	"time"
 
 	"github.com/gorilla/websocket"
+	"github.com/rs/zerolog"
 )
 
 type Discord struct {
 	token        string
+	logger       *zerolog.Logger
 	conn         *websocket.Conn
 	eventHandler *EventHandlerImpl
 	rest         REST
 }
 
-func NewClient(token string) *Discord {
+func NewClient(token string, logger *zerolog.Logger) *Discord {
 	token = "Bot " + token
 
 	return &Discord{
 		token:        token,
+		logger:       logger,
 		conn:         nil,
 		eventHandler: NewEventHandler(),
 		rest:         NewREST(token),
@@ -87,14 +89,14 @@ func (d *Discord) heartbeat(ctx context.Context, interval uint64) {
 		case <-ctx.Done():
 			return
 		case <-ticker.C:
-			fmt.Println("sending heartbeat.")
+			d.logger.Debug().Msg("sending heartbeat")
 
 			msg := GatewayPayload[any]{
 				Op: GATEWAY_OP_HEARTBEAT,
 			}
 
 			if err := d.conn.WriteJSON(msg); err != nil {
-				log.Fatalf("error sending heartbeat: %s\n", err)
+				d.logger.Error().Err(err).Msg("error sending heartbeat")
 			}
 		}
 	}
@@ -124,7 +126,7 @@ func (d *Discord) identify() error {
 		return err
 	}
 
-	fmt.Printf("identify response payload: %+v\n", res)
+	d.logger.Debug().Msgf("identify response payload: %+v", res)
 
 	if res.Op != GATEWAY_OP_DISPATCH {
 		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_DISPATCH, res.Op)
@@ -141,7 +143,7 @@ func (d *Discord) listen(ctx context.Context) {
 	for {
 		var msg GatewayPayload[json.RawMessage]
 		if err := d.conn.ReadJSON(&msg); err != nil {
-			log.Fatalf("error reading message: %s\n", err)
+			d.logger.Error().Err(err).Msg("error reading message")
 		}
 
 		select {
@@ -156,13 +158,13 @@ func (d *Discord) listen(ctx context.Context) {
 func (d *Discord) onEvent(msg GatewayPayload[json.RawMessage]) error {
 	switch msg.Op {
 	case GATEWAY_OP_HEARTBEAT_ACK:
-		fmt.Println("received heartbeat ack.")
+		d.logger.Debug().Msg("received heartbeat ack.")
 	case GATEWAY_OP_HEARTBEAT:
 		return errors.New("on demand heartbeat not implemented")
 	case GATEWAY_OP_DISPATCH:
 		return d.onDispatch(msg.EventName, msg.Data)
 	default:
-		fmt.Printf("received unknown opcode: %d\n", msg.Op)
+		d.logger.Warn().Msgf("received unknown opcode: %d", msg.Op)
 	}
 
 	return nil
@@ -178,7 +180,7 @@ func (d *Discord) onDispatch(eventName string, body json.RawMessage) error {
 
 		d.eventHandler.Fire(DISCORD_EVENT_MESSAGE, payload)
 	default:
-		fmt.Println("received unknown event:", eventName)
+		d.logger.Warn().Msgf("received unknown event: %s", eventName)
 	}
 
 	return nil