Compare commits
2 commits
363f557c35
...
cda8f0cc1b
Author | SHA1 | Date | |
---|---|---|---|
cda8f0cc1b | |||
b07f1b080a |
15 changed files with 264 additions and 139 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,3 +1,3 @@
|
||||||
pye-auth
|
pye
|
||||||
private.key
|
private.key
|
||||||
data.db
|
dev-data.db
|
7
Makefile
7
Makefile
|
@ -1,5 +1,8 @@
|
||||||
build:
|
build:
|
||||||
go build
|
go build
|
||||||
|
|
||||||
run:
|
serve:
|
||||||
go build && ./pye-auth
|
go build && ./pye serve
|
||||||
|
|
||||||
|
dev:
|
||||||
|
go build && ./pye serve --db dev-data.db
|
12
README.md
12
README.md
|
@ -11,10 +11,14 @@ in a state that proves I am competent Go developer.
|
||||||
|
|
||||||
## Current functionality
|
## Current functionality
|
||||||
|
|
||||||
* Port `7102`
|
## `serve`
|
||||||
|
|
||||||
* `POST /register` - register a user with Basic Auth
|
* `POST /register` - register a user with Basic Auth
|
||||||
* `POST /login` - get a JWT token by Basic Auth
|
* `POST /login` - get a JWT token by Basic Auth
|
||||||
* `GET /pem` - get PEM-encoded public RS256 key
|
* `GET /pem` - get PEM-encoded public RS256 key
|
||||||
* Data persistently stored in an SQLite database `data.db`
|
* Data persistently stored in an SQLite database
|
||||||
(requires creation of empty db)
|
* RS256 key loaded from a file or generated on startup if missing
|
||||||
* RS256 key loaded from `private.key` file or generated on startup if missing
|
|
||||||
|
## `verify`
|
||||||
|
|
||||||
|
* Verify JWT via public key in a PEM file
|
|
@ -1,10 +1,12 @@
|
||||||
package main
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/storage"
|
||||||
)
|
)
|
||||||
|
|
||||||
func validEmail(email string) bool {
|
func validEmail(email string) bool {
|
||||||
|
@ -16,30 +18,24 @@ func validPass(pass string) bool {
|
||||||
return len(pass) >= 8
|
return len(pass) >= 8
|
||||||
}
|
}
|
||||||
|
|
||||||
func Register(w http.ResponseWriter, r *http.Request) {
|
func Register(w http.ResponseWriter, r *http.Request, data storage.Storage) {
|
||||||
email, password, ok := r.BasicAuth()
|
email, password, ok := r.BasicAuth()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
email = strings.TrimSpace(email)
|
email = strings.TrimSpace(email)
|
||||||
password = strings.TrimSpace(password)
|
password = strings.TrimSpace(password)
|
||||||
if !(validEmail(email) && validPass(password) && !emailExists(email)) {
|
if !(validEmail(email) && validPass(password) && !data.EmailExists(email)) {
|
||||||
slog.Debug("Outcome",
|
slog.Debug("Outcome",
|
||||||
"email", validEmail(email),
|
"email", validEmail(email),
|
||||||
"pass", validPass(password),
|
"pass", validPass(password),
|
||||||
"taken", !emailExists(email))
|
"taken", !data.EmailExists(email))
|
||||||
http.Error(w, "invalid auth credentials", http.StatusBadRequest)
|
http.Error(w, "invalid auth credentials", http.StatusBadRequest)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
user, err := NewUser(email, password)
|
err := data.AddUser(email, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error creating a new user", "error", err)
|
slog.Error("error adding a new user", "error", err)
|
||||||
http.Error(w, "error creating a new user", http.StatusInternalServerError)
|
http.Error(w, "error adding a new user", http.StatusInternalServerError)
|
||||||
return
|
|
||||||
}
|
|
||||||
err = addUser(user)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error saving a new user", "error", err)
|
|
||||||
http.Error(w, "error saving a new user", http.StatusInternalServerError)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
w.WriteHeader(http.StatusCreated)
|
w.WriteHeader(http.StatusCreated)
|
||||||
|
@ -52,14 +48,15 @@ func Register(w http.ResponseWriter, r *http.Request) {
|
||||||
http.Error(w, "This API requires authorization", http.StatusUnauthorized)
|
http.Error(w, "This API requires authorization", http.StatusUnauthorized)
|
||||||
}
|
}
|
||||||
|
|
||||||
func Login(w http.ResponseWriter, r *http.Request) {
|
func Login(w http.ResponseWriter, r *http.Request, data storage.Storage) {
|
||||||
email, password, ok := r.BasicAuth()
|
email, password, ok := r.BasicAuth()
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
email = strings.TrimSpace(email)
|
email = strings.TrimSpace(email)
|
||||||
password = strings.TrimSpace(password)
|
password = strings.TrimSpace(password)
|
||||||
user, ok := byEmail(email)
|
user, ok := data.ByEmail(email)
|
||||||
if !ok || !user.PasswordFits(password) {
|
if !ok || !user.PasswordFits(password) {
|
||||||
|
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||||
http.Error(w, "you did something wrong", http.StatusUnauthorized)
|
http.Error(w, "you did something wrong", http.StatusUnauthorized)
|
||||||
return
|
return
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package main
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
|
@ -11,19 +11,20 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/config"
|
||||||
|
"git.a71.su/Andrew71/pye/storage"
|
||||||
"github.com/golang-jwt/jwt/v5"
|
"github.com/golang-jwt/jwt/v5"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
KeyFile = "private.key"
|
|
||||||
key *rsa.PrivateKey
|
key *rsa.PrivateKey
|
||||||
)
|
)
|
||||||
|
|
||||||
// LoadKey attempts to load a private key from KeyFile.
|
// LoadKey attempts to load a private key from KeyFile.
|
||||||
// If the file does not exist, it generates a new key (and saves it)
|
// If the file does not exist, it generates a new key (and saves it)
|
||||||
func LoadKey() {
|
func MustLoadKey() {
|
||||||
// If the key doesn't exist, create it
|
// If the key doesn't exist, create it
|
||||||
if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) {
|
if _, err := os.Stat(config.Cfg.KeyFile); errors.Is(err, os.ErrNotExist) {
|
||||||
key, err = rsa.GenerateKey(rand.Reader, 4096)
|
key, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error generating key", "error", err)
|
slog.Error("error generating key", "error", err)
|
||||||
|
@ -33,7 +34,7 @@ func LoadKey() {
|
||||||
// Save key to disk
|
// Save key to disk
|
||||||
km := x509.MarshalPKCS1PrivateKey(key)
|
km := x509.MarshalPKCS1PrivateKey(key)
|
||||||
block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"}
|
block := pem.Block{Bytes: km, Type: "RSA PRIVATE KEY"}
|
||||||
f, err := os.OpenFile(KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
f, err := os.OpenFile(config.Cfg.KeyFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error opening/creating file", "error", err)
|
slog.Error("error opening/creating file", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
@ -43,9 +44,9 @@ func LoadKey() {
|
||||||
slog.Error("error closing file", "error", err)
|
slog.Error("error closing file", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
slog.Info("generated new key")
|
slog.Info("generated new key", "file", config.Cfg.KeyFile)
|
||||||
} else {
|
} else {
|
||||||
km, err := os.ReadFile(KeyFile)
|
km, err := os.ReadFile(config.Cfg.KeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Error("error reading key", "error", err)
|
slog.Error("error reading key", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
|
@ -55,23 +56,27 @@ func LoadKey() {
|
||||||
slog.Error("error parsing key", "error", err)
|
slog.Error("error parsing key", "error", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
slog.Info("loaded private key")
|
slog.Info("loaded private key", "file", config.Cfg.KeyFile)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// publicKey returns our public key as PEM block
|
func init() {
|
||||||
func publicKey(w http.ResponseWriter, r *http.Request) {
|
MustLoadKey()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PublicKey returns our public key as PEM block over http
|
||||||
|
func PublicKey(w http.ResponseWriter, r *http.Request) {
|
||||||
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
||||||
block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"}
|
block := pem.Block{Bytes: key_marshalled, Type: "RSA PUBLIC KEY"}
|
||||||
pem.Encode(w, &block)
|
pem.Encode(w, &block)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateJWT(usr User) (string, error) {
|
func CreateJWT(user storage.User) (string, error) {
|
||||||
t := jwt.NewWithClaims(jwt.SigningMethodRS256,
|
t := jwt.NewWithClaims(jwt.SigningMethodRS256,
|
||||||
jwt.MapClaims{
|
jwt.MapClaims{
|
||||||
"iss": "pye",
|
"iss": "pye",
|
||||||
"uid": usr.Uuid,
|
"uid": user.Uuid,
|
||||||
"sub": usr.Email,
|
"sub": user.Email,
|
||||||
"iat": time.Now().Unix(),
|
"iat": time.Now().Unix(),
|
||||||
"exp": time.Now().Add(time.Hour * 24 * 7).Unix(),
|
"exp": time.Now().Add(time.Hour * 24 * 7).Unix(),
|
||||||
})
|
})
|
||||||
|
@ -85,8 +90,8 @@ func CreateJWT(usr User) (string, error) {
|
||||||
|
|
||||||
// VerifyToken receives a JWT and PEM-encoded public key,
|
// VerifyToken receives a JWT and PEM-encoded public key,
|
||||||
// then returns whether the token is valid
|
// then returns whether the token is valid
|
||||||
func VerifyJWT(token string, publicKey []byte) bool {
|
func VerifyJWT(token string, publicKey []byte) (*jwt.Token, error) {
|
||||||
_, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
t, err := jwt.Parse(token, func(token *jwt.Token) (interface{}, error) {
|
||||||
key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey)
|
key, err := jwt.ParseRSAPublicKeyFromPEM(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -96,10 +101,5 @@ func VerifyJWT(token string, publicKey []byte) bool {
|
||||||
}
|
}
|
||||||
return key, nil
|
return key, nil
|
||||||
})
|
})
|
||||||
slog.Info("Error check", "err", err)
|
return t, err
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
LoadKey()
|
|
||||||
}
|
}
|
51
cmd/main.go
Normal file
51
cmd/main.go
Normal file
|
@ -0,0 +1,51 @@
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/cmd/serve"
|
||||||
|
"git.a71.su/Andrew71/pye/cmd/verify"
|
||||||
|
"git.a71.su/Andrew71/pye/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Run() {
|
||||||
|
// configFlag := flag.String("config", "", "override config file")
|
||||||
|
// flag.Parse()
|
||||||
|
// if *configFlag != "" {
|
||||||
|
// config.Load()
|
||||||
|
// }
|
||||||
|
|
||||||
|
serveCmd := flag.NewFlagSet("serve", flag.ExitOnError)
|
||||||
|
servePort := serveCmd.Int("port", 0, "override port")
|
||||||
|
serveDb := serveCmd.String("db", "", "override sqlite database")
|
||||||
|
|
||||||
|
verifyCmd := flag.NewFlagSet("verify", flag.ExitOnError)
|
||||||
|
|
||||||
|
if len(os.Args) < 2 {
|
||||||
|
fmt.Println("expected 'serve' or 'verify' subcommands")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch os.Args[1] {
|
||||||
|
case "serve":
|
||||||
|
serveCmd.Parse(os.Args[2:])
|
||||||
|
if *servePort != 0 {
|
||||||
|
config.Cfg.Port = *servePort
|
||||||
|
}
|
||||||
|
if *serveDb != "" {
|
||||||
|
config.Cfg.SQLiteFile = *serveDb
|
||||||
|
}
|
||||||
|
serve.Serve()
|
||||||
|
case "verify":
|
||||||
|
verifyCmd.Parse(os.Args[2:])
|
||||||
|
if len(os.Args) != 4 {
|
||||||
|
fmt.Println("Usage: <jwt> <pem file>")
|
||||||
|
}
|
||||||
|
verify.Verify(os.Args[2], os.Args[3])
|
||||||
|
default:
|
||||||
|
fmt.Println("expected 'serve' or 'verify' subcommands")
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
32
cmd/serve/main.go
Normal file
32
cmd/serve/main.go
Normal file
|
@ -0,0 +1,32 @@
|
||||||
|
package serve
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/auth"
|
||||||
|
"git.a71.su/Andrew71/pye/config"
|
||||||
|
"git.a71.su/Andrew71/pye/storage"
|
||||||
|
"git.a71.su/Andrew71/pye/storage/sqlite"
|
||||||
|
)
|
||||||
|
|
||||||
|
var data storage.Storage
|
||||||
|
|
||||||
|
func Serve() {
|
||||||
|
data = sqlite.MustLoadSQLite(config.Cfg.SQLiteFile)
|
||||||
|
|
||||||
|
router := http.NewServeMux()
|
||||||
|
|
||||||
|
router.HandleFunc("GET /pem", auth.PublicKey)
|
||||||
|
|
||||||
|
router.HandleFunc("POST /register", func(w http.ResponseWriter, r *http.Request) { auth.Register(w, r, data) })
|
||||||
|
router.HandleFunc("POST /login", func(w http.ResponseWriter, r *http.Request) { auth.Login(w, r, data) })
|
||||||
|
|
||||||
|
// Note: likely temporary, possibly to be replaced by a fake "frontend"
|
||||||
|
router.HandleFunc("GET /register", func(w http.ResponseWriter, r *http.Request) { auth.Register(w, r, data) })
|
||||||
|
router.HandleFunc("GET /login", func(w http.ResponseWriter, r *http.Request) { auth.Login(w, r, data) })
|
||||||
|
|
||||||
|
slog.Info("🪐 pye started", "port", config.Cfg.Port)
|
||||||
|
http.ListenAndServe(":"+strconv.Itoa(config.Cfg.Port), router)
|
||||||
|
}
|
17
cmd/verify/main.go
Normal file
17
cmd/verify/main.go
Normal file
|
@ -0,0 +1,17 @@
|
||||||
|
package verify
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/auth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Verify(token, filename string) {
|
||||||
|
key, err := os.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error reading file", "error", err, "file", filename)
|
||||||
|
}
|
||||||
|
t, err := auth.VerifyJWT(token, key)
|
||||||
|
slog.Info("result", "token", t, "error", err, "ok", err == nil)
|
||||||
|
}
|
20
config/config.go
Normal file
20
config/config.go
Normal file
|
@ -0,0 +1,20 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
Port int `json:"port"`
|
||||||
|
KeyFile string `json:"key-file"`
|
||||||
|
SQLiteFile string `json:"sqlite-file"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var DefaultConfig = Config{
|
||||||
|
Port: 7102,
|
||||||
|
KeyFile: "private.key",
|
||||||
|
SQLiteFile: "data.db",
|
||||||
|
}
|
||||||
|
|
||||||
|
var Cfg = MustLoadConfig()
|
||||||
|
|
||||||
|
// TODO: Implement
|
||||||
|
func MustLoadConfig() Config {
|
||||||
|
return DefaultConfig
|
||||||
|
}
|
73
db.go
73
db.go
|
@ -1,73 +0,0 @@
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"log/slog"
|
|
||||||
"os"
|
|
||||||
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const create string = `
|
|
||||||
CREATE TABLE "users" (
|
|
||||||
"uuid" TEXT NOT NULL UNIQUE,
|
|
||||||
"email" TEXT NOT NULL UNIQUE,
|
|
||||||
"password" TEXT NOT NULL,
|
|
||||||
PRIMARY KEY("uuid")
|
|
||||||
);`
|
|
||||||
|
|
||||||
var (
|
|
||||||
DataFile = "data.db"
|
|
||||||
db *sql.DB = LoadDb()
|
|
||||||
)
|
|
||||||
|
|
||||||
func LoadDb() *sql.DB {
|
|
||||||
// I *think* we need some file, even if only empty
|
|
||||||
if _, err := os.Stat(DataFile); errors.Is(err, os.ErrNotExist) {
|
|
||||||
slog.Error("sqlite3 database file required", "file", DataFile)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
db, err := sql.Open("sqlite3", DataFile)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error opening database", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
|
|
||||||
slog.Info("error initialising database table", "error", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
slog.Info("loaded database")
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
|
|
||||||
func addUser(user User) error {
|
|
||||||
_, err := db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
|
|
||||||
user.Uuid.String(), user.Email, user.Hash)
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("error adding user", "error", err, "user", user)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func byId(uuid string) (User, bool) {
|
|
||||||
row := db.QueryRow("select * from users where uuid = $1", uuid)
|
|
||||||
user := User{}
|
|
||||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
|
||||||
|
|
||||||
return user, err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func byEmail(email string) (User, bool) {
|
|
||||||
row := db.QueryRow("select * from users where email = $1", email)
|
|
||||||
user := User{}
|
|
||||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
|
||||||
|
|
||||||
return user, err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func emailExists(email string) bool {
|
|
||||||
_, ok := byEmail(email)
|
|
||||||
return ok
|
|
||||||
}
|
|
2
go.mod
2
go.mod
|
@ -1,4 +1,4 @@
|
||||||
module pye-auth
|
module git.a71.su/Andrew71/pye
|
||||||
|
|
||||||
go 1.22
|
go 1.22
|
||||||
|
|
||||||
|
|
20
main.go
20
main.go
|
@ -1,23 +1,7 @@
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import "git.a71.su/Andrew71/pye/cmd"
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fmt.Println("=== Working on port 7102 ===")
|
cmd.Run()
|
||||||
|
|
||||||
router := http.NewServeMux()
|
|
||||||
|
|
||||||
router.HandleFunc("GET /pem", publicKey)
|
|
||||||
|
|
||||||
router.HandleFunc("POST /register", Register)
|
|
||||||
router.HandleFunc("POST /login", Login)
|
|
||||||
|
|
||||||
// Note: likely temporary, possibly to be replaced by a fake "frontend"
|
|
||||||
router.HandleFunc("GET /login", Login)
|
|
||||||
router.HandleFunc("GET /register", Register)
|
|
||||||
|
|
||||||
http.ListenAndServe(":7102", router)
|
|
||||||
}
|
}
|
||||||
|
|
79
storage/sqlite/sqlite.go
Normal file
79
storage/sqlite/sqlite.go
Normal file
|
@ -0,0 +1,79 @@
|
||||||
|
package sqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"git.a71.su/Andrew71/pye/storage"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const create string = `
|
||||||
|
CREATE TABLE "users" (
|
||||||
|
"uuid" TEXT NOT NULL UNIQUE,
|
||||||
|
"email" TEXT NOT NULL UNIQUE,
|
||||||
|
"password" TEXT NOT NULL,
|
||||||
|
PRIMARY KEY("uuid")
|
||||||
|
);`
|
||||||
|
|
||||||
|
type SQLiteStorage struct {
|
||||||
|
db *sql.DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SQLiteStorage) AddUser(email, password string) error {
|
||||||
|
user, err := storage.NewUser(email, password)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
|
||||||
|
user.Uuid.String(), user.Email, user.Hash)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error adding user to database", "error", err, "user", user)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SQLiteStorage) ById(uuid string) (storage.User, bool) {
|
||||||
|
row := s.db.QueryRow("select * from users where uuid = $1", uuid)
|
||||||
|
user := storage.User{}
|
||||||
|
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||||
|
|
||||||
|
return user, err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SQLiteStorage) ByEmail(email string) (storage.User, bool) {
|
||||||
|
row := s.db.QueryRow("select * from users where email = $1", email)
|
||||||
|
user := storage.User{}
|
||||||
|
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||||
|
|
||||||
|
return user, err == nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s SQLiteStorage) EmailExists(email string) bool {
|
||||||
|
_, ok := s.ByEmail(email)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func MustLoadSQLite(dataFile string) SQLiteStorage {
|
||||||
|
// I *think* we need some file, even if only empty
|
||||||
|
if _, err := os.Stat(dataFile); errors.Is(err, os.ErrNotExist) {
|
||||||
|
slog.Error("sqlite3 database file required", "file", dataFile)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
db, err := sql.Open("sqlite3", dataFile)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("error opening database", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Apparently "prepare" works here
|
||||||
|
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
|
||||||
|
slog.Info("error initialising database table", "error", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
slog.Info("loaded database", "file", dataFile)
|
||||||
|
return SQLiteStorage{db}
|
||||||
|
}
|
8
storage/storage.go
Normal file
8
storage/storage.go
Normal file
|
@ -0,0 +1,8 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
type Storage interface {
|
||||||
|
AddUser(email, password string) error
|
||||||
|
ById(uuid string) (User, bool)
|
||||||
|
ByEmail(uuid string) (User, bool)
|
||||||
|
EmailExists(email string) bool
|
||||||
|
}
|
|
@ -1,6 +1,8 @@
|
||||||
package main
|
package storage
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"log/slog"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
@ -19,6 +21,7 @@ func (u User) PasswordFits(password string) bool {
|
||||||
func NewUser(email, password string) (User, error) {
|
func NewUser(email, password string) (User, error) {
|
||||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
hash, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
slog.Error("error creating a new user", "error", err)
|
||||||
return User{}, err
|
return User{}, err
|
||||||
}
|
}
|
||||||
return User{uuid.New(), email, hash}, nil
|
return User{uuid.New(), email, hash}, nil
|
Loading…
Reference in a new issue