Files
web/worker/worker.go

173 lines
3.2 KiB
Go

package worker
import (
"context"
"errors"
"log/slog"
"os"
"os/signal"
"sync"
"syscall"
"time"
"github.com/nats-io/nats.go"
)
type Worker struct {
js nats.JetStreamContext
sub *nats.Subscription
handler Handler
log *slog.Logger
concurrency int
}
func New(nc *nats.Conn, opts ...Option) (*Worker, error) {
cfg := &config{
concurrency: 1,
ackWait: 30 * time.Second,
maxDeliver: 1,
deliverPolicy: nats.DeliverAllPolicy,
replayPolicy: nats.ReplayInstantPolicy,
log: slog.Default(),
storage: nats.FileStorage,
}
for _, opt := range opts {
opt(cfg)
}
if cfg.handler == nil {
return nil, ErrInvalidHandler
}
if cfg.streamName == "" {
return nil, ErrStreamNameRequired
}
if cfg.filterSubject == "" {
return nil, ErrSubjectRequired
}
js, err := nc.JetStream()
if err != nil {
return nil, err
}
w := &Worker{
js: js,
log: cfg.log,
handler: cfg.handler,
concurrency: cfg.concurrency,
}
_, err = js.StreamInfo(cfg.streamName)
if err != nil {
if errors.Is(err, nats.ErrStreamNotFound) {
w.log.Info("creating stream", "name", cfg.streamName)
_, err = js.AddStream(&nats.StreamConfig{
Name: cfg.streamName,
Subjects: []string{cfg.filterSubject + ".>"},
Storage: cfg.storage,
})
if err != nil {
return nil, err
}
} else {
return nil, err
}
}
if cfg.consumerName == "" {
cfg.consumerName = cfg.durableName
}
consumerCfg := &nats.ConsumerConfig{
Durable: cfg.durableName,
AckPolicy: nats.AckExplicitPolicy,
AckWait: cfg.ackWait,
MaxDeliver: cfg.maxDeliver,
FilterSubject: cfg.filterSubject,
DeliverPolicy: cfg.deliverPolicy,
ReplayPolicy: cfg.replayPolicy,
}
_, err = js.AddConsumer(cfg.streamName, consumerCfg)
if err != nil {
return nil, err
}
w.sub, err = js.PullSubscribe(cfg.filterSubject, cfg.consumerName, nats.BindStream(cfg.streamName))
if err != nil {
return nil, err
}
return w, nil
}
func (w *Worker) Run() {
w.log.Info("background worker started")
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sigs := make(chan os.Signal, 1)
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
go func() {
<-sigs
w.log.Info("shutting down background worker")
cancel()
}()
var wg sync.WaitGroup
for i := 0; i < w.concurrency; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
w.processMessages(ctx, id)
}(i)
}
wg.Wait()
w.log.Info("background worker stopped")
}
func (w *Worker) processMessages(ctx context.Context, workerID int) {
batchSize := 10
for {
select {
case <-ctx.Done():
return
default:
msgs, err := w.sub.Fetch(
batchSize,
nats.Context(ctx),
nats.MaxWait(5*time.Second),
)
if err != nil {
if errors.Is(err, nats.ErrTimeout) {
continue
}
w.log.Error("error fetching messages", "error", err)
time.Sleep(time.Second)
continue
}
for _, msg := range msgs {
if err := w.handler(ctx, msg); err != nil {
w.log.Error("error processing message", "error", err)
msg.Nak()
} else {
msg.Ack()
}
}
}
}
}
func (w *Worker) Shutdown() {
if w.sub != nil {
w.sub.Drain()
}
}