refactor get user from req, test stub
parent
7d92021e94
commit
1ed26d4cbc
|
@ -5,7 +5,6 @@ import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"git.tilde.town/tildetown/bbj2/server/cmd/config"
|
"git.tilde.town/tildetown/bbj2/server/cmd/config"
|
||||||
|
@ -90,34 +89,20 @@ func (a *API) Invoke(w http.ResponseWriter, req *http.Request, apiFn APIHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
|
func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
|
||||||
// TODO abstract sql stuff into db
|
username := req.Header.Get("User")
|
||||||
u = &db.User{}
|
hash := req.Header.Get("Auth")
|
||||||
u.Username = req.Header.Get("User")
|
if username == "" || username == "anon" {
|
||||||
u.Hash = req.Header.Get("Auth")
|
|
||||||
if u.Username == "" || u.Username == "anon" {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
db := opts.DB
|
opts.Logger.Printf("checking auth for %s", username)
|
||||||
stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?")
|
|
||||||
if err != nil {
|
if u, err = db.GetUserByName(opts.DB, username); err != nil {
|
||||||
err = fmt.Errorf("db error: %w", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer stmt.Close()
|
|
||||||
|
|
||||||
opts.Logger.Printf("querying for %s", u.Username)
|
if u.Hash != hash {
|
||||||
|
u = nil
|
||||||
var authHash string
|
|
||||||
if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil {
|
|
||||||
if strings.Contains(err.Error(), "no rows in result") {
|
|
||||||
err = errors.New("no such user")
|
|
||||||
} else {
|
|
||||||
err = fmt.Errorf("db error: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if authHash != u.Hash {
|
|
||||||
err = errors.New("bad credentials")
|
err = errors.New("bad credentials")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ func userCount(db *sql.DB) (count int, err error) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUserRegister(t *testing.T) {
|
func Test_UserRegister(t *testing.T) {
|
||||||
ts := []struct {
|
ts := []struct {
|
||||||
name string
|
name string
|
||||||
req func() *http.Request
|
req func() *http.Request
|
||||||
|
@ -157,12 +157,11 @@ func TestUserRegister(t *testing.T) {
|
||||||
t.Errorf("expected error")
|
t.Errorf("expected error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestInstanceInfo(t *testing.T) {
|
func Test_InstanceInfo(t *testing.T) {
|
||||||
opts, err := createTestState()
|
opts, err := createTestState()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create test state: %s", err.Error())
|
t.Fatalf("failed to create test state: %s", err.Error())
|
||||||
|
@ -229,3 +228,7 @@ func TestInstanceInfo(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_getUserFromReq(t *testing.T) {
|
||||||
|
// TODO
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue