about summary refs log tree commit diff
path: root/pkg/discord
diff options
context:
space:
mode:
authorMel <einebeere@gmail.com>2022-04-04 13:26:19 +0200
committerMel <einebeere@gmail.com>2022-04-04 13:26:19 +0200
commit165ed818775c915e4bfd2599bdb8ca8e2975bb83 (patch)
treecc653962e6c001952c73c56c20fa5c965758f365 /pkg/discord
parentfdd0ea7911b2c98f95ef99f6d1518ee4eb4dfd7a (diff)
downloadjinx-165ed818775c915e4bfd2599bdb8ca8e2975bb83.tar.zst
jinx-165ed818775c915e4bfd2599bdb8ca8e2975bb83.zip
Extract REST API
Diffstat (limited to 'pkg/discord')
-rw-r--r--pkg/discord/discord.go126
-rw-r--r--pkg/discord/rest.go111
2 files changed, 130 insertions, 107 deletions
diff --git a/pkg/discord/discord.go b/pkg/discord/discord.go
index c5e7d67..fd36c37 100644
--- a/pkg/discord/discord.go
+++ b/pkg/discord/discord.go
@@ -1,9 +1,7 @@
 package discord
 
 import (
-	"bytes"
 	"context"
-	"encoding/json"
 	"errors"
 	"fmt"
 	"log"
@@ -14,47 +12,42 @@ import (
 	"github.com/gorilla/websocket"
 )
 
-const DISCORD_URl = "https://discord.com/api/v9/"
-const USER_AGENT = "DiscordBot (https://jinx.rnrd.eu/, v0.0.0) Jinx"
-
 type Discord struct {
-	Token string
-	Conn  *websocket.Conn
+	token        string
+	conn         *websocket.Conn
+	rest         REST
 }
 
 func NewClient(token string) *Discord {
 	return &Discord{
-		Token: token,
-		Conn:  nil,
+		token:        token,
+		conn:         nil,
+		rest:         NewREST(token),
 	}
 }
 
 func (d *Discord) Connect(ctx context.Context) error {
-	gatewayURL, err := d.getGateway()
+	gatewayURL, err := d.rest.Gateway()
 	if err != nil {
 		return err
 	}
 
-	fmt.Printf("gateway: %s\n", gatewayURL)
-
 	connectHeader := http.Header{}
-	d.Conn, _, err = websocket.DefaultDialer.Dial(gatewayURL, connectHeader)
+	d.conn, _, err = websocket.DefaultDialer.Dial(gatewayURL, connectHeader)
 	if err != nil {
 		return err
 	}
 
 	var helloMsg GatewayPayload[GatewayHelloMsg]
-	if err = d.Conn.ReadJSON(&helloMsg); err != nil {
+	if err = d.conn.ReadJSON(&helloMsg); err != nil {
 		return err
 	}
 
-	fmt.Printf("connection response Payload: %+v\n", helloMsg)
-
 	if helloMsg.Op != GATEWAY_OP_HELLO {
 		return fmt.Errorf("expected opcode %d, got %d", GATEWAY_OP_HELLO, helloMsg.Op)
 	}
 
-	go d.startHeartbeat(ctx, helloMsg.Data.HeartbeatInterval)
+	go d.heartbeat(ctx, helloMsg.Data.HeartbeatInterval)
 
 	if err = d.identify(); err != nil {
 		return err
@@ -66,58 +59,14 @@ func (d *Discord) Connect(ctx context.Context) error {
 }
 
 func (d *Discord) Disconnect() error {
-	if d.Conn == nil {
+	if d.conn == nil {
 		return errors.New("not connected")
 	}
 
-	return d.Conn.Close()
-}
-
-func (d *Discord) getGateway() (string, error) {
-	url := DISCORD_URl + "gateway"
-
-	req, err := http.NewRequest("GET", url, nil)
-	if err != nil {
-		return "", err
-	}
-
-	req.Header.Set("Authorization", d.Token)
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("User-Agent", USER_AGENT)
-
-	resp, err := http.DefaultClient.Do(req)
-	if err != nil {
-		return "", err
-	}
-
-	defer resp.Body.Close()
-
-	var buf bytes.Buffer
-	_, err = buf.ReadFrom(resp.Body)
-	if err != nil {
-		return "", err
-	}
-
-	switch resp.StatusCode {
-	case 200:
-	default:
-		return "", errors.New("gateway response status code: " + resp.Status)
-	}
-
-	body := struct {
-		URL string `json:"url"`
-	}{}
-
-	err = json.Unmarshal(buf.Bytes(), &body)
-	if err != nil {
-		return "", err
-	}
-
-	url = body.URL + "?v=9&encoding=json"
-	return url, nil
+	return d.conn.Close()
 }
 
-func (d *Discord) startHeartbeat(ctx context.Context, interval uint64) {
+func (d *Discord) heartbeat(ctx context.Context, interval uint64) {
 	// REF: heartbeat_interval * jitter
 	jitter := rand.Intn(int(interval))
 	time.Sleep(time.Duration(jitter) * time.Millisecond)
@@ -136,7 +85,7 @@ func (d *Discord) startHeartbeat(ctx context.Context, interval uint64) {
 				Op: GATEWAY_OP_HEARTBEAT,
 			}
 
-			if err := d.Conn.WriteJSON(msg); err != nil {
+			if err := d.conn.WriteJSON(msg); err != nil {
 				log.Fatalf("error sending heartbeat: %s\n", err)
 			}
 		}
@@ -146,7 +95,7 @@ func (d *Discord) startHeartbeat(ctx context.Context, interval uint64) {
 func (d *Discord) listen(ctx context.Context) {
 	for {
 		var msg GatewayPayload[any]
-		if err := d.Conn.ReadJSON(&msg); err != nil {
+		if err := d.conn.ReadJSON(&msg); err != nil {
 			log.Fatalf("error reading message: %s\n", err)
 		}
 
@@ -160,7 +109,7 @@ func (d *Discord) listen(ctx context.Context) {
 				event := msg.Data.(map[string]interface{})
 				if event["content"] == "ping" {
 					fmt.Println("got ping, sending pong...")
-					if err := d.sendMessage(Snowflake(event["channel_id"].(string)), "pong"); err != nil {
+					if err := d.rest.SendMessage(Snowflake(event["channel_id"].(string)), "pong"); err != nil {
 						log.Fatalf("error sending message: %s\n", err)
 					}
 				}
@@ -173,7 +122,7 @@ func (d *Discord) identify() error {
 	msg := GatewayPayload[GatewayIdentifyMsg]{
 		Op: GATEWAY_OP_IDENTIFY,
 		Data: GatewayIdentifyMsg{
-			Token:   d.Token,
+			Token:   d.token,
 			Intents: 13991,
 			Properties: GatewayIdentifyProperties{
 				OS:      "linux",
@@ -184,12 +133,12 @@ func (d *Discord) identify() error {
 		Sequence: 0,
 	}
 
-	if err := d.Conn.WriteJSON(msg); err != nil {
+	if err := d.conn.WriteJSON(msg); err != nil {
 		return err
 	}
 
 	var res GatewayPayload[GatewayReadyMsg]
-	if err := d.Conn.ReadJSON(&res); err != nil {
+	if err := d.conn.ReadJSON(&res); err != nil {
 		return err
 	}
 
@@ -205,40 +154,3 @@ func (d *Discord) identify() error {
 
 	return nil
 }
-
-func (d *Discord) sendMessage(channelID Snowflake, content string) error {
-	url := DISCORD_URl + "channels/" + string(channelID) + "/messages"
-
-	msg := struct {
-		Content string `json:"content"`
-	}{
-		Content: content,
-	}
-
-	raw, err := json.Marshal(msg)
-	if err != nil {
-		return err
-	}
-
-	req, err := http.NewRequest("POST", url, bytes.NewBuffer(raw))
-	if err != nil {
-		return err
-	}
-
-	req.Header.Set("Authorization", "Bot "+d.Token)
-	req.Header.Set("Content-Type", "application/json")
-	req.Header.Set("User-Agent", USER_AGENT)
-
-	res, err := http.DefaultClient.Do(req)
-	if err != nil {
-		return err
-	}
-
-	switch res.StatusCode {
-	case 200:
-	default:
-		return errors.New("unexpected status code after sending message: " + res.Status)
-	}
-
-	return nil
-}
diff --git a/pkg/discord/rest.go b/pkg/discord/rest.go
new file mode 100644
index 0000000..438d8d1
--- /dev/null
+++ b/pkg/discord/rest.go
@@ -0,0 +1,111 @@
+package discord
+
+import (
+	"bytes"
+	"encoding/json"
+	"errors"
+	"net/http"
+	"time"
+)
+
+const DISCORD_URl = "https://discord.com/api/v9/"
+const USER_AGENT = "DiscordBot (https://jinx.rnrd.eu/, v0.0.0) Jinx"
+
+type REST interface {
+	Gateway() (string, error)
+
+	SendMessage(channelID Snowflake, content string) error
+}
+
+var _ REST = &RESTImpl{}
+
+type RESTImpl struct {
+	token  string
+	client *http.Client
+}
+
+func NewREST(token string) *RESTImpl {
+	return &RESTImpl{
+		token: token,
+		client: &http.Client{
+			Timeout: time.Second * 5,
+		},
+	}
+}
+
+func (r *RESTImpl) Gateway() (string, error) {
+	type gatewayResponse struct {
+		URL string `json:"url"`
+	}
+
+	res, err := request[gatewayResponse](r, "GET", url("gateway"), nil)
+	if err != nil {
+		return "", err
+	}
+
+	return res.URL + "?v=9&encoding=json", nil
+}
+
+func (r *RESTImpl) SendMessage(channelID Snowflake, content string) error {
+	msg := struct {
+		Content string `json:"content"`
+	}{
+		Content: content,
+	}
+
+	_, err := request[any](r, "POST", url("channels/"+string(channelID)+"/messages"), msg)
+	if err != nil {
+		return err
+	}
+
+	return nil
+}
+
+func request[D any](r *RESTImpl, method string, url string, data any) (*D, error) {
+	var raw []byte
+	if data != nil {
+		var err error
+		raw, err = json.Marshal(data)
+		if err != nil {
+			return nil, err
+		}
+	}
+
+	req, err := http.NewRequest(method, url, bytes.NewBuffer(raw))
+	if err != nil {
+		return nil, err
+	}
+
+	req.Header.Set("Authorization", r.token)
+	req.Header.Set("Content-Type", "application/json")
+	req.Header.Set("User-Agent", USER_AGENT)
+
+	resp, err := r.client.Do(req)
+	if err != nil {
+		return nil, err
+	}
+
+	defer resp.Body.Close()
+
+	switch resp.StatusCode {
+	case 200:
+	default:
+		return nil, errors.New("unexpected status code: " + resp.Status)
+	}
+
+	var buf bytes.Buffer
+	if _, err = buf.ReadFrom(resp.Body); err != nil {
+		return nil, err
+	}
+
+	var res D
+	if err = json.Unmarshal(buf.Bytes(), &res); err != nil {
+		return nil, err
+	}
+
+	return &res, nil
+}
+
+func url(path string) string {
+	return DISCORD_URl + path
+}