From 910bb3d00af5f563c48daf1b8918dda30e859732 Mon Sep 17 00:00:00 2001 From: vilmibm Date: Thu, 30 Jun 2022 20:44:58 -0500 Subject: [PATCH] implement check auth --- server/cmd/api/api.go | 49 +++++++++++++++++++++++++-------- server/cmd/main.go | 63 +++---------------------------------------- 2 files changed, 41 insertions(+), 71 deletions(-) diff --git a/server/cmd/api/api.go b/server/cmd/api/api.go index b8d8e92..df94752 100644 --- a/server/cmd/api/api.go +++ b/server/cmd/api/api.go @@ -97,14 +97,7 @@ func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err err opts.Logger.Printf("checking auth for %s", username) - if u, err = db.GetUserByName(opts.DB, username); err != nil { - return - } - - if u.Hash != hash { - u = nil - err = errors.New("bad credentials") - } + u, err = checkAuth(opts, username, hash) return } @@ -132,6 +125,40 @@ func (a *API) InstanceInfo(ctx *ReqCtx) (resp *BBJResponse, err error) { return } +func (a *API) CheckAuth(ctx *ReqCtx) (resp *BBJResponse, err error) { + if !ctx.IsPost() { + err = badMethod() + return + } + type AuthArgs struct { + Username string `json:"target_user"` + Hash string `json:"target_hash"` + } + + var args AuthArgs + if err = json.NewDecoder(ctx.Req.Body).Decode(&args); err != nil { + err = invalidArgs(err.Error()) + return + } + + if args.Hash == "" || args.Username == "" { + err = invalidArgs(err.Error()) + return + } + + _, err = checkAuth(a.Opts, args.Username, args.Hash) + if err != nil { + return + } + + // TODO usermap + resp = &BBJResponse{ + Data: true, + } + + return +} + func (a *API) UserRegister(ctx *ReqCtx) (resp *BBJResponse, err error) { if !ctx.IsPost() { err = badMethod() @@ -153,7 +180,7 @@ func (a *API) UserRegister(ctx *ReqCtx) (resp *BBJResponse, err error) { return } - if err = checkAuth(a.Opts, args.Username, args.Hash); err == nil { + if _, err = checkAuth(a.Opts, args.Username, args.Hash); err == nil { a.Opts.Logger.Printf("user %s already registered", args.Username) err = &HTTPError{Code: 403, Msg: "user already exists"} return @@ -173,9 +200,9 @@ func (a *API) UserRegister(ctx *ReqCtx) (resp *BBJResponse, err error) { return } -func checkAuth(opts config.Options, username, hash string) (err error) { +// checkAuth returns an error if username is not associated with hash +func checkAuth(opts config.Options, username, hash string) (user *db.User, err error) { opts.Logger.Printf("querying for %s", username) - var user *db.User if user, err = db.GetUserByName(opts.DB, username); err != nil { return } diff --git a/server/cmd/main.go b/server/cmd/main.go index 3980750..0062427 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -88,69 +88,12 @@ func setupAPI(opts config.Options) { a.Invoke(w, req, a.UserRegister) }) + http.HandleFunc("/check_auth", func(w http.ResponseWriter, req *http.Request) { + a.Invoke(w, req, a.CheckAuth) + }) } /* - http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } - - type AuthArgs struct { - Username string `json:"target_user"` - AuthHash string `json:"target_hash"` - } - - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } - - opts.Logf("got %s %s", args.Username, args.AuthHash) - - db := opts.DB - - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - - var authHash string - err = stmt.QueryRow(args.Username).Scan(&authHash) - if err != nil { - if strings.Contains(err.Error(), "no rows in result") { - opts.Logf("user not found") - writeErrorResponse(w, 404, BBJResponse{ - Error: true, - Data: "user not found", - }) - } else { - serverErr(w, err) - } - return - } - - // TODO unique constraint on user_name - - if authHash != args.AuthHash { - http.Error(w, "incorrect password", 403) - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "incorrect password", - }) - return - } - - // TODO include usermap? - writeResponse(w, BBJResponse{ - Data: true, - }) - })) - http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { db := opts.DB rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id")