Compare commits

...

2 commits

Author SHA1 Message Date
cda8f0cc1b Refactor everything 2024-10-12 21:45:00 +03:00
b07f1b080a Refactor database to separate package 2024-10-12 20:41:30 +03:00
15 changed files with 264 additions and 139 deletions

4
.gitignore vendored
View file

@ -1,3 +1,3 @@
pye-auth pye
private.key private.key
data.db dev-data.db

View file

@ -1,5 +1,8 @@
build: build:
go build go build
run: serve:
go build && ./pye-auth go build && ./pye serve
dev:
go build && ./pye serve --db dev-data.db

View file

@ -11,10 +11,14 @@ in a state that proves I am competent Go developer.
## Current functionality ## Current functionality
* Port `7102` ## `serve`
* `POST /register` - register a user with Basic Auth * `POST /register` - register a user with Basic Auth
* `POST /login` - get a JWT token by Basic Auth * `POST /login` - get a JWT token by Basic Auth
* `GET /pem` - get PEM-encoded public RS256 key * `GET /pem` - get PEM-encoded public RS256 key
* Data persistently stored in an SQLite database `data.db` * Data persistently stored in an SQLite database
(requires creation of empty db) * RS256 key loaded from a file or generated on startup if missing
* RS256 key loaded from `private.key` file or generated on startup if missing
## `verify`
* Verify JWT via public key in a PEM file

View file

