diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 3dcb3c9..694aaa2 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -1,6 +1,7 @@ package api import ( + "database/sql" "log" "net/http" "os" @@ -34,13 +35,12 @@ func createTestState() (opts *config.Options, err error) { return } -func TestUserRegister(t *testing.T) { - opts, err := createTestState() - if err != nil { - t.Fatalf("failed to create test state: %s", err.Error()) - return - } +func userCount(db *sql.DB) (count int, err error) { + err = db.QueryRow("SELECT count(*) FROM users").Scan(&count) + return +} +func TestUserRegister(t *testing.T) { ts := []struct { name string req func() *http.Request @@ -57,18 +57,15 @@ func TestUserRegister(t *testing.T) { Created: time.Now(), }) }, - assert: func(opts *config.Options, t *testing.T) error { + assert: func(opts *config.Options, t *testing.T) (err error) { var count int - var err error - if err = opts.DB.QueryRow("SELECT count(*) FROM users").Scan(&count); err != nil { - return err + count, err = userCount(opts.DB) + + if count != 2 { + t.Errorf("expected 2 users, got %d", count) } - if count != 1 { - t.Errorf("expected 1 user, got %d", count) - } - - return nil + return }, wantErr: &HTTPError{Code: 403, Msg: "user already exists"}, }, @@ -83,11 +80,33 @@ func TestUserRegister(t *testing.T) { Code: 400, }, }, - // TODO + { + name: "add new user", + req: func() *http.Request { + r, _ := http.NewRequest("POST", "", strings.NewReader(`{"user_name":"chrisredfield", "auth_hash":"abc123"}`)) + return r + }, + assert: func(opts *config.Options, t *testing.T) (err error) { + var count int + count, err = userCount(opts.DB) + + if count != 2 { + t.Errorf("expected 2 users, got %d", count) + } + + return + }, + }, + // TODO bad args } for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { + opts, err := createTestState() + if err != nil { + t.Fatalf("failed to create test state: %s", err.Error()) + return + } var req *http.Request if tt.req == nil { req, _ = http.NewRequest("POST", "", strings.NewReader(`{"user_name":"albertwesker","auth_hash":"1234abc"}`)) @@ -118,6 +137,15 @@ func TestUserRegister(t *testing.T) { api := &API{Opts: *opts} ctx := &ReqCtx{Req: req} _, err = api.UserRegister(ctx) + + if tt.assert != nil { + err := tt.assert(opts, t) + if err != nil { + t.Fatal(err.Error()) + return + } + } + if tt.wantErr != nil && err != nil { if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("got unwanted error: %s", err.Error()) @@ -125,14 +153,6 @@ func TestUserRegister(t *testing.T) { return } - if tt.assert != nil { - err = tt.assert(opts, t) - if err != nil { - t.Fatal(err.Error()) - return - } - } - if tt.wantErr != nil && err == nil { t.Errorf("expected error") return