diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 75cbe9a..3dcb3c9 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -44,8 +44,8 @@ func TestUserRegister(t *testing.T) { ts := []struct { name string req func() *http.Request - setup func(opts *config.Options) error - assert func(t *testing.T) error + setup func(*config.Options) error + assert func(*config.Options, *testing.T) error wantErr *HTTPError }{ { @@ -57,12 +57,32 @@ func TestUserRegister(t *testing.T) { Created: time.Now(), }) }, - assert: func(t *testing.T) error { - // TODO ensure user count is still 1 + assert: func(opts *config.Options, t *testing.T) error { + var count int + var err error + if err = opts.DB.QueryRow("SELECT count(*) FROM users").Scan(&count); err != nil { + return err + } + + if count != 1 { + t.Errorf("expected 1 user, got %d", count) + } + return nil }, wantErr: &HTTPError{Code: 403, Msg: "user already exists"}, }, + { + name: "bad method", + req: func() *http.Request { + r, _ := http.NewRequest("GET", "", strings.NewReader("")) + return r + }, + wantErr: &HTTPError{ + Msg: "bad method", + Code: 400, + }, + }, // TODO } @@ -70,7 +90,7 @@ func TestUserRegister(t *testing.T) { t.Run(tt.name, func(t *testing.T) { var req *http.Request if tt.req == nil { - req, _ = http.NewRequest("GET", "", strings.NewReader("")) + req, _ = http.NewRequest("POST", "", strings.NewReader(`{"user_name":"albertwesker","auth_hash":"1234abc"}`)) } else { req = tt.req() } @@ -87,15 +107,17 @@ func TestUserRegister(t *testing.T) { return } - err = tt.setup(opts) - if err != nil { - t.Fatalf("setup failed: %s", err.Error()) - return + if tt.setup != nil { + err = tt.setup(opts) + if err != nil { + t.Fatalf("setup failed: %s", err.Error()) + return + } } api := &API{Opts: *opts} ctx := &ReqCtx{Req: req} - _, err = api.InstanceInfo(ctx) + _, err = api.UserRegister(ctx) if tt.wantErr != nil && err != nil { if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("got unwanted error: %s", err.Error()) @@ -103,16 +125,19 @@ func TestUserRegister(t *testing.T) { return } + if tt.assert != nil { + err = tt.assert(opts, t) + if err != nil { + t.Fatal(err.Error()) + return + } + } + if tt.wantErr != nil && err == nil { t.Errorf("expected error") return } - err = tt.assert(t) - if err != nil { - t.Fatal(err.Error()) - return - } }) } } diff --git a/server/cmd/db/db.go b/server/cmd/db/db.go index 68500ff..77bd771 100644 --- a/server/cmd/db/db.go +++ b/server/cmd/db/db.go @@ -95,13 +95,20 @@ func EnsureSchema(opts config.Options) error { func GetUserByName(db *sql.DB, username string) (u *User, err error) { var stmt *sql.Stmt - stmt, err = db.Prepare("select auth_hash from users where user_name = ?") + stmt, err = db.Prepare("select user_id, user_name, auth_hash from users where user_name = ?") if err != nil { return } defer stmt.Close() - if err = stmt.QueryRow(username).Scan(&u); err != nil { + u = &User{} + + if err = stmt.QueryRow(username).Scan( + &u.ID, + &u.Username, + &u.Hash, + // TODO support the rest + ); err != nil { if strings.Contains(err.Error(), "no rows in result") { err = errors.New("no such user") }