173 lines
3.2 KiB
Go
173 lines
3.2 KiB
Go
package worker
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"log/slog"
|
|
"os"
|
|
"os/signal"
|
|
"sync"
|
|
"syscall"
|
|
"time"
|
|
|
|
"github.com/nats-io/nats.go"
|
|
)
|
|
|
|
type Worker struct {
|
|
js nats.JetStreamContext
|
|
sub *nats.Subscription
|
|
handler Handler
|
|
log *slog.Logger
|
|
concurrency int
|
|
}
|
|
|
|
func New(nc *nats.Conn, opts ...Option) (*Worker, error) {
|
|
cfg := &config{
|
|
concurrency: 1,
|
|
ackWait: 30 * time.Second,
|
|
maxDeliver: 1,
|
|
deliverPolicy: nats.DeliverAllPolicy,
|
|
replayPolicy: nats.ReplayInstantPolicy,
|
|
log: slog.Default(),
|
|
storage: nats.FileStorage,
|
|
}
|
|
|
|
for _, opt := range opts {
|
|
opt(cfg)
|
|
}
|
|
|
|
if cfg.handler == nil {
|
|
return nil, ErrInvalidHandler
|
|
}
|
|
|
|
if cfg.streamName == "" {
|
|
return nil, ErrStreamNameRequired
|
|
}
|
|
|
|
if cfg.filterSubject == "" {
|
|
return nil, ErrSubjectRequired
|
|
}
|
|
|
|
js, err := nc.JetStream()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
w := &Worker{
|
|
js: js,
|
|
log: cfg.log,
|
|
handler: cfg.handler,
|
|
concurrency: cfg.concurrency,
|
|
}
|
|
|
|
_, err = js.StreamInfo(cfg.streamName)
|
|
if err != nil {
|
|
if errors.Is(err, nats.ErrStreamNotFound) {
|
|
w.log.Info("creating stream", "name", cfg.streamName)
|
|
|
|
_, err = js.AddStream(&nats.StreamConfig{
|
|
Name: cfg.streamName,
|
|
Subjects: []string{cfg.filterSubject + ".>"},
|
|
Storage: cfg.storage,
|
|
})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
if cfg.consumerName == "" {
|
|
cfg.consumerName = cfg.durableName
|
|
}
|
|
|
|
consumerCfg := &nats.ConsumerConfig{
|
|
Durable: cfg.durableName,
|
|
AckPolicy: nats.AckExplicitPolicy,
|
|
AckWait: cfg.ackWait,
|
|
MaxDeliver: cfg.maxDeliver,
|
|
FilterSubject: cfg.filterSubject,
|
|
DeliverPolicy: cfg.deliverPolicy,
|
|
ReplayPolicy: cfg.replayPolicy,
|
|
}
|
|
|
|
_, err = js.AddConsumer(cfg.streamName, consumerCfg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
w.sub, err = js.PullSubscribe(cfg.filterSubject, cfg.consumerName, nats.BindStream(cfg.streamName))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return w, nil
|
|
}
|
|
|
|
func (w *Worker) Run() {
|
|
w.log.Info("background worker started")
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
defer cancel()
|
|
|
|
sigs := make(chan os.Signal, 1)
|
|
signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM)
|
|
go func() {
|
|
<-sigs
|
|
w.log.Info("shutting down background worker")
|
|
cancel()
|
|
}()
|
|
|
|
var wg sync.WaitGroup
|
|
for i := 0; i < w.concurrency; i++ {
|
|
wg.Add(1)
|
|
go func(id int) {
|
|
defer wg.Done()
|
|
w.processMessages(ctx, id)
|
|
}(i)
|
|
}
|
|
|
|
wg.Wait()
|
|
w.log.Info("background worker stopped")
|
|
}
|
|
|
|
func (w *Worker) processMessages(ctx context.Context, workerID int) {
|
|
batchSize := 10
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
default:
|
|
msgs, err := w.sub.Fetch(
|
|
batchSize,
|
|
nats.Context(ctx),
|
|
nats.MaxWait(5*time.Second),
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, nats.ErrTimeout) {
|
|
continue
|
|
}
|
|
w.log.Error("error fetching messages", "error", err)
|
|
time.Sleep(time.Second)
|
|
continue
|
|
}
|
|
|
|
for _, msg := range msgs {
|
|
if err := w.handler(ctx, msg); err != nil {
|
|
w.log.Error("error processing message", "error", err)
|
|
msg.Nak()
|
|
} else {
|
|
msg.Ack()
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (w *Worker) Shutdown() {
|
|
if w.sub != nil {
|
|
w.sub.Drain()
|
|
}
|
|
}
|