forked from tildetown/bbj2
		
	refactor get user from req, test stub
This commit is contained in:
		
							parent
							
								
									7d92021e94
								
							
						
					
					
						commit
						1ed26d4cbc
					
				@ -5,7 +5,6 @@ import (
 | 
				
			|||||||
	"errors"
 | 
						"errors"
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"net/http"
 | 
						"net/http"
 | 
				
			||||||
	"strings"
 | 
					 | 
				
			||||||
	"time"
 | 
						"time"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	"git.tilde.town/tildetown/bbj2/server/cmd/config"
 | 
						"git.tilde.town/tildetown/bbj2/server/cmd/config"
 | 
				
			||||||
@ -90,34 +89,20 @@ func (a *API) Invoke(w http.ResponseWriter, req *http.Request, apiFn APIHandler)
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
 | 
					func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
 | 
				
			||||||
	// TODO abstract sql stuff into db
 | 
						username := req.Header.Get("User")
 | 
				
			||||||
	u = &db.User{}
 | 
						hash := req.Header.Get("Auth")
 | 
				
			||||||
	u.Username = req.Header.Get("User")
 | 
						if username == "" || username == "anon" {
 | 
				
			||||||
	u.Hash = req.Header.Get("Auth")
 | 
					 | 
				
			||||||
	if u.Username == "" || u.Username == "anon" {
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	db := opts.DB
 | 
						opts.Logger.Printf("checking auth for %s", username)
 | 
				
			||||||
	stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?")
 | 
					
 | 
				
			||||||
	if err != nil {
 | 
						if u, err = db.GetUserByName(opts.DB, username); err != nil {
 | 
				
			||||||
		err = fmt.Errorf("db error: %w", err)
 | 
					 | 
				
			||||||
		return
 | 
							return
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	defer stmt.Close()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
	opts.Logger.Printf("querying for %s", u.Username)
 | 
						if u.Hash != hash {
 | 
				
			||||||
 | 
							u = nil
 | 
				
			||||||
	var authHash string
 | 
					 | 
				
			||||||
	if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil {
 | 
					 | 
				
			||||||
		if strings.Contains(err.Error(), "no rows in result") {
 | 
					 | 
				
			||||||
			err = errors.New("no such user")
 | 
					 | 
				
			||||||
		} else {
 | 
					 | 
				
			||||||
			err = fmt.Errorf("db error: %w", err)
 | 
					 | 
				
			||||||
		}
 | 
					 | 
				
			||||||
	}
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
	if authHash != u.Hash {
 | 
					 | 
				
			||||||
		err = errors.New("bad credentials")
 | 
							err = errors.New("bad credentials")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
				
			|||||||
@ -40,7 +40,7 @@ func userCount(db *sql.DB) (count int, err error) {
 | 
				
			|||||||
	return
 | 
						return
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestUserRegister(t *testing.T) {
 | 
					func Test_UserRegister(t *testing.T) {
 | 
				
			||||||
	ts := []struct {
 | 
						ts := []struct {
 | 
				
			||||||
		name    string
 | 
							name    string
 | 
				
			||||||
		req     func() *http.Request
 | 
							req     func() *http.Request
 | 
				
			||||||
@ -157,12 +157,11 @@ func TestUserRegister(t *testing.T) {
 | 
				
			|||||||
				t.Errorf("expected error")
 | 
									t.Errorf("expected error")
 | 
				
			||||||
				return
 | 
									return
 | 
				
			||||||
			}
 | 
								}
 | 
				
			||||||
 | 
					 | 
				
			||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestInstanceInfo(t *testing.T) {
 | 
					func Test_InstanceInfo(t *testing.T) {
 | 
				
			||||||
	opts, err := createTestState()
 | 
						opts, err := createTestState()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		t.Fatalf("failed to create test state: %s", err.Error())
 | 
							t.Fatalf("failed to create test state: %s", err.Error())
 | 
				
			||||||
@ -229,3 +228,7 @@ func TestInstanceInfo(t *testing.T) {
 | 
				
			|||||||
		})
 | 
							})
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func Test_getUserFromReq(t *testing.T) {
 | 
				
			||||||
 | 
						// TODO
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user