package api import ( "database/sql" "errors" "log" "net/http" "os" "reflect" "strings" "testing" "time" "git.tilde.town/tildetown/bbj2/server/cmd/config" "git.tilde.town/tildetown/bbj2/server/cmd/db" ) func createTestState() (opts *config.Options, err error) { var dbFile *os.File if dbFile, err = os.CreateTemp("", "bbj2-test"); err != nil { return } opts = &config.Options{ Logger: log.New(os.Stdout, "bbj2 test", log.Lshortfile), Config: config.Config{ Admins: []string{"jillValentine", "rebeccaChambers"}, Port: 666, Host: "hell.cool", InstanceName: "cool test zone", AllowAnon: true, DBPath: dbFile.Name(), }, } return } func userCount(db *sql.DB) (count int, err error) { err = db.QueryRow("SELECT count(*) FROM users").Scan(&count) return } func Test_UserRegister(t *testing.T) { ts := []struct { name string req func() *http.Request setup func(*config.Options) error assert func(*config.Options, *testing.T) error wantErr *HTTPError }{ { name: "user already exists", setup: func(opts *config.Options) error { return db.CreateUser(opts.DB, db.User{ Username: "albertwesker", Hash: "1234abc", Created: time.Now(), }) }, assert: func(opts *config.Options, t *testing.T) (err error) { var count int count, err = userCount(opts.DB) if count != 2 { t.Errorf("expected 2 users, got %d", count) } return }, wantErr: &HTTPError{Code: 403, Msg: "user already exists"}, }, { name: "bad method", req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) return r }, wantErr: &HTTPError{ Msg: "bad method", Code: 400, }, }, { name: "add new user", req: func() *http.Request { r, _ := http.NewRequest("POST", "", strings.NewReader(`{"user_name":"chrisredfield", "auth_hash":"abc123"}`)) return r }, assert: func(opts *config.Options, t *testing.T) (err error) { var count int count, err = userCount(opts.DB) if count != 2 { t.Errorf("expected 2 users, got %d", count) } return }, }, // TODO bad args } for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { opts, err := createTestState() if err != nil { t.Fatalf("failed to create test state: %s", err.Error()) return } var req *http.Request if tt.req == nil { req, _ = http.NewRequest("POST", "", strings.NewReader(`{"user_name":"albertwesker","auth_hash":"1234abc"}`)) } else { req = tt.req() } teardown, err := db.Setup(opts) if err != nil { t.Fatalf("could not initialize DB: %s", err.Error()) return } defer teardown() err = db.EnsureSchema(*opts) if err != nil { t.Fatalf("could not initialize DB: %s", err.Error()) return } if tt.setup != nil { err = tt.setup(opts) if err != nil { t.Fatalf("setup failed: %s", err.Error()) return } } api := &API{Opts: *opts} ctx := &ReqCtx{Req: req} _, err = api.UserRegister(ctx) if tt.assert != nil { err := tt.assert(opts, t) if err != nil { t.Fatal(err.Error()) return } } if tt.wantErr != nil && err != nil { if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("got unwanted error: %s", err.Error()) } return } if tt.wantErr != nil && err == nil { t.Errorf("expected error") return } }) } } func Test_InstanceInfo(t *testing.T) { opts, err := createTestState() if err != nil { t.Fatalf("failed to create test state: %s", err.Error()) return } ts := []struct { name string req func() *http.Request wantData instanceInfo wantErr *HTTPError }{ { name: "basic", wantData: instanceInfo{ InstanceName: "cool test zone", AllowAnon: true, Admins: []string{"jillValentine", "rebeccaChambers"}, }, }, { name: "bad method", req: func() *http.Request { r, _ := http.NewRequest("POST", "", strings.NewReader("")) return r }, wantErr: &HTTPError{ Msg: "bad method", Code: 400, }, }, } for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { var req *http.Request if tt.req == nil { req, _ = http.NewRequest("GET", "", strings.NewReader("")) } else { req = tt.req() } api := &API{Opts: *opts} ctx := &ReqCtx{Req: req} resp, err := api.InstanceInfo(ctx) if tt.wantErr != nil && err != nil { if !reflect.DeepEqual(tt.wantErr, err) { t.Errorf("got unwanted error: %s", err.Error()) } return } if tt.wantErr != nil && err == nil { t.Errorf("expected error") return } ii, ok := resp.Data.(instanceInfo) if !ok { t.Errorf("could not cast data in %s", tt.name) } if !reflect.DeepEqual(ii, tt.wantData) { t.Errorf("did not get expected data in %s", tt.name) } }) } } func Test_getUserFromReq(t *testing.T) { ts := []struct { name string req func() *http.Request setup func(*config.Options) error assert func(*db.User, *testing.T) wantErr error }{ { name: "no auth attempt", req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) return r }, assert: func(u *db.User, t *testing.T) { if u != nil { t.Errorf("expected nil, got %v", u) } }, }, { name: "anon", req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) r.Header.Add("User", "anon") return r }, assert: func(u *db.User, t *testing.T) { if u != nil { t.Errorf("expected nil, got %v", u) } }, }, { name: "no such user", req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) r.Header.Set("User", "williambirkin") r.Header.Set("Auth", "abc123") return r }, wantErr: errors.New("no such user"), }, { name: "bad creds", setup: func(opts *config.Options) error { return db.CreateUser(opts.DB, db.User{ Username: "jillvalentine", Hash: "abc123", Created: time.Now(), }) }, req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) r.Header.Set("User", "jillvalentine") r.Header.Set("Auth", "xyz789") return r }, wantErr: errors.New("bad credentials"), }, { name: "good creds", setup: func(opts *config.Options) error { return db.CreateUser(opts.DB, db.User{ Username: "jillvalentine", Hash: "abc123", Created: time.Now(), }) }, req: func() *http.Request { r, _ := http.NewRequest("GET", "", strings.NewReader("")) r.Header.Set("User", "jillvalentine") r.Header.Set("Auth", "abc123") return r }, assert: func(u *db.User, t *testing.T) { if u.Username != "jillvalentine" { t.Errorf("expected 'jillvalentine' got %s", u.Username) } }, }, } for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { // TODO BOILERPLATE opts, err := createTestState() if err != nil { t.Fatalf("failed to create test state: %s", err.Error()) return } var req *http.Request if tt.req == nil { req, _ = http.NewRequest("POST", "", strings.NewReader(`{"user_name":"albertwesker","auth_hash":"1234abc"}`)) } else { req = tt.req() } teardown, err := db.Setup(opts) if err != nil { t.Fatalf("could not initialize DB: %s", err.Error()) return } defer teardown() err = db.EnsureSchema(*opts) if err != nil { t.Fatalf("could not initialize DB: %s", err.Error()) return } if tt.setup != nil { err = tt.setup(opts) if err != nil { t.Fatalf("setup failed: %s", err.Error()) return } } // END BOILERPLATE u, err := getUserFromReq(*opts, req) if err != nil { if tt.wantErr == nil || tt.wantErr.Error() != err.Error() { t.Errorf("got unexpected error: %s", err) } return } else { if tt.wantErr != nil { t.Error("expected error, got none") return } } tt.assert(u, t) }) } }