fix test, add test
This commit is contained in:
		
							parent
							
								
									62394e3f77
								
							
						
					
					
						commit
						7d92021e94
					
				@ -1,6 +1,7 @@
 | 
			
		||||
package api
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"database/sql"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
@ -34,13 +35,12 @@ func createTestState() (opts *config.Options, err error) {
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUserRegister(t *testing.T) {
 | 
			
		||||
	opts, err := createTestState()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		t.Fatalf("failed to create test state: %s", err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
func userCount(db *sql.DB) (count int, err error) {
 | 
			
		||||
	err = db.QueryRow("SELECT count(*) FROM users").Scan(&count)
 | 
			
		||||
	return
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestUserRegister(t *testing.T) {
 | 
			
		||||
	ts := []struct {
 | 
			
		||||
		name    string
 | 
			
		||||
		req     func() *http.Request
 | 
			
		||||
@ -57,18 +57,15 @@ func TestUserRegister(t *testing.T) {
 | 
			
		||||
					Created:  time.Now(),
 | 
			
		||||
				})
 | 
			
		||||
			},
 | 
			
		||||
			assert: func(opts *config.Options, t *testing.T) error {
 | 
			
		||||
			assert: func(opts *config.Options, t *testing.T) (err error) {
 | 
			
		||||
				var count int
 | 
			
		||||
				var err error
 | 
			
		||||
				if err = opts.DB.QueryRow("SELECT count(*) FROM users").Scan(&count); err != nil {
 | 
			
		||||
					return err
 | 
			
		||||
				count, err = userCount(opts.DB)
 | 
			
		||||
 | 
			
		||||
				if count != 2 {
 | 
			
		||||
					t.Errorf("expected 2 users, got %d", count)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if count != 1 {
 | 
			
		||||
					t.Errorf("expected 1 user, got %d", count)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return nil
 | 
			
		||||
				return
 | 
			
		||||
			},
 | 
			
		||||
			wantErr: &HTTPError{Code: 403, Msg: "user already exists"},
 | 
			
		||||
		},
 | 
			
		||||
@ -83,11 +80,33 @@ func TestUserRegister(t *testing.T) {
 | 
			
		||||
				Code: 400,
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		// TODO
 | 
			
		||||
		{
 | 
			
		||||
			name: "add new user",
 | 
			
		||||
			req: func() *http.Request {
 | 
			
		||||
				r, _ := http.NewRequest("POST", "", strings.NewReader(`{"user_name":"chrisredfield", "auth_hash":"abc123"}`))
 | 
			
		||||
				return r
 | 
			
		||||
			},
 | 
			
		||||
			assert: func(opts *config.Options, t *testing.T) (err error) {
 | 
			
		||||
				var count int
 | 
			
		||||
				count, err = userCount(opts.DB)
 | 
			
		||||
 | 
			
		||||
				if count != 2 {
 | 
			
		||||
					t.Errorf("expected 2 users, got %d", count)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				return
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		// TODO bad args
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, tt := range ts {
 | 
			
		||||
		t.Run(tt.name, func(t *testing.T) {
 | 
			
		||||
			opts, err := createTestState()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				t.Fatalf("failed to create test state: %s", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			var req *http.Request
 | 
			
		||||
			if tt.req == nil {
 | 
			
		||||
				req, _ = http.NewRequest("POST", "", strings.NewReader(`{"user_name":"albertwesker","auth_hash":"1234abc"}`))
 | 
			
		||||
@ -118,6 +137,15 @@ func TestUserRegister(t *testing.T) {
 | 
			
		||||
			api := &API{Opts: *opts}
 | 
			
		||||
			ctx := &ReqCtx{Req: req}
 | 
			
		||||
			_, err = api.UserRegister(ctx)
 | 
			
		||||
 | 
			
		||||
			if tt.assert != nil {
 | 
			
		||||
				err := tt.assert(opts, t)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err.Error())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if tt.wantErr != nil && err != nil {
 | 
			
		||||
				if !reflect.DeepEqual(tt.wantErr, err) {
 | 
			
		||||
					t.Errorf("got unwanted error: %s", err.Error())
 | 
			
		||||
@ -125,14 +153,6 @@ func TestUserRegister(t *testing.T) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if tt.assert != nil {
 | 
			
		||||
				err = tt.assert(opts, t)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					t.Fatal(err.Error())
 | 
			
		||||
					return
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if tt.wantErr != nil && err == nil {
 | 
			
		||||
				t.Errorf("expected error")
 | 
			
		||||
				return
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user