refactor get user from req, test stub

trunk
vilmibm 2022-06-21 13:35:34 -05:00
parent 7d92021e94
commit 1ed26d4cbc
2 changed files with 14 additions and 26 deletions

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"time" "time"
"git.tilde.town/tildetown/bbj2/server/cmd/config" "git.tilde.town/tildetown/bbj2/server/cmd/config"
@ -90,34 +89,20 @@ func (a *API) Invoke(w http.ResponseWriter, req *http.Request, apiFn APIHandler)
} }
func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
// TODO abstract sql stuff into db username := req.Header.Get("User")
u = &db.User{} hash := req.Header.Get("Auth")
u.Username = req.Header.Get("User") if username == "" || username == "anon" {
u.Hash = req.Header.Get("Auth")
if u.Username == "" || u.Username == "anon" {
return return
} }
db := opts.DB opts.Logger.Printf("checking auth for %s", username)
stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?")
if err != nil { if u, err = db.GetUserByName(opts.DB, username); err != nil {
err = fmt.Errorf("db error: %w", err)
return return
} }
defer stmt.Close()
opts.Logger.Printf("querying for %s", u.Username) if u.Hash != hash {
u = nil
var authHash string
if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil {
if strings.Contains(err.Error(), "no rows in result") {
err = errors.New("no such user")
} else {
err = fmt.Errorf("db error: %w", err)
}
}
if authHash != u.Hash {
err = errors.New("bad credentials") err = errors.New("bad credentials")
} }

View File

@ -40,7 +40,7 @@ func userCount(db *sql.DB) (count int, err error) {
return return
} }
func TestUserRegister(t *testing.T) { func Test_UserRegister(t *testing.T) {
ts := []struct { ts := []struct {
name string name string
req func() *http.Request req func() *http.Request
@ -157,12 +157,11 @@ func TestUserRegister(t *testing.T) {
t.Errorf("expected error") t.Errorf("expected error")
return return
} }
}) })
} }
} }
func TestInstanceInfo(t *testing.T) { func Test_InstanceInfo(t *testing.T) {
opts, err := createTestState() opts, err := createTestState()
if err != nil { if err != nil {
t.Fatalf("failed to create test state: %s", err.Error()) t.Fatalf("failed to create test state: %s", err.Error())
@ -229,3 +228,7 @@ func TestInstanceInfo(t *testing.T) {
}) })
} }
} }
func Test_getUserFromReq(t *testing.T) {
// TODO
}