pye/internal/auth/jwt.go

108 lines
3 KiB
Go
Raw Normal View History

2024-10-12 21:45:00 +03:00
package auth
2024-10-11 23:57:57 +03:00
import (
"crypto/rand"
2024-10-12 16:59:47 +03:00
"crypto/rsa"
2024-10-11 23:57:57 +03:00
"crypto/x509"
2024-10-12 09:55:58 +03:00
"encoding/pem"
"errors"
2024-10-11 23:57:57 +03:00
"log/slog"
2024-10-12 09:55:58 +03:00
"net/http"
"os"
2024-10-12 10:22:05 +03:00
"time"
2024-10-11 23:57:57 +03:00
2024-10-13 21:03:44 +03:00
"git.a71.su/Andrew71/pye/internal/config"
"git.a71.su/Andrew71/pye/internal/models/user"
2024-10-12 09:55:58 +03:00
"github.com/golang-jwt/jwt/v5"
2024-10-11 23:57:57 +03:00
)
2024-10-13 14:49:41 +03:00
var key *rsa.PrivateKey
2024-10-12 09:55:58 +03:00
2024-10-13 17:18:53 +03:00
// LoadKey attempts to load a private RS256 key from file.
2024-10-12 09:55:58 +03:00
// If the file does not exist, it generates a new key (and saves it)
2024-10-12 20:41:30 +03:00
func MustLoadKey() {
2024-10-12 09:55:58 +03:00
// If the key doesn't exist, create it
2024-10-12 21:45:00 +03:00
if _, err := os.Stat(config.Cfg.KeyFile); errors.Is(err, os.ErrNotExist) {
2024-10-12 16:59:47 +03:00
key, err = rsa.GenerateKey(rand.Reader, 4096)
2024-10-12 09:55:58 +03:00
if err != nil {
slog.Error("error generating key", "error", err)
os.Exit(1)
}
2024-10-12 16:59:47 +03:00
// Save key to disk
km := x509.MarshalPKCS1PrivateKey(key)
block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"}
2024-10-12 21:45:00 +03:00
f, err := os.OpenFile(config.Cfg.KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
2024-10-12 09:55:58 +03:00
if err != nil {
2024-10-12 16:59:47 +03:00
slog.Error("error opening/creating file", "error", err)
os.Exit(1)
}
f.Write(pem.EncodeToMemory(&block))
if err := f.Close(); err != nil {
slog.Error("error closing file", "error", err)
2024-10-12 09:55:58 +03:00
os.Exit(1)
}
2024-10-13 16:38:13 +03:00
slog.Debug("generated new key", "file", config.Cfg.KeyFile)
2024-10-12 09:55:58 +03:00
} else {
2024-10-12 21:45:00 +03:00
km, err := os.ReadFile(config.Cfg.KeyFile)
2024-10-12 09:55:58 +03:00
if err != nil {
slog.Error("error reading key", "error", err)
os.Exit(1)
}
2024-10-12 16:59:47 +03:00
key, err = jwt.ParseRSAPrivateKeyFromPEM(km)
2024-10-12 09:55:58 +03:00
if err != nil {
slog.Error("error parsing key", "error", err)
os.Exit(1)
}
2024-10-13 16:38:13 +03:00
slog.Debug("loaded private key", "file", config.Cfg.KeyFile)
2024-10-12 09:55:58 +03:00
}
}
2024-10-13 17:18:53 +03:00
// ServePublicKey returns our public key as PEM block over HTTP
func ServePublicKey(w http.ResponseWriter, r *http.Request) {
2024-10-12 16:59:47 +03:00
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"}
2024-10-12 09:55:58 +03:00
pem.Encode(w, &block)
2024-10-11 23:57:57 +03:00
}
2024-10-13 17:18:53 +03:00
// Create creates a JSON Web Token that expires after a week
2024-10-13 21:03:44 +03:00
func Create(user user.User) (token string, err error) {
2024-10-12 16:59:47 +03:00
t := jwt.NewWithClaims(jwt.SigningMethodRS256,
2024-10-12 09:55:58 +03:00
jwt.MapClaims{
"iss": "pye",
2024-10-12 21:45:00 +03:00
"uid": user.Uuid,
"sub": user.Email,
2024-10-12 16:59:47 +03:00
"iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24 * 7).Unix(),
2024-10-12 09:55:58 +03:00
})
2024-10-13 17:18:53 +03:00
token, err = t.SignedString(key)
2024-10-12 09:55:58 +03:00
if err != nil {
2024-10-12 16:59:47 +03:00
slog.Error("error creating JWT", "error", err)
2024-10-12 09:55:58 +03:00
return "", err
}
2024-10-13 17:18:53 +03:00
return
2024-10-12 09:55:58 +03:00
}
2024-10-12 16:59:47 +03:00
2024-10-13 17:18:53 +03:00
// Verify receives a JWT and PEM-encoded public key,
2024-10-12 16:59:47 +03:00
// then returns whether the token is valid
2024-10-13 17:18:53 +03:00
func Verify(token string, publicKey []byte) (*jwt.Token, error) {
2024-10-12 21:45:00 +03:00
t, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
2024-10-12 16:59:47 +03:00
key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey)
if err != nil {
return nil, err
}
if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok {
return nil, err
}
return key, nil
})
2024-10-12 21:45:00 +03:00
return t, err
2024-10-12 16:59:47 +03:00
}
2024-10-13 16:16:19 +03:00
2024-10-13 17:18:53 +03:00
// VerifyLocal calls Verify with public key set to current local one
func VerifyLocal(token string) (*jwt.Token, error) {
2024-10-13 16:16:19 +03:00
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"}
2024-10-13 17:18:53 +03:00
return Verify(token, pem.EncodeToMemory(&block))
2024-10-13 16:38:13 +03:00
}