about summary refs log tree commit diff
path: root/pkg/discord/gateway/heartbeat.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/discord/gateway/heartbeat.go')
-rw-r--r--pkg/discord/gateway/heartbeat.go37
1 files changed, 23 insertions, 14 deletions
diff --git a/pkg/discord/gateway/heartbeat.go b/pkg/discord/gateway/heartbeat.go
index 1df753a..1ebe964 100644
--- a/pkg/discord/gateway/heartbeat.go
+++ b/pkg/discord/gateway/heartbeat.go
@@ -10,42 +10,51 @@ import (
 
 type Heartbeat interface {
 	Start(ctx context.Context, interval time.Duration)
+	Stop()
 
 	ForceHeartbeat()
 	Ack()
 }
 
-var _ Heartbeat = &HeartbeatImpl{}
+var _ Heartbeat = &heartbeatImpl{}
 
-type HeartbeatImpl struct {
-	ctx     context.Context
+type heartbeatImpl struct {
 	logger  *zerolog.Logger
 	gateway Gateway
+
+	cancelRoutine context.CancelFunc
 }
 
-func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *HeartbeatImpl {
-	return &HeartbeatImpl{
-		ctx:     nil,
+func NewHeartbeat(logger *zerolog.Logger, gateway Gateway) *heartbeatImpl {
+	return &heartbeatImpl{
 		logger:  logger,
 		gateway: gateway,
+
+		cancelRoutine: nil,
 	}
 }
 
-func (h *HeartbeatImpl) Start(ctx context.Context, interval time.Duration) {
-	h.ctx = ctx
-	go h.heartbeatRoutine(interval)
+func (h *heartbeatImpl) Start(ctx context.Context, interval time.Duration) {
+	routineCtx, cancel := context.WithCancel(ctx)
+	h.cancelRoutine = cancel
+	go h.heartbeatRoutine(routineCtx, interval)
+}
+
+func (h *heartbeatImpl) Stop() {
+	h.cancelRoutine()
+	h.cancelRoutine = nil
 }
 
-func (h *HeartbeatImpl) ForceHeartbeat() {
+func (h *heartbeatImpl) ForceHeartbeat() {
 	h.gateway.Heartbeat()
 }
 
-func (h *HeartbeatImpl) Ack() {
+func (h *heartbeatImpl) Ack() {
 	// What do we do here?
 	h.logger.Debug().Msg("received heartbeat ack")
 }
 
-func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
+func (h *heartbeatImpl) heartbeatRoutine(ctx context.Context, interval time.Duration) {
 	h.logger.Debug().Msgf("beating heart every %dms", interval.Milliseconds())
 
 	// REF: heartbeat_interval * jitter
@@ -53,7 +62,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
 
 	select {
 	case <-time.After(time.Duration(jitter)):
-	case <-h.ctx.Done():
+	case <-ctx.Done():
 		h.logger.Debug().Msg("heartbeat routine stopped before jitter heartbeat")
 		return
 	}
@@ -68,7 +77,7 @@ func (h *HeartbeatImpl) heartbeatRoutine(interval time.Duration) {
 		}
 
 		select {
-		case <-h.ctx.Done():
+		case <-ctx.Done():
 			return
 		case <-ticker.C:
 			continue