From 1ed26d4cbca8582f80848e3b58d4f0b488d90bcc Mon Sep 17 00:00:00 2001 From: vilmibm Date: Tue, 21 Jun 2022 13:35:34 -0500 Subject: [PATCH] refactor get user from req, test stub --- server/cmd/api/api.go | 31 ++++++++----------------------- server/cmd/api/api_test.go | 9 ++++++--- 2 files changed, 14 insertions(+), 26 deletions(-) diff --git a/server/cmd/api/api.go b/server/cmd/api/api.go index eb27bad..b8d8e92 100644 --- a/server/cmd/api/api.go +++ b/server/cmd/api/api.go @@ -5,7 +5,6 @@ import ( "errors" "fmt" "net/http" - "strings" "time" "git.tilde.town/tildetown/bbj2/server/cmd/config" @@ -90,34 +89,20 @@ func (a *API) Invoke(w http.ResponseWriter, req *http.Request, apiFn APIHandler) } func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { - // TODO abstract sql stuff into db - u = &db.User{} - u.Username = req.Header.Get("User") - u.Hash = req.Header.Get("Auth") - if u.Username == "" || u.Username == "anon" { + username := req.Header.Get("User") + hash := req.Header.Get("Auth") + if username == "" || username == "anon" { return } - db := opts.DB - stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?") - if err != nil { - err = fmt.Errorf("db error: %w", err) + opts.Logger.Printf("checking auth for %s", username) + + if u, err = db.GetUserByName(opts.DB, username); err != nil { return } - defer stmt.Close() - opts.Logger.Printf("querying for %s", u.Username) - - var authHash string - if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil { - if strings.Contains(err.Error(), "no rows in result") { - err = errors.New("no such user") - } else { - err = fmt.Errorf("db error: %w", err) - } - } - - if authHash != u.Hash { + if u.Hash != hash { + u = nil err = errors.New("bad credentials") } diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 694aaa2..990eeb9 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -40,7 +40,7 @@ func userCount(db *sql.DB) (count int, err error) { return } -func TestUserRegister(t *testing.T) { +func Test_UserRegister(t *testing.T) { ts := []struct { name string req func() *http.Request @@ -157,12 +157,11 @@ func TestUserRegister(t *testing.T) { t.Errorf("expected error") return } - }) } } -func TestInstanceInfo(t *testing.T) { +func Test_InstanceInfo(t *testing.T) { opts, err := createTestState() if err != nil { t.Fatalf("failed to create test state: %s", err.Error()) @@ -229,3 +228,7 @@ func TestInstanceInfo(t *testing.T) { }) } } + +func Test_getUserFromReq(t *testing.T) { + // TODO +}