@ -1,10 +1,12 @@
package main package auth
import ( import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/mail" "net/mail"
"strings" "strings"
"git.a71.su/Andrew71/pye/storage"
) )
func validEmail(email string) bool { func validEmail(email string) bool {
@ -16,30 +18,24 @@ func validPass(pass string) bool {
return len(pass) >= 8 return len(pass) >= 8
} }
func Register(w http.ResponseWriter, r *http.Request) { func Register(w http.ResponseWriter, r *http.Request, data storage.Storage) {
email, password, ok := r.BasicAuth() email, password, ok := r.BasicAuth()
if ok { if ok {
email = strings.TrimSpace(email) email = strings.TrimSpace(email)
password = strings.TrimSpace(password) password = strings.TrimSpace(password)
if !(validEmail(email) && validPass(password) && !emailExists(email)) { if !(validEmail(email) && validPass(password) && !data.EmailExists(email)) {
slog.Debug("Outcome", slog.Debug("Outcome",
"email", validEmail(email), "email", validEmail(email),
"pass", validPass(password), "pass", validPass(password),
"taken", !emailExists(email)) "taken", !data.EmailExists(email))
http.Error(w, "invalid auth credentials", http.StatusBadRequest) http.Error(w, "invalid auth credentials", http.StatusBadRequest)
return return
} }
user, err := NewUser(email, password) err := data.AddUser(email, password)
if err != nil { if err != nil {
slog.Error("error creating a new user", "error", err) slog.Error("error adding a new user", "error", err)
http.Error(w, "error creating a new user", http.StatusInternalServerError) http.Error(w, "error adding a new user", http.StatusInternalServerError)
return
}
err = addUser(user)
if err != nil {
slog.Error("error saving a new user", "error", err)
http.Error(w, "error saving a new user", http.StatusInternalServerError)
return return
} }
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
@ -52,14 +48,15 @@ func Register(w http.ResponseWriter, r *http.Request) {
http.Error(w, "This API requires authorization", http.StatusUnauthorized) http.Error(w, "This API requires authorization", http.StatusUnauthorized)
} }
func Login(w http.ResponseWriter, r *http.Request) { func Login(w http.ResponseWriter, r *http.Request, data storage.Storage) {
email, password, ok := r.BasicAuth() email, password, ok := r.BasicAuth()
if ok { if ok {
email = strings.TrimSpace(email) email = strings.TrimSpace(email)
password = strings.TrimSpace(password) password = strings.TrimSpace(password)
user, ok := byEmail(email) user, ok := data.ByEmail(email)
if !ok || !user.PasswordFits(password) { if !ok || !user.PasswordFits(password) {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "you did something wrong", http.StatusUnauthorized) http.Error(w, "you did something wrong", http.StatusUnauthorized)
return return
} }

View file

@ -1,4 +1,4 @@
package main package auth
import ( import (
"crypto/rand" "crypto/rand"
@ -11,19 +11,20 @@ import (
"os" "os"
"time" "time"
"git.a71.su/Andrew71/pye/config"
"git.a71.su/Andrew71/pye/storage"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
var ( var (
KeyFile = "private.key" key *rsa.PrivateKey
key *rsa.PrivateKey
) )
// LoadKey attempts to load a private key from KeyFile. // LoadKey attempts to load a private key from KeyFile.
// If the file does not exist, it generates a new key (and saves it) // If the file does not exist, it generates a new key (and saves it)
func LoadKey() { func MustLoadKey() {
// If the key doesn't exist, create it // If the key doesn't exist, create it
if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(config.Cfg.KeyFile); errors.Is(err, os.ErrNotExist) {
key, err = rsa.GenerateKey(rand.Reader, 4096) key, err = rsa.GenerateKey(rand.Reader, 4096)
if err != nil { if err != nil {
slog.Error("error generating key", "error", err) slog.Error("error generating key", "error", err)
@ -33,7 +34,7 @@ func LoadKey() {
// Save key to disk // Save key to disk
km := x509.MarshalPKCS1PrivateKey(key) km := x509.MarshalPKCS1PrivateKey(key)
block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"} block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"}
f, err := os.OpenFile(KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) f, err := os.OpenFile(config.Cfg.KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil { if err != nil {
slog.Error("error opening/creating file", "error", err) slog.Error("error opening/creating file", "error", err)
os.Exit(1) os.Exit(1)
@ -43,9 +44,9 @@ func LoadKey() {
slog.Error("error closing file", "error", err) slog.Error("error closing file", "error", err)
os.Exit(1) os.Exit(1)
} }
slog.Info("generated new key") slog.Info("generated new key", "file", config.Cfg.KeyFile)
} else { } else {
km, err := os.ReadFile(KeyFile) km, err := os.ReadFile(config.Cfg.KeyFile)
if err != nil { if err != nil {
slog.Error("error reading key", "error", err) slog.Error("error reading key", "error", err)
os.Exit(1) os.Exit(1)
@ -55,23 +56,27 @@ func LoadKey() {
slog.Error("error parsing key", "error", err) slog.Error("error parsing key", "error", err)
os.Exit(1) os.Exit(1)
} }
slog.Info("loaded private key") slog.Info("loaded private key", "file", config.Cfg.KeyFile)
} }
} }
// publicKey returns our public key as PEM block func init() {
func publicKey(w http.ResponseWriter, r *http.Request) { MustLoadKey()
}
// PublicKey returns our public key as PEM block over http
func PublicKey(w http.ResponseWriter, r *http.Request) {
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey) key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"} block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"}
pem.Encode(w, &block) pem.Encode(w, &block)
} }
func CreateJWT(usr User) (string, error) { func CreateJWT(user storage.User) (string, error) {
t := jwt.NewWithClaims(jwt.SigningMethodRS256, t := jwt.NewWithClaims(jwt.SigningMethodRS256,
jwt.MapClaims{ jwt.MapClaims{
"iss": "pye", "iss": "pye",
"uid": usr.Uuid, "uid": user.Uuid,
"sub": usr.Email, "sub": user.Email,
"iat": time.Now().Unix(), "iat": time.Now().Unix(),
"exp": time.Now().Add(time.Hour * 24 * 7).Unix(), "exp": time.Now().Add(time.Hour * 24 * 7).Unix(),
}) })
@ -85,8 +90,8 @@ func CreateJWT(usr User) (string, error) {
// VerifyToken receives a JWT and PEM-encoded public key, // VerifyToken receives a JWT and PEM-encoded public key,
// then returns whether the token is valid // then returns whether the token is valid
func VerifyJWT(token string, publicKey []byte) bool { func VerifyJWT(token string, publicKey []byte) (*jwt.Token, error) {
_, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) { t, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey) key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey)
if err != nil { if err != nil {
return nil, err return nil, err
@ -96,10 +101,5 @@ func VerifyJWT(token string, publicKey []byte) bool {
} }
return key, nil return key, nil
}) })
slog.Info("Error check", "err", err) return t, err
return err == nil
}
func init() {
LoadKey()
} }

51
cmd/main.go Normal file
View file

@ -0,0 +1,51 @@
package cmd
import (
"flag"
"fmt"
"os"
"git.a71.su/Andrew71/pye/cmd/serve"
"git.a71.su/Andrew71/pye/cmd/verify"
"git.a71.su/Andrew71/pye/config"
)
func Run() {
// configFlag := flag.String("config", "", "override config file")
// flag.Parse()
// if *configFlag != "" {
// config.Load()
// }
serveCmd := flag.NewFlagSet("serve", flag.ExitOnError)
servePort := serveCmd.Int("port", 0, "override port")
serveDb := serveCmd.String("db", "", "override sqlite database")
verifyCmd := flag.NewFlagSet("verify", flag.ExitOnError)
if len(os.Args) < 2 {
fmt.Println("expected 'serve' or 'verify' subcommands")
os.Exit(0)
}
switch os.Args[1] {
case "serve":
serveCmd.Parse(os.Args[2:])
if *servePort != 0 {
config.Cfg.Port = *servePort
}
if *serveDb != "" {
config.Cfg.SQLiteFile = *serveDb
}
serve.Serve()
case "verify":
verifyCmd.Parse(os.Args[2:])
if len(os.Args) != 4 {
fmt.Println("Usage: <jwt> <pem file>")
}
verify.Verify(os.Args[2], os.Args[3])
default:
fmt.Println("expected 'serve' or 'verify' subcommands")
os.Exit(0)
}
}

32
cmd/serve/main.go Normal file
View file

@ -0,0 +1,32 @@
package serve
import (
"log/slog"
"net/http"
"strconv"
"git.a71.su/Andrew71/pye/auth"
"git.a71.su/Andrew71/pye/config"
"git.a71.su/Andrew71/pye/storage"
"git.a71.su/Andrew71/pye/storage/sqlite"
)
var data storage.Storage
func Serve() {
data = sqlite.MustLoadSQLite(config.Cfg.SQLiteFile)
router := http.NewServeMux()
router.HandleFunc("GET /pem", auth.PublicKey)
router.HandleFunc("POST /register", func(w http.ResponseWriter, r *http.Request) { auth.Register(w, r, data) })
router.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) { auth.Login(w, r, data) })
// Note: likely temporary, possibly to be replaced by a fake "frontend"
router.HandleFunc("GET /register", func(w http.ResponseWriter, r *http.Request) { auth.Register(w, r, data) })
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { auth.Login(w, r, data) })
slog.Info("🪐 pye started", "port", config.Cfg.Port)
http.ListenAndServe(":"+strconv.Itoa(config.Cfg.Port), router)
}

17
cmd/verify/main.go Normal file
View file

@ -0,0 +1,17 @@
package verify
import (
"log/slog"
"os"
"git.a71.su/Andrew71/pye/auth"
)
func Verify(token, filename string) {
key, err := os.ReadFile(filename)
if err != nil {
slog.Error("error reading file", "error", err, "file", filename)
}
t, err := auth.VerifyJWT(token, key)
slog.Info("result", "token", t, "error", err, "ok", err == nil)
}

20
config/config.go Normal file
View file

@ -0,0 +1,20 @@
package config
type Config struct {
Port int `json:"port"`
KeyFile string `json:"key-file"`
SQLiteFile string `json:"sqlite-file"`
}
var DefaultConfig = Config{
Port: 7102,
KeyFile: "private.key",
SQLiteFile: "data.db",
}
var Cfg = MustLoadConfig()
// TODO: Implement
func MustLoadConfig() Config {
return DefaultConfig
}

73
db.go
View file

@ -1,73 +0,0 @@
package main
import (
"database/sql"
"errors"
"log/slog"
"os"
_ "github.com/mattn/go-sqlite3"
)
const create string = `
CREATE TABLE "users" (
"uuid" TEXT NOT NULL UNIQUE,
"email" TEXT NOT NULL UNIQUE,
"password" TEXT NOT NULL,
PRIMARY KEY("uuid")
);`
var (
DataFile = "data.db"
db *sql.DB = LoadDb()
)
func LoadDb() *sql.DB {
// I *think* we need some file, even if only empty
if _, err := os.Stat(DataFile); errors.Is(err, os.ErrNotExist) {
slog.Error("sqlite3 database file required", "file", DataFile)
os.Exit(1)
}
db, err := sql.Open("sqlite3", DataFile)
if err != nil {
slog.Error("error opening database", "error", err)
os.Exit(1)
}
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
slog.Info("error initialising database table", "error", err)
os.Exit(1)
}
slog.Info("loaded database")
return db
}
func addUser(user User) error {
_, err := db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
user.Uuid.String(), user.Email, user.Hash)
if err != nil {
slog.Error("error adding user", "error", err, "user", user)
return err
}
return nil
}
func byId(uuid string) (User, bool) {
row := db.QueryRow("select * from users where uuid = $1", uuid)
user := User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func byEmail(email string) (User, bool) {
row := db.QueryRow("select * from users where email = $1", email)
user := User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func emailExists(email string) bool {
_, ok := byEmail(email)
return ok
}

2
go.mod
View file

@ -1,4 +1,4 @@
module pye-auth module git.a71.su/Andrew71/pye
go 1.22 go 1.22

20
main.go
View file

@ -1,23 +1,7 @@
package main package main
import ( import "git.a71.su/Andrew71/pye/cmd"
"fmt"
"net/http"
)
func main() { func main() {
fmt.Println("=== Working on port 7102 ===") cmd.Run()
router := http.NewServeMux()
router.HandleFunc("GET /pem", publicKey)
router.HandleFunc("POST /register", Register)
router.HandleFunc("POST /login", Login)
// Note: likely temporary, possibly to be replaced by a fake "frontend"
router.HandleFunc("GET /login", Login)
router.HandleFunc("GET /register", Register)
http.ListenAndServe(":7102", router)
} }

79
storage/sqlite/sqlite.go Normal file
View file

@ -0,0 +1,79 @@
package sqlite
import (
"database/sql"
"errors"
"log/slog"
"os"
"git.a71.su/Andrew71/pye/storage"
_ "github.com/mattn/go-sqlite3"
)
const create string = `
CREATE TABLE "users" (
"uuid" TEXT NOT NULL UNIQUE,
"email" TEXT NOT NULL UNIQUE,
"password" TEXT NOT NULL,
PRIMARY KEY("uuid")
);`
type SQLiteStorage struct {
db *sql.DB
}
func (s SQLiteStorage) AddUser(email, password string) error {
user, err := storage.NewUser(email, password)
if err != nil {
return err
}
_, err = s.db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
user.Uuid.String(), user.Email, user.Hash)
if err != nil {
slog.Error("error adding user to database", "error", err, "user", user)
return err
}
return nil
}
func (s SQLiteStorage) ById(uuid string) (storage.User, bool) {
row := s.db.QueryRow("select * from users where uuid = $1", uuid)
user := storage.User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func (s SQLiteStorage) ByEmail(email string) (storage.User, bool) {
row := s.db.QueryRow("select * from users where email = $1", email)
user := storage.User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func (s SQLiteStorage) EmailExists(email string) bool {
_, ok := s.ByEmail(email)
return ok
}
func MustLoadSQLite(dataFile string) SQLiteStorage {
// I *think* we need some file, even if only empty
if _, err := os.Stat(dataFile); errors.Is(err, os.ErrNotExist) {
slog.Error("sqlite3 database file required", "file", dataFile)
os.Exit(1)
}
db, err := sql.Open("sqlite3", dataFile)
if err != nil {
slog.Error("error opening database", "error", err)
os.Exit(1)
}
// TODO: Apparently "prepare" works here
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
slog.Info("error initialising database table", "error", err)
os.Exit(1)
}
slog.Info("loaded database", "file", dataFile)
return SQLiteStorage{db}
}

8
storage/storage.go Normal file
View file

@ -0,0 +1,8 @@
package storage
type Storage interface {
AddUser(email, password string) error
ById(uuid string) (User, bool)
ByEmail(uuid string) (User, bool)
EmailExists(email string) bool
}

View file

@ -1,6 +1,8 @@
package main package storage
import ( import (
"log/slog"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -19,6 +21,7 @@ func (u User) PasswordFits(password string) bool {
func NewUser(email, password string) (User, error) { func NewUser(email, password string) (User, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), 14) hash, err := bcrypt.GenerateFromPassword([]byte(password), 14)
if err != nil { if err != nil {
slog.Error("error creating a new user", "error", err)
return User{}, err return User{}, err
} }
return User{uuid.New(), email, hash}, nil return User{uuid.New(), email, hash}, nil