bbj2/server/cmd/api/api.go

204 lines
4.3 KiB
Go

package api
import (
"encoding/json"
"errors"
"fmt"
"net/http"
"strings"
"time"
"git.tilde.town/tildetown/bbj2/server/cmd/config"
"git.tilde.town/tildetown/bbj2/server/cmd/db"
)
type HTTPError struct {
Msg string
Code int
}
func (e *HTTPError) Error() string {
return fmt.Sprintf("%d %s", e.Code, e.Msg)
}
func badMethod() error {
return &HTTPError{Code: 400, Msg: "bad method"}
}
func invalidArgs(msg string) error {
return &HTTPError{Code: 400, Msg: fmt.Sprintf("invalid args: %s", msg)}
}
type BBJResponse struct {
Error bool `json:"error"`
Data interface{} `json:"data"`
Usermap map[string]db.User `json:"usermap"`
}
type APIHandler func(*ReqCtx) (*BBJResponse, error)
type API struct {
Opts config.Options
}
func NewAPI(opts config.Options) *API {
return &API{Opts: opts}
}
type ReqCtx struct {
User db.User
Req *http.Request
}
func (c *ReqCtx) IsGet() bool {
return c.Req.Method == "GET"
}
func (c *ReqCtx) IsPost() bool {
return c.Req.Method == "POST"
}
func (a *API) Invoke(w http.ResponseWriter, req *http.Request, apiFn APIHandler) {
a.Opts.Logger.Printf("<- %s", req.URL.Path)
user, err := getUserFromReq(a.Opts, req)
if err != nil {
a.Opts.Logger.Printf("failed to get user from req: %s", err.Error())
w.WriteHeader(500)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(BBJResponse{
Error: true,
Data: "server error check logs",
})
}
resp, err := apiFn(&ReqCtx{*user, req})
if err != nil {
he := &HTTPError{}
_ = errors.As(err, &he)
resp := BBJResponse{
Error: true,
Data: he.Msg,
}
w.WriteHeader(he.Code)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
return
}
w.WriteHeader(http.StatusOK)
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(resp)
}
func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) {
// TODO abstract sql stuff into db
u = &db.User{}
u.Username = req.Header.Get("User")
u.Hash = req.Header.Get("Auth")
if u.Username == "" || u.Username == "anon" {
return
}
db := opts.DB
stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?")
if err != nil {
err = fmt.Errorf("db error: %w", err)
return
}
defer stmt.Close()
opts.Logger.Printf("querying for %s", u.Username)
var authHash string
if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil {
if strings.Contains(err.Error(), "no rows in result") {
err = errors.New("no such user")
} else {
err = fmt.Errorf("db error: %w", err)
}
}
if authHash != u.Hash {
err = errors.New("bad credentials")
}
return
}
type instanceInfo struct {
InstanceName string `json:"instance_name"`
AllowAnon bool `json:"allow_anon"`
Admins []string
}
func (a *API) InstanceInfo(ctx *ReqCtx) (resp *BBJResponse, err error) {
if !ctx.IsGet() {
err = badMethod()
return
}
resp = &BBJResponse{
Data: instanceInfo{
InstanceName: a.Opts.Config.InstanceName,
AllowAnon: a.Opts.Config.AllowAnon,
Admins: a.Opts.Config.Admins,
},
}
return
}
func (a *API) UserRegister(ctx *ReqCtx) (resp *BBJResponse, err error) {
if !ctx.IsPost() {
err = badMethod()
return
}
type AuthArgs struct {
Username string `json:"user_name"`
Hash string `json:"auth_hash"`
}
var args AuthArgs
if err = json.NewDecoder(ctx.Req.Body).Decode(&args); err != nil {
err = invalidArgs(err.Error())
return
}
if args.Hash == "" || args.Username == "" {
err = invalidArgs(err.Error())
return
}
if err = checkAuth(a.Opts, args.Username, args.Hash); err == nil {
a.Opts.Logger.Printf("user %s already registered", args.Username)
err = &HTTPError{Code: 403, Msg: "user already exists"}
return
} else if err.Error() != "no such user" {
err = &HTTPError{Code: 500, Msg: err.Error()}
return
}
u := db.User{
Username: args.Username,
Hash: args.Hash,
Created: time.Now(), // TODO inject time
}
err = db.CreateUser(a.Opts.DB, u)
return
}
func checkAuth(opts config.Options, username, hash string) (err error) {
opts.Logger.Printf("querying for %s", username)
var user *db.User
if user, err = db.GetUserByName(opts.DB, username); err != nil {
return
}
if user.Hash != hash {
err = errors.New("bad credentials")
}
return
}