diff --git a/.gitignore b/.gitignore index 84695b4..c9a0010 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,3 @@ pye-auth -key +private.key data.json \ No newline at end of file diff --git a/README.md b/README.md index d39f102..6c999d9 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# PYE Auth +# Auth microservice **Mission**: Science compels us to create a microservice! @@ -14,9 +14,9 @@ obviously I'd use **SQL** for production ## Current functionality -* Port 7102 +* Port `7102` * `POST /register` - register a user with Basic Auth * `POST /login` - get a JWT token by Basic Auth -* `GET /public-key` - get PEM-encoded public HS256 key +* `GET /pem` - get PEM-encoded public RS256 key * Data persistently stored in... `data.json`, for convenience -* HS256 key loaded from `key` file or generated on startup if missing \ No newline at end of file +* RS256 key loaded from `private.key` file or generated on startup if missing \ No newline at end of file diff --git a/auth.go b/auth.go index 64ac598..cc2cf3f 100644 --- a/auth.go +++ b/auth.go @@ -13,11 +13,10 @@ func ValidEmail(email string) bool { } func ValidPass(pass string) bool { // TODO: Obviously, we *might* want something more sophisticated here - return true - //return len(pass) >= 8 + return len(pass) >= 8 } func EmailTaken(email string) bool { - // TODO: Implement properly + // FIXME: Implement properly return EmailExists(email) } func Register(w http.ResponseWriter, r *http.Request) { @@ -27,12 +26,11 @@ func Register(w http.ResponseWriter, r *http.Request) { email = strings.TrimSpace(email) password = strings.TrimSpace(password) if !(ValidEmail(email) && ValidPass(password) && !EmailTaken(email)) { - // TODO: Provide descriptive error and check if 400 is best code? slog.Info("Outcome", "email", ValidEmail(email), "pass", ValidPass(password), "taken", !EmailTaken(email)) - http.Error(w, "Invalid auth credentials", http.StatusBadRequest) + http.Error(w, "invalid auth credentials", http.StatusBadRequest) return } user, err := NewUser(email, password) @@ -59,7 +57,7 @@ func Login(w http.ResponseWriter, r *http.Request) { password = strings.TrimSpace(password) user, ok := ByEmail(email) if !ok || !user.PasswordFits(password) { - http.Error(w, "You did something wrong", http.StatusUnauthorized) + http.Error(w, "you did something wrong", http.StatusUnauthorized) return } diff --git a/jwt.go b/jwt.go index e219f93..c6dcc5e 100644 --- a/jwt.go +++ b/jwt.go @@ -1,9 +1,8 @@ package main import ( - "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" + "crypto/rsa" "crypto/x509" "encoding/pem" "errors" @@ -15,10 +14,10 @@ import ( "github.com/golang-jwt/jwt/v5" ) -var KeyFile = "key" +var KeyFile = "private.key" var ( - key *ecdsa.PrivateKey + key *rsa.PrivateKey ) // LoadKey attempts to load a private key from KeyFile. @@ -26,17 +25,25 @@ var ( func LoadKey() { // If the key doesn't exist, create it if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) { - key, err = ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + key, err = rsa.GenerateKey(rand.Reader, 4096) if err != nil { slog.Error("error generating key", "error", err) os.Exit(1) } - km, err := x509.MarshalECPrivateKey(key) // Save private key to disk + + // Save key to disk + km := x509.MarshalPKCS1PrivateKey(key) + block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"} + f, err := os.OpenFile(KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) if err != nil { - slog.Error("error marshalling key", "error", err) + slog.Error("error opening/creating file", "error", err) + os.Exit(1) + } + f.Write(pem.EncodeToMemory(&block)) + if err := f.Close(); err != nil { + slog.Error("error closing file", "error", err) os.Exit(1) } - os.WriteFile(KeyFile, km, 0644) slog.Info("generated new key") } else { km, err := os.ReadFile(KeyFile) @@ -44,47 +51,56 @@ func LoadKey() { slog.Error("error reading key", "error", err) os.Exit(1) } - key, err = x509.ParseECPrivateKey(km) + key, err = jwt.ParseRSAPrivateKeyFromPEM(km) if err != nil { slog.Error("error parsing key", "error", err) os.Exit(1) } slog.Info("loaded private key") } - slog.Debug("private key", "key", key) } -// publicKey returns our public key in PKIX, ASN.1 DER form +// publicKey returns our public key as PEM block func publicKey(w http.ResponseWriter, r *http.Request) { - key_marshalled, err := x509.MarshalPKIXPublicKey(&key.PublicKey) - if err != nil { - slog.Error("error marshalling public key", "error", err) - http.Error(w, "error marshalling public key", http.StatusInternalServerError) - return - } - // w.Write(key_marshalled) - block := pem.Block{Bytes: key_marshalled, Type: "ECDSA PUBLIC KEY"} - // slog.Info("public key", "orig", key_marshalled, "block", block) + key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey) + block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"} pem.Encode(w, &block) } +func CreateJWT(usr User) (string, error) { + t := jwt.NewWithClaims(jwt.SigningMethodRS256, + jwt.MapClaims{ + "iss": "pye", + "uid": usr.Uuid, + "sub": usr.Email, + "iat": time.Now().Unix(), + "exp": time.Now().Add(time.Hour * 24 * 7).Unix(), + }) + s, err := t.SignedString(key) + if err != nil { + slog.Error("error creating JWT", "error", err) + return "", err + } + return s, nil +} + +// VerifyToken receives a JWT and PEM-encoded public key, +// then returns whether the token is valid +func VerifyJWT(token string, publicKey []byte) bool { + _, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { + key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey) + if err != nil { + return nil, err + } + if _, ok := token.Method.(*jwt.SigningMethodRSA); !ok { + return nil, err + } + return key, nil + }) + slog.Info("Error check", "err", err) + return err == nil +} + func init() { LoadKey() } - -func CreateJWT(usr User) (string, error) { - t := jwt.NewWithClaims(jwt.SigningMethodES256, - jwt.MapClaims{ - "iss": "pye", - "uid": usr.Uuid, - "sub": usr.Email, - "iat": time.Now(), - "exp": time.Now().Add(time.Hour * 24 * 7), - }) - s, err := t.SignedString(key) - if err != nil { - slog.Error("Error creating JWT", "error", err) - return "", err - } - return s, nil -} diff --git a/main.go b/main.go index 83fd878..7ee58f0 100644 --- a/main.go +++ b/main.go @@ -10,10 +10,12 @@ func main() { router := http.NewServeMux() - router.HandleFunc("GET /public-key", publicKey) + router.HandleFunc("GET /pem", publicKey) router.HandleFunc("POST /register", Register) router.HandleFunc("POST /login", Login) + router.HandleFunc("GET /login", Login) // TODO: temp + http.ListenAndServe(":7102", router) }