From 98f6d67eca1b219347d75e83ff208c112fce00d1 Mon Sep 17 00:00:00 2001 From: vilmibm Date: Tue, 14 Jun 2022 16:05:39 -0500 Subject: [PATCH] thinkin bout db --- server/cmd/api/api_test.go | 14 +- server/cmd/db/db.go | 62 +++- server/cmd/main.go | 622 +++++++++++++++++-------------------- server/cmd/schema.sql | 52 ---- 4 files changed, 363 insertions(+), 387 deletions(-) delete mode 100644 server/cmd/schema.sql diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 77e7bfc..5a10e94 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -11,6 +11,7 @@ import ( "testing" "git.tilde.town/tildetown/bbj2/server/cmd/config" + "git.tilde.town/tildetown/bbj2/server/cmd/db" ) func TestInstanceInfo(t *testing.T) { @@ -26,7 +27,7 @@ func TestInstanceInfo(t *testing.T) { t.Fatalf("failed to make test db: %s", err.Error()) } logger := log.New(os.Stdout, "bbj test", log.Lshortfile) - defaultOptions := config.Options{ + opts := config.Options{ IO: testIO, Logger: logger, Config: config.Config{ @@ -38,6 +39,14 @@ func TestInstanceInfo(t *testing.T) { DBPath: dbFile.Name(), }, } + + teardown, err := db.Setup(opts) + if err != nil { + t.Fatalf("could not initialize DB: %s", err.Error()) + return + } + defer teardown() + ts := []struct { name string opts config.Options @@ -46,7 +55,7 @@ func TestInstanceInfo(t *testing.T) { }{ { name: "basic", - opts: defaultOptions, + opts: opts, wantData: instanceInfo{ InstanceName: "cool test zone", AllowAnon: true, @@ -54,7 +63,6 @@ func TestInstanceInfo(t *testing.T) { }, }, } - for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { req, _ := http.NewRequest("GET", "", strings.NewReader("")) diff --git a/server/cmd/db/db.go b/server/cmd/db/db.go index 7c3f532..83d35aa 100644 --- a/server/cmd/db/db.go +++ b/server/cmd/db/db.go @@ -1,7 +1,26 @@ package db -import "time" +import ( + "database/sql" + _ "embed" + "errors" + "fmt" + "os" + "strings" + "time" + "git.tilde.town/tildetown/bbj2/server/cmd/config" + _ "github.com/mattn/go-sqlite3" +) + +//go:embed schema.sql +var schemaSQL string + +// TODO I'm not sold on this hash system; without transport encryption, it +// doesn't really help anything. I'd rather have plaintext + transport +// encryption and then, on the server side, proper salted hashing. I can't +// figure out if there was a reason for this approach that I'm just +// overlooking. type User struct { ID string Username string @@ -29,3 +48,44 @@ type Message struct { Body string SendRaw int `json:"send_raw"` // TODO bool } + +func Setup(opts config.Options) (func(), error) { + db, err := sql.Open("sqlite3", opts.Config.DBPath) + opts.DB = db + return func() { db.Close() }, err +} + +func EnsureSchema(opts config.Options) error { + db := opts.DB + + if opts.Reset { + err := os.Remove(opts.Config.DBPath) + if err != nil { + return fmt.Errorf("failed to delete database: %w", err) + } + } + rows, err := db.Query("select version from meta") + if err == nil { + defer rows.Close() + rows.Next() + var version string + err = rows.Scan(&version) + if err != nil { + return fmt.Errorf("failed to check database schema version: %w", err) + } else if version == "" { + return errors.New("database is in unknown state") + } + return nil + } + + if !strings.Contains(err.Error(), "no such table") { + return fmt.Errorf("got error checking database state: %w", err) + } + + _, err = db.Exec(schemaSQL) + if err != nil { + return fmt.Errorf("failed to initialize database schema: %w", err) + } + + return nil +} diff --git a/server/cmd/main.go b/server/cmd/main.go index b5dca2c..d7dbef3 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -1,27 +1,21 @@ package main import ( - "database/sql" - _ "embed" "encoding/json" - "errors" "flag" "fmt" "log" "net/http" "os" - "strings" "git.tilde.town/tildetown/bbj2/server/cmd/api" "git.tilde.town/tildetown/bbj2/server/cmd/config" + "git.tilde.town/tildetown/bbj2/server/cmd/db" _ "github.com/mattn/go-sqlite3" ) // TODO tests -//go:embed schema.sql -var schemaSQL string - func main() { var configFlag = flag.String("config", "config.yml", "A path to a config file.") var resetFlag = flag.Bool("reset", false, "reset the database. WARNING this wipes everything.") @@ -44,13 +38,20 @@ func main() { } } -type Teardown func() +/* -func setupDB(opts *config.Options) (Teardown, error) { - db, err := sql.Open("sqlite3", opts.Config.DBPath) - opts.DB = db - return func() { db.Close() }, err -} +TODO my next initiative is doing /something/ about the database layer. + +The amount of boiler plate involved in: + +- prepare a statement +- prepare a result struct +- execute statement +- scan into result + +is wild; the error handling is really out of control. I need to think of abstractions for this. The "easiest" is just making blunt, non performant functions that return structs and a single error, but that could get out of control too. In general I think not having raw sql ever in application code is a good place to start. + +*/ func _main(opts *config.Options) error { cfg, err := config.ParseConfig(opts.ConfigPath) @@ -61,13 +62,13 @@ func _main(opts *config.Options) error { opts.Config = *cfg - teardown, err := setupDB(opts) + teardown, err := db.Setup(*opts) if err != nil { return fmt.Errorf("could not initialize DB: %w", err) } defer teardown() - err = ensureSchema(*opts) + err = db.EnsureSchema(*opts) if err != nil { return err } @@ -83,41 +84,6 @@ func _main(opts *config.Options) error { return nil } -func ensureSchema(opts config.Options) error { - db := opts.DB - - if opts.Reset { - err := os.Remove(opts.Config.DBPath) - if err != nil { - return fmt.Errorf("failed to delete database: %w", err) - } - } - rows, err := db.Query("select version from meta") - if err == nil { - defer rows.Close() - rows.Next() - var version string - err = rows.Scan(&version) - if err != nil { - return fmt.Errorf("failed to check database schema version: %w", err) - } else if version == "" { - return errors.New("database is in unknown state") - } - return nil - } - - if !strings.Contains(err.Error(), "no such table") { - return fmt.Errorf("got error checking database state: %w", err) - } - - _, err = db.Exec(schemaSQL) - if err != nil { - return fmt.Errorf("failed to initialize database schema: %w", err) - } - - return nil -} - func handler(opts config.Options, f http.HandlerFunc) http.HandlerFunc { // TODO make this more real return func(w http.ResponseWriter, req *http.Request) { @@ -127,37 +93,6 @@ func handler(opts config.Options, f http.HandlerFunc) http.HandlerFunc { } } -// TODO I'm not entirely sold on this hash system; without transport -// encryption, it doesn't really help anything. I'd rather have plaintext + -// transport encryption and then, on the server side, proper salted hashing. - -// NB breaking: i'm not just returning 200 always but using http status codes - -func checkAuth(opts config.Options, username, hash string) error { - db := opts.DB - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") - if err != nil { - return fmt.Errorf("db error: %w", err) - } - defer stmt.Close() - - opts.Logger.Printf("querying for %s", username) - - var authHash string - if err = stmt.QueryRow(username).Scan(&authHash); err != nil { - if strings.Contains(err.Error(), "no rows in result") { - return errors.New("no such user") - } - return fmt.Errorf("db error: %w", err) - } - - if authHash != hash { - return errors.New("bad credentials") - } - - return nil -} - func setupAPI(opts config.Options) { handleFailedAPICreate := func(w http.ResponseWriter, err error) { opts.Logger.Printf("failed to create API: %s", err.Error()) @@ -190,287 +125,312 @@ func setupAPI(opts config.Options) { })) /* - http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } - - type AuthArgs struct { - Username string `json:"user_name"` - Hash string `json:"auth_hash"` - } - - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } - - if args.Hash == "" || args.Username == "" { - invalidArgs(w) - return - } - - opts.Logf("querying for %s", args.Username) - - if err := checkAuth(opts, args.Username, args.Hash); err == nil { - opts.Logf("found %s", args.Username) - // code 4 apparently - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "user already exists", - }) - } else if err.Error() != "no such user" { - serverErr(w, err) - return - } - + func checkAuth(opts config.Options, username, hash string) error { db := opts.DB - stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) - id, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } - - _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) - if err != nil { - serverErr(w, err) - } - - writeResponse(w, BBJResponse{ - Data: true, // TODO probably something else - // TODO prob usermap - }) - })) - - http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } - - type AuthArgs struct { - Username string `json:"target_user"` - AuthHash string `json:"target_hash"` - } - - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } - - opts.Logf("got %s %s", args.Username, args.AuthHash) - - db := opts.DB - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") if err != nil { - serverErr(w, err) - return + return fmt.Errorf("db error: %w", err) } defer stmt.Close() + opts.Logger.Printf("querying for %s", username) + var authHash string - err = stmt.QueryRow(args.Username).Scan(&authHash) - if err != nil { + if err = stmt.QueryRow(username).Scan(&authHash); err != nil { if strings.Contains(err.Error(), "no rows in result") { - opts.Logf("user not found") - writeErrorResponse(w, 404, BBJResponse{ - Error: true, - Data: "user not found", + return errors.New("no such user") + } + return fmt.Errorf("db error: %w", err) + } + + if authHash != hash { + return errors.New("bad credentials") + } + + return nil + } + + http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } + + type AuthArgs struct { + Username string `json:"user_name"` + Hash string `json:"auth_hash"` + } + + var args AuthArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } + + if args.Hash == "" || args.Username == "" { + invalidArgs(w) + return + } + + opts.Logf("querying for %s", args.Username) + + if err := checkAuth(opts, args.Username, args.Hash); err == nil { + opts.Logf("found %s", args.Username) + // code 4 apparently + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: "user already exists", + }) + } else if err.Error() != "no such user" { + serverErr(w, err) + return + } + + db := opts.DB + stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) + id, err := uuid.NewRandom() + if err != nil { + serverErr(w, err) + return + } + + _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) + if err != nil { + serverErr(w, err) + } + + writeResponse(w, BBJResponse{ + Data: true, // TODO probably something else + // TODO prob usermap }) - } else { - serverErr(w, err) - } - return - } + })) - // TODO unique constraint on user_name + http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } - if authHash != args.AuthHash { - http.Error(w, "incorrect password", 403) - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "incorrect password", - }) - return - } + type AuthArgs struct { + Username string `json:"target_user"` + AuthHash string `json:"target_hash"` + } - // TODO include usermap? - writeResponse(w, BBJResponse{ - Data: true, - }) - })) + var args AuthArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { - db := opts.DB - rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") - if err != nil { - serverErr(w, err) - return - } - defer rows.Close() - for rows.Next() { - var id string - err = rows.Scan(&id) - if err != nil { - serverErr(w, err) - return - } - opts.Log(id) - } - writeResponse(w, BBJResponse{Data: "TODO"}) - // TODO - })) + opts.Logf("got %s %s", args.Username, args.AuthHash) - http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + db := opts.DB - // TODO make this getUserInfoFromReq or similar so we can use the user ID later - user, err := getUserFromReq(opts, req) - if err != nil { - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: err.Error(), - }) - return - } + stmt, err := db.Prepare("select auth_hash from users where user_name = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - type threadCreateArgs struct { - Title string - Body string - SendRaw bool `json:"send_raw"` - } + var authHash string + err = stmt.QueryRow(args.Username).Scan(&authHash) + if err != nil { + if strings.Contains(err.Error(), "no rows in result") { + opts.Logf("user not found") + writeErrorResponse(w, 404, BBJResponse{ + Error: true, + Data: "user not found", + }) + } else { + serverErr(w, err) + } + return + } - var args threadCreateArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + // TODO unique constraint on user_name - if args.Title == "" || args.Body == "" { - invalidArgs(w) - return - } + if authHash != args.AuthHash { + http.Error(w, "incorrect password", 403) + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: "incorrect password", + }) + return + } - db := opts.DB - tx, err := db.Begin() - if err != nil { - serverErr(w, err) - return - } + // TODO include usermap? + writeResponse(w, BBJResponse{ + Data: true, + }) + })) - stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { + db := opts.DB + rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") + if err != nil { + serverErr(w, err) + return + } + defer rows.Close() + for rows.Next() { + var id string + err = rows.Scan(&id) + if err != nil { + serverErr(w, err) + return + } + opts.Log(id) + } + writeResponse(w, BBJResponse{Data: "TODO"}) + // TODO + })) - threadID, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } - now := time.Now() - if _, err = stmt.Exec( - threadID, - user.ID, - args.Title, - now, - now, - user.Username, - ); err != nil { - serverErr(w, err) - return - } + http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } - stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + // TODO make this getUserInfoFromReq or similar so we can use the user ID later + user, err := getUserFromReq(opts, req) + if err != nil { + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: err.Error(), + }) + return + } - if _, err = stmt.Exec( - threadID, - user.ID, - now, - args.Body, - args.SendRaw, - ); err != nil { - serverErr(w, err) - return - } + type threadCreateArgs struct { + Title string + Body string + SendRaw bool `json:"send_raw"` + } - if err = tx.Commit(); err != nil { - serverErr(w, err) - return - } + var args threadCreateArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + if args.Title == "" || args.Body == "" { + invalidArgs(w) + return + } - t := &Thread{} + db := opts.DB + tx, err := db.Begin() + if err != nil { + serverErr(w, err) + return + } - // TODO fill in rest of thread - if err = stmt.QueryRow(threadID).Scan( - t.ID, - t.Author, - t.Title, - t.LastMod, - t.Created, - t.ReplyCount, - t.Pinned, - t.LastAuthor, - ); err != nil { - serverErr(w, err) - return - } + stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - stmt, err = db.Prepare("select * from messages where thread_id = ?") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - rows, err := stmt.Query(threadID) - if err != nil { - serverErr(w, err) - return - } + threadID, err := uuid.NewRandom() + if err != nil { + serverErr(w, err) + return + } + now := time.Now() + if _, err = stmt.Exec( + threadID, + user.ID, + args.Title, + now, + now, + user.Username, + ); err != nil { + serverErr(w, err) + return + } - t.Messages = []Message{} + stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - for rows.Next() { - m := &Message{} - if err := rows.Scan( - m.ThreadID, - m.PostID, - m.Author, - m.Created, - m.Edited, - m.Body, - m.SendRaw, - ); err != nil { - serverErr(w, err) - return - } - t.Messages = append(t.Messages, *m) - } + if _, err = stmt.Exec( + threadID, + user.ID, + now, + args.Body, + args.SendRaw, + ); err != nil { + serverErr(w, err) + return + } - writeResponse(w, BBJResponse{Data: t}) + if err = tx.Commit(); err != nil { + serverErr(w, err) + return + } - })) + stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() + + t := &Thread{} + + // TODO fill in rest of thread + if err = stmt.QueryRow(threadID).Scan( + t.ID, + t.Author, + t.Title, + t.LastMod, + t.Created, + t.ReplyCount, + t.Pinned, + t.LastAuthor, + ); err != nil { + serverErr(w, err) + return + } + + stmt, err = db.Prepare("select * from messages where thread_id = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() + rows, err := stmt.Query(threadID) + if err != nil { + serverErr(w, err) + return + } + + t.Messages = []Message{} + + for rows.Next() { + m := &Message{} + if err := rows.Scan( + m.ThreadID, + m.PostID, + m.Author, + m.Created, + m.Edited, + m.Body, + m.SendRaw, + ); err != nil { + serverErr(w, err) + return + } + t.Messages = append(t.Messages, *m) + } + + writeResponse(w, BBJResponse{Data: t}) + + })) */ } diff --git a/server/cmd/schema.sql b/server/cmd/schema.sql deleted file mode 100644 index eb72af6..0000000 --- a/server/cmd/schema.sql +++ /dev/null @@ -1,52 +0,0 @@ -create table meta ( - version text -- schema version -); - -insert into meta values ("1.0.0"); - -create table users ( - user_id text, -- string (uuid1) - user_name text, -- string - auth_hash text, -- string (sha256 hash) - quip text, -- string (possibly empty) - bio text, -- string (possibly empty) - color int, -- int (from 0 to 6) - is_admin int, -- bool - created real -- floating point unix timestamp (when this user registered) -); - -insert into users values ( - "be105a40-6bd1-405f-9716-aa6158ac1eef", -- TODO replace UUID with incrementing int - "anon", - "8e97c0b197816a652fb489b21e63f664863daa991e2f8fd56e2df71593c2793f", - "", - "", - 0, - 0, - 1650819851 -); - --- TODO unique constraint on user_name? --- TODO foreign keys - -create table threads ( - thread_id text, -- uuid string - author text, -- string (uuid1, user.user_id) - title text, -- string - last_mod real, -- floating point unix timestamp (of last post or post edit) - created real, -- floating point unix timestamp (when thread was made) - reply_count int, -- integer (incremental, starting with 0) - pinned int, -- boolean - last_author text -- uuid string -); - - -create table messages ( - thread_id text, -- string (uuid1 of parent thread) - post_id int, -- integer (incrementing from 1) - author text, -- string (uuid1, user.user_id) - created real, -- floating point unix timestamp (when reply was posted) - edited int, -- bool - body text, -- string - send_raw int -- bool (1/true == never apply formatting) -);