6 Commits
v0.0.1 ... main

Author SHA1 Message Date
b457c74d49 fix auth stuff 2025-12-18 14:12:00 -05:00
9278ddf029 fix auth stuff 2025-12-18 14:06:07 -05:00
ec095a3955 fix auth stuff 2025-12-18 14:04:26 -05:00
7bc4e9b846 fix middleware 2025-12-18 13:40:38 -05:00
1c9688efd0 fixed an issue with session manager interface 2025-12-18 13:30:02 -05:00
fd0458cc08 fixed an issue with session manager interface 2025-12-18 13:26:00 -05:00
6 changed files with 71 additions and 35 deletions

View File

@@ -2,6 +2,7 @@ package auth0
import ( import (
"context" "context"
"encoding/gob"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@@ -11,6 +12,10 @@ import (
"git.citc.tech/go/web/auth/auth0/authenticator" "git.citc.tech/go/web/auth/auth0/authenticator"
) )
func init() {
gob.Register(SessionUser{})
}
type Logger interface { type Logger interface {
Debug(msg string, args ...any) Debug(msg string, args ...any)
Info(msg string, args ...any) Info(msg string, args ...any)
@@ -19,7 +24,7 @@ type Logger interface {
type SessionManager interface { type SessionManager interface {
Get(ctx context.Context, key string) any Get(ctx context.Context, key string) any
Put(ctx context.Context, key string, value any) error Put(ctx context.Context, key string, value any)
} }
type Config struct { type Config struct {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@@ -34,11 +35,7 @@ func HandleLogin(deps *deps) http.HandlerFunc {
deps.log.Info("generated state", "state", state) deps.log.Info("generated state", "state", state)
if err = deps.sessions.Put(r.Context(), StateKey, state); err != nil { deps.sessions.Put(r.Context(), StateKey, state)
deps.log.Error("unable to store state in session", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
http.Redirect(w, r, deps.auth.AuthCodeURL(state), http.StatusFound) http.Redirect(w, r, deps.auth.AuthCodeURL(state), http.StatusFound)
} }
@@ -46,13 +43,8 @@ func HandleLogin(deps *deps) http.HandlerFunc {
func HandleLogout(deps *deps) http.HandlerFunc { func HandleLogout(deps *deps) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if err := deps.sessions.Put(r.Context(), "user", nil); err != nil { deps.sessions.Put(r.Context(), "user", nil)
deps.log.Error("unable to remove user from session", "error", err) deps.sessions.Put(r.Context(), StateKey, nil)
}
if err := deps.sessions.Put(r.Context(), StateKey, nil); err != nil {
deps.log.Error("unable to remove state from session", "error", err)
}
scheme := "http" scheme := "http"
if r.TLS != nil { if r.TLS != nil {
@@ -85,9 +77,7 @@ func HandleCallback(deps *deps) http.HandlerFunc {
return return
} }
if err := deps.sessions.Put(r.Context(), StateKey, nil); err != nil { deps.sessions.Put(r.Context(), StateKey, nil)
deps.log.Error("unable to remove state from session", "error", err)
}
token, err := deps.auth.Exchange(r.Context(), r.URL.Query().Get("code")) token, err := deps.auth.Exchange(r.Context(), r.URL.Query().Get("code"))
if err != nil { if err != nil {
@@ -103,24 +93,40 @@ func HandleCallback(deps *deps) http.HandlerFunc {
return return
} }
var profile map[string]any var rawClaims map[string]json.RawMessage
if err = idToken.Claims(&profile); err != nil { if err = idToken.Claims(&rawClaims); err != nil {
deps.log.Error("unable to decode ID token claims", "error", err) deps.log.Error("unable to decode ID token claims", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
if err = deps.sessions.Put(r.Context(), "user", profile); err != nil { var user SessionUser
deps.log.Error("unable to store user profile in session", "error", err) if sub, ok := rawClaims["sub"]; ok {
http.Error(w, err.Error(), http.StatusInternalServerError) json.Unmarshal(sub, &user)
return }
if name, ok := rawClaims["name"]; ok {
json.Unmarshal(name, &user.Name)
}
if email, ok := rawClaims["email"]; ok {
json.Unmarshal(email, &user.Email)
}
if picture, ok := rawClaims["picture"]; ok {
json.Unmarshal(picture, &user.Picture)
} }
if err = deps.sessions.Put(r.Context(), "access_token", token.AccessToken); err != nil { customMap := make(map[string]json.RawMessage)
deps.log.Error("unable to store access token in session", "error", err) for k, v := range rawClaims {
http.Error(w, err.Error(), http.StatusInternalServerError) if k != "sub" && k != "name" && k != "email" && k != "picture" {
return customMap[k] = v
} }
}
if len(customMap) > 0 {
user.Custom, _ = json.Marshal(customMap)
}
deps.sessions.Put(r.Context(), "user", user)
deps.sessions.Put(r.Context(), "access_token", token.AccessToken)
http.Redirect(w, r, "/", http.StatusFound) http.Redirect(w, r, "/", http.StatusFound)
} }

View File

@@ -29,7 +29,7 @@ func (m *mockSessionManager) Get(ctx context.Context, key string) any {
return m.store[key] return m.store[key]
} }
func (m *mockSessionManager) Put(ctx context.Context, key string, value any) error { func (m *mockSessionManager) Put(ctx context.Context, key string, value any) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
if value == nil { if value == nil {
@@ -37,7 +37,6 @@ func (m *mockSessionManager) Put(ctx context.Context, key string, value any) err
} else { } else {
m.store[key] = value m.store[key] = value
} }
return nil
} }
func TestHandleLogic(t *testing.T) { func TestHandleLogic(t *testing.T) {

View File

@@ -20,11 +20,7 @@ func authenticatedMiddleware(deps *deps, next http.Handler) http.Handler {
return return
} }
if err = deps.sessions.Put(r.Context(), StateKey, state); err != nil { deps.sessions.Put(r.Context(), StateKey, state)
deps.log.Error("unable to store state in session", "error", err)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
loginURL := deps.auth.AuthCodeURL(state) loginURL := deps.auth.AuthCodeURL(state)
http.Redirect(w, r, loginURL, http.StatusFound) http.Redirect(w, r, loginURL, http.StatusFound)
@@ -37,6 +33,6 @@ func authenticatedMiddleware(deps *deps, next http.Handler) http.Handler {
}) })
} }
func CurrentUser(r *http.Request) any { func CurrentUser(r *http.Request) SessionUser {
return r.Context().Value(userContextKey{}) return r.Context().Value(userContextKey{}).(SessionUser)
} }

25
auth/auth0/session.go Normal file
View File

@@ -0,0 +1,25 @@
package auth0
import "encoding/json"
type SessionUser struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
Picture string `json:"picture"`
Custom json.RawMessage `json:"-"`
}
func (u *SessionUser) CustomClaims() (map[string]any, error) {
if len(u.Custom) == 0 {
return nil, nil
}
var claims map[string]any
if err := json.Unmarshal(u.Custom, &claims); err != nil {
return nil, err
}
return claims, nil
}

View File

@@ -2,6 +2,7 @@ package server
import ( import (
"log/slog" "log/slog"
"net/http"
"time" "time"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@@ -32,3 +33,7 @@ func WithWriteTimeout(d time.Duration) Option {
func WithIdleTimeout(d time.Duration) Option { func WithIdleTimeout(d time.Duration) Option {
return func(server *Server) { server.idleTimeout = d } return func(server *Server) { server.idleTimeout = d }
} }
func WithMiddleware(mw func(http.Handler) http.Handler) Option {
return func(server *Server) { server.Router.Use(mw) }
}