Refactor database to separate package
This commit is contained in:
parent
363f557c35
commit
b07f1b080a
9 changed files with 84 additions and 62 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -1,3 +1,3 @@
|
|||
pye-auth
|
||||
pye
|
||||
private.key
|
||||
data.db
|
2
Makefile
2
Makefile
|
@ -2,4 +2,4 @@ build:
|
|||
go build
|
||||
|
||||
run:
|
||||
go build && ./pye-auth
|
||||
go build && ./pye
|
19
auth.go
19
auth.go
|
@ -22,24 +22,18 @@ func Register(w http.ResponseWriter, r *http.Request) {
|
|||
if ok {
|
||||
email = strings.TrimSpace(email)
|
||||
password = strings.TrimSpace(password)
|
||||
if !(validEmail(email) && validPass(password) && !emailExists(email)) {
|
||||
if !(validEmail(email) && validPass(password) && !data.EmailExists(email)) {
|
||||
slog.Debug("Outcome",
|
||||
"email", validEmail(email),
|
||||
"pass", validPass(password),
|
||||
"taken", !emailExists(email))
|
||||
"taken", !data.EmailExists(email))
|
||||
http.Error(w, "invalid auth credentials", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
user, err := NewUser(email, password)
|
||||
err := data.AddUser(email, password)
|
||||
if err != nil {
|
||||
slog.Error("error creating a new user", "error", err)
|
||||
http.Error(w, "error creating a new user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
err = addUser(user)
|
||||
if err != nil {
|
||||
slog.Error("error saving a new user", "error", err)
|
||||
http.Error(w, "error saving a new user", http.StatusInternalServerError)
|
||||
slog.Error("error adding a new user", "error", err)
|
||||
http.Error(w, "error adding a new user", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
|
@ -58,8 +52,9 @@ func Login(w http.ResponseWriter, r *http.Request) {
|
|||
if ok {
|
||||
email = strings.TrimSpace(email)
|
||||
password = strings.TrimSpace(password)
|
||||
user, ok := byEmail(email)
|
||||
user, ok := data.ByEmail(email)
|
||||
if !ok || !user.PasswordFits(password) {
|
||||
w.Header().Set("WWW-Authenticate", `Basic realm="restricted", charset="UTF-8"`)
|
||||
http.Error(w, "you did something wrong", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
|
2
go.mod
2
go.mod
|
@ -1,4 +1,4 @@
|
|||
module pye-auth
|
||||
module git.a71.su/Andrew71/pye
|
||||
|
||||
go 1.22
|
||||
|
||||
|
|
17
jwt.go
17
jwt.go
|
@ -11,17 +11,18 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"git.a71.su/Andrew71/pye/storage"
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
)
|
||||
|
||||
var (
|
||||
KeyFile = "private.key"
|
||||
key *rsa.PrivateKey
|
||||
KeyFile = "private.key"
|
||||
key *rsa.PrivateKey
|
||||
)
|
||||
|
||||
// LoadKey attempts to load a private key from KeyFile.
|
||||
// If the file does not exist, it generates a new key (and saves it)
|
||||
func LoadKey() {
|
||||
func MustLoadKey() {
|
||||
// If the key doesn't exist, create it
|
||||
if _, err := os.Stat(KeyFile); errors.Is(err, os.ErrNotExist) {
|
||||
key, err = rsa.GenerateKey(rand.Reader, 4096)
|
||||
|
@ -59,6 +60,10 @@ func LoadKey() {
|
|||
}
|
||||
}
|
||||
|
||||
func init() {
|
||||
MustLoadKey()
|
||||
}
|
||||
|
||||
// publicKey returns our public key as PEM block
|
||||
func publicKey(w http.ResponseWriter, r *http.Request) {
|
||||
key_marshalled := x509.MarshalPKCS1PublicKey(&key.PublicKey)
|
||||
|
@ -66,7 +71,7 @@ func publicKey(w http.ResponseWriter, r *http.Request) {
|
|||
pem.Encode(w, &block)
|
||||
}
|
||||
|
||||
func CreateJWT(usr User) (string, error) {
|
||||
func CreateJWT(usr storage.User) (string, error) {
|
||||
t := jwt.NewWithClaims(jwt.SigningMethodRS256,
|
||||
jwt.MapClaims{
|
||||
"iss": "pye",
|
||||
|
@ -99,7 +104,3 @@ func VerifyJWT(token string, publicKey []byte) bool {
|
|||
slog.Info("Error check", "err", err)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
LoadKey()
|
||||
}
|
||||
|
|
5
main.go
5
main.go
|
@ -3,8 +3,13 @@ package main
|
|||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"git.a71.su/Andrew71/pye/storage"
|
||||
"git.a71.su/Andrew71/pye/storage/sqlite"
|
||||
)
|
||||
|
||||
var data storage.Storage = sqlite.MustLoadSQLite()
|
||||
|
||||
func main() {
|
||||
fmt.Println("=== Working on port 7102 ===")
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package main
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
@ -6,6 +6,7 @@ import (
|
|||
"log/slog"
|
||||
"os"
|
||||
|
||||
"git.a71.su/Andrew71/pye/storage"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
|
@ -18,11 +19,49 @@ const create string = `
|
|||
);`
|
||||
|
||||
var (
|
||||
DataFile = "data.db"
|
||||
db *sql.DB = LoadDb()
|
||||
DataFile = "data.db"
|
||||
)
|
||||
|
||||
func LoadDb() *sql.DB {
|
||||
type SQLiteStorage struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s SQLiteStorage) AddUser(email, password string) error {
|
||||
user, err := storage.NewUser(email, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = s.db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
|
||||
user.Uuid.String(), user.Email, user.Hash)
|
||||
if err != nil {
|
||||
slog.Error("error adding user to database", "error", err, "user", user)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s SQLiteStorage) ById(uuid string) (storage.User, bool) {
|
||||
row := s.db.QueryRow("select * from users where uuid = $1", uuid)
|
||||
user := storage.User{}
|
||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||
|
||||
return user, err == nil
|
||||
}
|
||||
|
||||
func (s SQLiteStorage) ByEmail(email string) (storage.User, bool) {
|
||||
row := s.db.QueryRow("select * from users where email = $1", email)
|
||||
user := storage.User{}
|
||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||
|
||||
return user, err == nil
|
||||
}
|
||||
|
||||
func (s SQLiteStorage) EmailExists(email string) bool {
|
||||
_, ok := s.ByEmail(email)
|
||||
return ok
|
||||
}
|
||||
|
||||
func MustLoadSQLite() SQLiteStorage {
|
||||
// I *think* we need some file, even if only empty
|
||||
if _, err := os.Stat(DataFile); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Error("sqlite3 database file required", "file", DataFile)
|
||||
|
@ -33,41 +72,12 @@ func LoadDb() *sql.DB {
|
|||
slog.Error("error opening database", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// TODO: Apparently "prepare" works here
|
||||
if _, err := db.Exec(create); err != nil && err.Error() != "table \"users\" already exists" {
|
||||
slog.Info("error initialising database table", "error", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
slog.Info("loaded database")
|
||||
return db
|
||||
}
|
||||
|
||||
func addUser(user User) error {
|
||||
_, err := db.Exec("insert into users (uuid, email, password) values ($1, $2, $3)",
|
||||
user.Uuid.String(), user.Email, user.Hash)
|
||||
if err != nil {
|
||||
slog.Error("error adding user", "error", err, "user", user)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func byId(uuid string) (User, bool) {
|
||||
row := db.QueryRow("select * from users where uuid = $1", uuid)
|
||||
user := User{}
|
||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||
|
||||
return user, err == nil
|
||||
}
|
||||
|
||||
func byEmail(email string) (User, bool) {
|
||||
row := db.QueryRow("select * from users where email = $1", email)
|
||||
user := User{}
|
||||
err := row.Scan(&user.Uuid, &user.Email, &user.Hash)
|
||||
|
||||
return user, err == nil
|
||||
}
|
||||
|
||||
func emailExists(email string) bool {
|
||||
_, ok := byEmail(email)
|
||||
return ok
|
||||
return SQLiteStorage{db}
|
||||
}
|
8
storage/storage.go
Normal file
8
storage/storage.go
Normal file
|
@ -0,0 +1,8 @@
|
|||
package storage
|
||||
|
||||
type Storage interface {
|
||||
AddUser(email, password string) error
|
||||
ById(uuid string) (User, bool)
|
||||
ByEmail(uuid string) (User, bool)
|
||||
EmailExists(email string) bool
|
||||
}
|
|
@ -1,6 +1,8 @@
|
|||
package main
|
||||
package storage
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
@ -19,6 +21,7 @@ func (u User) PasswordFits(password string) bool {
|
|||
func NewUser(email, password string) (User, error) {
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), 14)
|
||||
if err != nil {
|
||||
slog.Error("error creating a new user", "error", err)
|
||||
return User{}, err
|
||||
}
|
||||
return User{uuid.New(), email, hash}, nil
|
Loading…
Reference in a new issue