Refactor database to separate package

This commit is contained in:
Andrew-71 2024-10-12 20:41:30 +03:00
parent 363f557c35
commit b07f1b080a
9 changed files with 84 additions and 62 deletions

2
.gitignore vendored
View file

@ -1,3 +1,3 @@
pye-auth pye
private.key private.key
data.db data.db

View file

@ -2,4 +2,4 @@ build:
go build go build
run: run:
go build && ./pye-auth go build && ./pye

19
auth.go
View file

@ -22,24 +22,18 @@ func Register(w http.ResponseWriter, r *http.Request) {
if ok { if ok {
email = strings.TrimSpace(email) email = strings.TrimSpace(email)
password = strings.TrimSpace(password) password = strings.TrimSpace(password)
if !(validEmail(email) && validPass(password) && !emailExists(email)) { if !(validEmail(email) && validPass(password) && !data.EmailExists(email)) {
slog.Debug("Outcome", slog.Debug("Outcome",
"email", validEmail(email), "email", validEmail(email),
"pass", validPass(password), "pass", validPass(password),
"taken", !emailExists(email)) "taken", !data.EmailExists(email))
http.Error(w, "invalid auth credentials", http.StatusBadRequest) http.Error(w, "invalid auth credentials", http.StatusBadRequest)
return return
} }
user, err := NewUser(email, password) err := data.AddUser(email, password)
if err != nil { if err != nil {
slog.Error("error creating a new user", "error", err) slog.Error("error adding a new user", "error", err)
http.Error(w, "error creating a new user", http.StatusInternalServerError) http.Error(w, "error adding a new user", http.StatusInternalServerError)
return
}
err = addUser(user)
if err != nil {
slog.Error("error saving a new user", "error", err)
http.Error(w, "error saving a new user", http.StatusInternalServerError)
return return
} }
w.WriteHeader(http.StatusCreated) w.WriteHeader(http.StatusCreated)
@ -58,8 +52,9 @@ func Login(w http.ResponseWriter, r *http.Request) {
if ok { if ok {
email = strings.TrimSpace(email) email = strings.TrimSpace(email)
password = strings.TrimSpace(password) password = strings.TrimSpace(password)
user, ok := byEmail(email) user, ok := data.ByEmail(email)
if !ok || !user.PasswordFits(password) { if !ok || !user.PasswordFits(password) {
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
http.Error(w, "you did something wrong", http.StatusUnauthorized) http.Error(w, "you did something wrong", http.StatusUnauthorized)
return return
} }

2
go.mod
View file

@ -1,4 +1,4 @@
module pye-auth module git.a71.su/Andrew71/pye
go 1.22 go 1.22

13
jwt.go
View file

@ -11,6 +11,7 @@ import (
"os" "os"
"time" "time"
"git.a71.su/Andrew71/pye/storage"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
) )
@ -21,7 +22,7 @@ var (
// LoadKey attempts to load a private key from KeyFile. // LoadKey attempts to load a private key from KeyFile.
// If the file does not exist, it generates a new key (and saves it) // If the file does not exist, it generates a new key (and saves it)
func LoadKey() { func MustLoadKey() {
// If the key doesn't exist, create it // If the key doesn't exist, create it
if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) {
key, err = rsa.GenerateKey(rand.Reader, 4096) key, err = rsa.GenerateKey(rand.Reader, 4096)
@ -59,6 +60,10 @@ func LoadKey() {
} }
} }
func init() {
MustLoadKey()
}
// publicKey returns our public key as PEM block // publicKey returns our public key as PEM block
func publicKey(w http.ResponseWriter, r *http.Request) { func publicKey(w http.ResponseWriter, r *http.Request) {
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey) key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
@ -66,7 +71,7 @@ func publicKey(w http.ResponseWriter, r *http.Request) {
pem.Encode(w, &block) pem.Encode(w, &block)
} }
func CreateJWT(usr User) (string, error) { func CreateJWT(usr storage.User) (string, error) {
t := jwt.NewWithClaims(jwt.SigningMethodRS256, t := jwt.NewWithClaims(jwt.SigningMethodRS256,
jwt.MapClaims{ jwt.MapClaims{
"iss": "pye", "iss": "pye",
@ -99,7 +104,3 @@ func VerifyJWT(token string, publicKey []byte) bool {
slog.Info("Error check", "err", err) slog.Info("Error check", "err", err)
return err == nil return err == nil
} }
func init() {
LoadKey()
}

View file

@ -3,8 +3,13 @@ package main
import ( import (
"fmt" "fmt"
"net/http" "net/http"
"git.a71.su/Andrew71/pye/storage"
"git.a71.su/Andrew71/pye/storage/sqlite"
) )
var data storage.Storage = sqlite.MustLoadSQLite()
func main() { func main() {
fmt.Println("=== Working on port 7102 ===") fmt.Println("=== Working on port 7102 ===")

View file

@ -1,4 +1,4 @@
package main package sqlite
import ( import (
"database/sql" "database/sql"
@ -6,6 +6,7 @@ import (
"log/slog" "log/slog"
"os" "os"
"git.a71.su/Andrew71/pye/storage"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
@ -19,10 +20,48 @@ const create string = `
var ( var (
DataFile = "data.db" DataFile = "data.db"
db *sql.DB = LoadDb()
) )
func LoadDb() *sql.DB { type SQLiteStorage struct {
db *sql.DB
}
func (s SQLiteStorage) AddUser(email, password string) error {
user, err := storage.NewUser(email, password)
if err != nil {
return err
}
_, err = s.db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
user.Uuid.String(), user.Email, user.Hash)
if err != nil {
slog.Error("error adding user to database", "error", err, "user", user)
return err
}
return nil
}
func (s SQLiteStorage) ById(uuid string) (storage.User, bool) {
row := s.db.QueryRow("select * from users where uuid = $1", uuid)
user := storage.User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func (s SQLiteStorage) ByEmail(email string) (storage.User, bool) {
row := s.db.QueryRow("select * from users where email = $1", email)
user := storage.User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func (s SQLiteStorage) EmailExists(email string) bool {
_, ok := s.ByEmail(email)
return ok
}
func MustLoadSQLite() SQLiteStorage {
// I *think* we need some file, even if only empty // I *think* we need some file, even if only empty
if _, err := os.Stat(DataFile); errors.Is(err, os.ErrNotExist) { if _, err := os.Stat(DataFile); errors.Is(err, os.ErrNotExist) {
slog.Error("sqlite3 database file required", "file", DataFile) slog.Error("sqlite3 database file required", "file", DataFile)
@ -33,41 +72,12 @@ func LoadDb() *sql.DB {
slog.Error("error opening database", "error", err) slog.Error("error opening database", "error", err)
os.Exit(1) os.Exit(1)
} }
// TODO: Apparently "prepare" works here
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" { if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
slog.Info("error initialising database table", "error", err) slog.Info("error initialising database table", "error", err)
os.Exit(1) os.Exit(1)
} }
slog.Info("loaded database") slog.Info("loaded database")
return db return SQLiteStorage{db}
}
func addUser(user User) error {
_, err := db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
user.Uuid.String(), user.Email, user.Hash)
if err != nil {
slog.Error("error adding user", "error", err, "user", user)
return err
}
return nil
}
func byId(uuid string) (User, bool) {
row := db.QueryRow("select * from users where uuid = $1", uuid)
user := User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func byEmail(email string) (User, bool) {
row := db.QueryRow("select * from users where email = $1", email)
user := User{}
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
return user, err == nil
}
func emailExists(email string) bool {
_, ok := byEmail(email)
return ok
} }

8
storage/storage.go Normal file
View file

@ -0,0 +1,8 @@
package storage
type Storage interface {
AddUser(email, password string) error
ById(uuid string) (User, bool)
ByEmail(uuid string) (User, bool)
EmailExists(email string) bool
}

View file

@ -1,6 +1,8 @@
package main package storage
import ( import (
"log/slog"
"github.com/google/uuid" "github.com/google/uuid"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
@ -19,6 +21,7 @@ func (u User) PasswordFits(password string) bool {
func NewUser(email, password string) (User, error) { func NewUser(email, password string) (User, error) {
hash, err := bcrypt.GenerateFromPassword([]byte(password), 14) hash, err := bcrypt.GenerateFromPassword([]byte(password), 14)
if err != nil { if err != nil {
slog.Error("error creating a new user", "error", err)
return User{}, err return User{}, err
} }
return User{uuid.New(), email, hash}, nil return User{uuid.New(), email, hash}, nil