From 7fd2547cd121c39b77f53da84c447ea57aaf3760 Mon Sep 17 00:00:00 2001 From: vilmibm Date: Tue, 14 Jun 2022 17:03:46 -0500 Subject: [PATCH] vague unfucking, work on db, test stubs and imp for user-register --- go.mod | 7 +- server/cmd/api/api.go | 76 +++++- server/cmd/api/api_test.go | 82 ++++-- server/cmd/config/config.go | 7 - server/cmd/db/db.go | 35 +++ server/cmd/main.go | 489 +++++++++++++++--------------------- 6 files changed, 373 insertions(+), 323 deletions(-) diff --git a/go.mod b/go.mod index ed636cb..d6cd5f7 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.18 require gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b -require github.com/mattn/go-sqlite3 v1.14.12 - -require github.com/google/uuid v1.3.0 // indirect +require ( + github.com/google/uuid v1.3.0 + github.com/mattn/go-sqlite3 v1.14.12 +) diff --git a/server/cmd/api/api.go b/server/cmd/api/api.go index 618d030..f608647 100644 --- a/server/cmd/api/api.go +++ b/server/cmd/api/api.go @@ -6,6 +6,7 @@ import ( "fmt" "net/http" "strings" + "time" "git.tilde.town/tildetown/bbj2/server/cmd/config" "git.tilde.town/tildetown/bbj2/server/cmd/db" @@ -20,6 +21,14 @@ func (e *HTTPError) Error() string { return fmt.Sprintf("%d %s", e.Code, e.Msg) } +func badMethod() error { + return &HTTPError{Code: 400, Msg: "bad method"} +} + +func invalidArgs(msg string) error { + return &HTTPError{Code: 400, Msg: fmt.Sprintf("invalid args: %s", msg)} +} + type BBJResponse struct { Error bool `json:"error"` Data interface{} `json:"data"` @@ -66,6 +75,7 @@ func Invoke(w http.ResponseWriter, apiFn APIHandler) { } func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { + // TODO abstract sql stuff into db u = &db.User{} u.Username = req.Header.Get("User") u.Hash = req.Header.Get("Auth") @@ -113,22 +123,74 @@ func (a *API) IsPost() bool { return a.Req.Method == "POST" } -func (a *API) InstanceInfo() (*BBJResponse, error) { +func (a *API) InstanceInfo() (resp *BBJResponse, err error) { if !a.IsGet() { - return nil, &HTTPError{Msg: "bad method", Code: 400} + err = badMethod() + return } - return &BBJResponse{ + + resp = &BBJResponse{ Data: instanceInfo{ InstanceName: a.Opts.Config.InstanceName, AllowAnon: a.Opts.Config.AllowAnon, Admins: a.Opts.Config.Admins, }, - }, nil + } + + return } -func (a *API) UserRegister() (*BBJResponse, error) { +func (a *API) UserRegister() (resp *BBJResponse, err error) { if !a.IsPost() { - return nil, &HTTPError{Msg: "bad method", Code: 400} + err = badMethod() + return } - return nil, nil + type AuthArgs struct { + Username string `json:"user_name"` + Hash string `json:"auth_hash"` + } + + var args AuthArgs + if err = json.NewDecoder(a.Req.Body).Decode(&args); err != nil { + err = invalidArgs(err.Error()) + return + } + + if args.Hash == "" || args.Username == "" { + err = invalidArgs(err.Error()) + return + } + + if err = checkAuth(a.Opts, args.Username, args.Hash); err == nil { + a.Opts.Logger.Printf("user %s already registered", args.Username) + err = &HTTPError{Code: 403, Msg: "user already exists"} + return + } else if err.Error() != "no such user" { + err = &HTTPError{Code: 500, Msg: err.Error()} + return + } + + u := db.User{ + Username: args.Username, + Hash: args.Hash, + Created: time.Now(), // TODO inject time + } + + err = db.CreateUser(a.Opts.DB, u) + + return +} + +func checkAuth(opts config.Options, username, hash string) (err error) { + opts.Logger.Printf("querying for %s", username) + var user *db.User + if user, err = db.GetUserByName(opts.DB, username); err != nil { + return + } + + if user.Hash != hash { + err = errors.New("bad credentials") + } + + return } diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 5a10e94..e137b28 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -1,8 +1,6 @@ package api import ( - "bufio" - "bytes" "log" "net/http" "os" @@ -14,22 +12,14 @@ import ( "git.tilde.town/tildetown/bbj2/server/cmd/db" ) -func TestInstanceInfo(t *testing.T) { - // TODO a lot of this needs to be cleaned up and generalized etc - stderr := []byte{} - stdout := []byte{} - testIO := config.IOStreams{ - Err: bufio.NewWriter(bytes.NewBuffer(stderr)), - Out: bufio.NewWriter(bytes.NewBuffer(stdout)), +func createTestState() (opts *config.Options, err error) { + var dbFile *os.File + if dbFile, err = os.CreateTemp("", "bbj2-test"); err != nil { + return } - dbFile, err := os.CreateTemp("", "bbj2-test") - if err != nil { - t.Fatalf("failed to make test db: %s", err.Error()) - } - logger := log.New(os.Stdout, "bbj test", log.Lshortfile) - opts := config.Options{ - IO: testIO, - Logger: logger, + + opts = &config.Options{ + Logger: log.New(os.Stdout, "bbj2 test", log.Lshortfile), Config: config.Config{ Admins: []string{"jillValentine", "rebeccaChambers"}, Port: 666, @@ -40,7 +30,61 @@ func TestInstanceInfo(t *testing.T) { }, } - teardown, err := db.Setup(opts) + return +} + +func TestUserRegister(t *testing.T) { + opts, err := createTestState() + if err != nil { + t.Fatalf("failed to create test state: %s", err.Error()) + return + } + + ts := []struct { + name string + opts config.Options + setup func(opts *config.Options) error + assert func(t *testing.T) error + wantErr *HTTPError + }{ + { + name: "user already exists", + opts: *opts, + setup: func(opts *config.Options) error { + // TODO + return nil + }, + assert: func(t *testing.T) error { + // TODO + return nil + }, + wantErr: &HTTPError{Code: 403, Msg: "user already exists"}, + }, + } + + for _, tt := range ts { + t.Run(tt.name, func(t *testing.T) { + teardown, err := db.Setup(*opts) + if err != nil { + t.Fatalf("could not initialize DB: %s", err.Error()) + return + } + defer teardown() + + // TODO + + }) + } +} + +func TestInstanceInfo(t *testing.T) { + opts, err := createTestState() + if err != nil { + t.Fatalf("failed to create test state: %s", err.Error()) + return + } + + teardown, err := db.Setup(*opts) if err != nil { t.Fatalf("could not initialize DB: %s", err.Error()) return @@ -55,7 +99,7 @@ func TestInstanceInfo(t *testing.T) { }{ { name: "basic", - opts: opts, + opts: *opts, wantData: instanceInfo{ InstanceName: "cool test zone", AllowAnon: true, diff --git a/server/cmd/config/config.go b/server/cmd/config/config.go index 1725ec9..faef903 100644 --- a/server/cmd/config/config.go +++ b/server/cmd/config/config.go @@ -3,7 +3,6 @@ package config import ( "database/sql" "fmt" - "io" "log" "os" @@ -17,11 +16,6 @@ const ( defaultDBPath = "db.sqlite3" ) -type IOStreams struct { - Err io.Writer - Out io.Writer -} - type Config struct { Admins []string Port int @@ -34,7 +28,6 @@ type Config struct { type Options struct { ConfigPath string - IO IOStreams Logger *log.Logger Log func(string) Logf func(string, ...interface{}) diff --git a/server/cmd/db/db.go b/server/cmd/db/db.go index 83d35aa..fc9cbf9 100644 --- a/server/cmd/db/db.go +++ b/server/cmd/db/db.go @@ -10,6 +10,7 @@ import ( "time" "git.tilde.town/tildetown/bbj2/server/cmd/config" + "github.com/google/uuid" _ "github.com/mattn/go-sqlite3" ) @@ -25,6 +26,7 @@ type User struct { ID string Username string Hash string + Created time.Time } type Thread struct { @@ -89,3 +91,36 @@ func EnsureSchema(opts config.Options) error { return nil } + +func GetUserByName(db *sql.DB, username string) (u *User, err error) { + var stmt *sql.Stmt + stmt, err = db.Prepare("select auth_hash from users where user_name = ?") + if err != nil { + return + } + defer stmt.Close() + + if err = stmt.QueryRow(username).Scan(&u); err != nil { + if strings.Contains(err.Error(), "no rows in result") { + err = errors.New("no such user") + } + } + + return +} + +func CreateUser(db *sql.DB, u User) (err error) { + var id uuid.UUID + if id, err = uuid.NewRandom(); err != nil { + return + } + var stmt *sql.Stmt + if stmt, err = db.Prepare(`INSERT INTO users VALUES(?, ?, ?, "", "", 0, 0, ?)`); err != nil { + return + } + defer stmt.Close() + + _, err = stmt.Exec(id, u.Username, u.Hash, u.Created) + + return +} diff --git a/server/cmd/main.go b/server/cmd/main.go index d7dbef3..4d0eafe 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -20,15 +20,10 @@ func main() { var configFlag = flag.String("config", "config.yml", "A path to a config file.") var resetFlag = flag.Bool("reset", false, "reset the database. WARNING this wipes everything.") flag.Parse() - io := config.IOStreams{ - Err: os.Stderr, - Out: os.Stdout, - } - logger := log.New(io.Out, "", log.Ldate|log.Ltime|log.Lshortfile) + logger := log.New(os.Stdout, "", log.Ldate|log.Ltime|log.Lshortfile) opts := &config.Options{ ConfigPath: *configFlag, Reset: *resetFlag, - IO: io, Logger: logger, } @@ -56,8 +51,7 @@ is wild; the error handling is really out of control. I need to think of abstrac func _main(opts *config.Options) error { cfg, err := config.ParseConfig(opts.ConfigPath) if err != nil { - fmt.Fprintf(os.Stderr, "could not read config file '%s'", opts.ConfigPath) - os.Exit(1) + return fmt.Errorf("could not read config file '%s'", opts.ConfigPath) } opts.Config = *cfg @@ -124,313 +118,234 @@ func setupAPI(opts config.Options) { api.Invoke(w, a.UserRegister) })) - /* - func checkAuth(opts config.Options, username, hash string) error { - db := opts.DB - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") - if err != nil { - return fmt.Errorf("db error: %w", err) - } - defer stmt.Close() +} - opts.Logger.Printf("querying for %s", username) - - var authHash string - if err = stmt.QueryRow(username).Scan(&authHash); err != nil { - if strings.Contains(err.Error(), "no rows in result") { - return errors.New("no such user") - } - return fmt.Errorf("db error: %w", err) - } - - if authHash != hash { - return errors.New("bad credentials") - } - - return nil +/* + http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return } - http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + type AuthArgs struct { + Username string `json:"target_user"` + AuthHash string `json:"target_hash"` + } - type AuthArgs struct { - Username string `json:"user_name"` - Hash string `json:"auth_hash"` - } + var args AuthArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + opts.Logf("got %s %s", args.Username, args.AuthHash) - if args.Hash == "" || args.Username == "" { - invalidArgs(w) - return - } + db := opts.DB - opts.Logf("querying for %s", args.Username) + stmt, err := db.Prepare("select auth_hash from users where user_name = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - if err := checkAuth(opts, args.Username, args.Hash); err == nil { - opts.Logf("found %s", args.Username) - // code 4 apparently - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "user already exists", - }) - } else if err.Error() != "no such user" { - serverErr(w, err) - return - } + var authHash string + err = stmt.QueryRow(args.Username).Scan(&authHash) + if err != nil { + if strings.Contains(err.Error(), "no rows in result") { + opts.Logf("user not found") + writeErrorResponse(w, 404, BBJResponse{ + Error: true, + Data: "user not found", + }) + } else { + serverErr(w, err) + } + return + } - db := opts.DB - stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) - id, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } + // TODO unique constraint on user_name - _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) - if err != nil { - serverErr(w, err) - } + if authHash != args.AuthHash { + http.Error(w, "incorrect password", 403) + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: "incorrect password", + }) + return + } - writeResponse(w, BBJResponse{ - Data: true, // TODO probably something else - // TODO prob usermap - }) - })) + // TODO include usermap? + writeResponse(w, BBJResponse{ + Data: true, + }) + })) - http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { + db := opts.DB + rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") + if err != nil { + serverErr(w, err) + return + } + defer rows.Close() + for rows.Next() { + var id string + err = rows.Scan(&id) + if err != nil { + serverErr(w, err) + return + } + opts.Log(id) + } + writeResponse(w, BBJResponse{Data: "TODO"}) + // TODO + })) - type AuthArgs struct { - Username string `json:"target_user"` - AuthHash string `json:"target_hash"` - } + http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + // TODO make this getUserInfoFromReq or similar so we can use the user ID later + user, err := getUserFromReq(opts, req) + if err != nil { + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: err.Error(), + }) + return + } - opts.Logf("got %s %s", args.Username, args.AuthHash) + type threadCreateArgs struct { + Title string + Body string + SendRaw bool `json:"send_raw"` + } - db := opts.DB + var args threadCreateArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + if args.Title == "" || args.Body == "" { + invalidArgs(w) + return + } - var authHash string - err = stmt.QueryRow(args.Username).Scan(&authHash) - if err != nil { - if strings.Contains(err.Error(), "no rows in result") { - opts.Logf("user not found") - writeErrorResponse(w, 404, BBJResponse{ - Error: true, - Data: "user not found", - }) - } else { - serverErr(w, err) - } - return - } + db := opts.DB + tx, err := db.Begin() + if err != nil { + serverErr(w, err) + return + } - // TODO unique constraint on user_name + stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - if authHash != args.AuthHash { - http.Error(w, "incorrect password", 403) - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "incorrect password", - }) - return - } + threadID, err := uuid.NewRandom() + if err != nil { + serverErr(w, err) + return + } + now := time.Now() + if _, err = stmt.Exec( + threadID, + user.ID, + args.Title, + now, + now, + user.Username, + ); err != nil { + serverErr(w, err) + return + } - // TODO include usermap? - writeResponse(w, BBJResponse{ - Data: true, - }) - })) + stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { - db := opts.DB - rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") - if err != nil { - serverErr(w, err) - return - } - defer rows.Close() - for rows.Next() { - var id string - err = rows.Scan(&id) - if err != nil { - serverErr(w, err) - return - } - opts.Log(id) - } - writeResponse(w, BBJResponse{Data: "TODO"}) - // TODO - })) + if _, err = stmt.Exec( + threadID, + user.ID, + now, + args.Body, + args.SendRaw, + ); err != nil { + serverErr(w, err) + return + } - http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + if err = tx.Commit(); err != nil { + serverErr(w, err) + return + } - // TODO make this getUserInfoFromReq or similar so we can use the user ID later - user, err := getUserFromReq(opts, req) - if err != nil { - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: err.Error(), - }) - return - } + stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - type threadCreateArgs struct { - Title string - Body string - SendRaw bool `json:"send_raw"` - } + t := &Thread{} - var args threadCreateArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + // TODO fill in rest of thread + if err = stmt.QueryRow(threadID).Scan( + t.ID, + t.Author, + t.Title, + t.LastMod, + t.Created, + t.ReplyCount, + t.Pinned, + t.LastAuthor, + ); err != nil { + serverErr(w, err) + return + } - if args.Title == "" || args.Body == "" { - invalidArgs(w) - return - } + stmt, err = db.Prepare("select * from messages where thread_id = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() + rows, err := stmt.Query(threadID) + if err != nil { + serverErr(w, err) + return + } - db := opts.DB - tx, err := db.Begin() - if err != nil { - serverErr(w, err) - return - } + t.Messages = []Message{} - stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + for rows.Next() { + m := &Message{} + if err := rows.Scan( + m.ThreadID, + m.PostID, + m.Author, + m.Created, + m.Edited, + m.Body, + m.SendRaw, + ); err != nil { + serverErr(w, err) + return + } + t.Messages = append(t.Messages, *m) + } - threadID, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } - now := time.Now() - if _, err = stmt.Exec( - threadID, - user.ID, - args.Title, - now, - now, - user.Username, - ); err != nil { - serverErr(w, err) - return - } + writeResponse(w, BBJResponse{Data: t}) - stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - - if _, err = stmt.Exec( - threadID, - user.ID, - now, - args.Body, - args.SendRaw, - ); err != nil { - serverErr(w, err) - return - } - - if err = tx.Commit(); err != nil { - serverErr(w, err) - return - } - - stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - - t := &Thread{} - - // TODO fill in rest of thread - if err = stmt.QueryRow(threadID).Scan( - t.ID, - t.Author, - t.Title, - t.LastMod, - t.Created, - t.ReplyCount, - t.Pinned, - t.LastAuthor, - ); err != nil { - serverErr(w, err) - return - } - - stmt, err = db.Prepare("select * from messages where thread_id = ?") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - rows, err := stmt.Query(threadID) - if err != nil { - serverErr(w, err) - return - } - - t.Messages = []Message{} - - for rows.Next() { - m := &Message{} - if err := rows.Scan( - m.ThreadID, - m.PostID, - m.Author, - m.Created, - m.Edited, - m.Body, - m.SendRaw, - ); err != nil { - serverErr(w, err) - return - } - t.Messages = append(t.Messages, *m) - } - - writeResponse(w, BBJResponse{Data: t}) - - })) - */ -} + })) +*/