diff --git a/server/cmd/api/api.go b/server/cmd/api/api.go index bc8ed5c..618d030 100644 --- a/server/cmd/api/api.go +++ b/server/cmd/api/api.go @@ -1,7 +1,11 @@ package api import ( + "encoding/json" + "errors" "fmt" + "net/http" + "strings" "git.tilde.town/tildetown/bbj2/server/cmd/config" "git.tilde.town/tildetown/bbj2/server/cmd/db" @@ -22,9 +26,77 @@ type BBJResponse struct { Usermap map[string]db.User `json:"usermap"` } +type APIHandler func() (*BBJResponse, error) + type API struct { User *db.User Opts config.Options + Req *http.Request +} + +func NewAPI(opts config.Options, req *http.Request) (*API, error) { + user, err := getUserFromReq(opts, req) + if err != nil { + return nil, &HTTPError{Msg: err.Error(), Code: 403} + } + return &API{ + Opts: opts, + User: user, + Req: req, + }, nil +} + +func Invoke(w http.ResponseWriter, apiFn APIHandler) { + resp, err := apiFn() + if err != nil { + he := &HTTPError{} + _ = errors.As(err, &he) + resp := BBJResponse{ + Error: true, + Data: he.Msg, + } + w.WriteHeader(he.Code) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) + return + } + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(resp) +} + +func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { + u = &db.User{} + u.Username = req.Header.Get("User") + u.Hash = req.Header.Get("Auth") + if u.Username == "" || u.Username == "anon" { + return + } + + db := opts.DB + stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?") + if err != nil { + err = fmt.Errorf("db error: %w", err) + return + } + defer stmt.Close() + + opts.Logger.Printf("querying for %s", u.Username) + + var authHash string + if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil { + if strings.Contains(err.Error(), "no rows in result") { + err = errors.New("no such user") + } else { + err = fmt.Errorf("db error: %w", err) + } + } + + if authHash != u.Hash { + err = errors.New("bad credentials") + } + + return } type instanceInfo struct { @@ -33,7 +105,18 @@ type instanceInfo struct { Admins []string } +func (a *API) IsGet() bool { + return a.Req.Method == "GET" +} + +func (a *API) IsPost() bool { + return a.Req.Method == "POST" +} + func (a *API) InstanceInfo() (*BBJResponse, error) { + if !a.IsGet() { + return nil, &HTTPError{Msg: "bad method", Code: 400} + } return &BBJResponse{ Data: instanceInfo{ InstanceName: a.Opts.Config.InstanceName, @@ -43,4 +126,9 @@ func (a *API) InstanceInfo() (*BBJResponse, error) { }, nil } -type ApiHandler func() (*BBJResponse, error) +func (a *API) UserRegister() (*BBJResponse, error) { + if !a.IsPost() { + return nil, &HTTPError{Msg: "bad method", Code: 400} + } + return nil, nil +} diff --git a/server/cmd/api/api_test.go b/server/cmd/api/api_test.go index 77bed04..77e7bfc 100644 --- a/server/cmd/api/api_test.go +++ b/server/cmd/api/api_test.go @@ -3,15 +3,18 @@ package api import ( "bufio" "bytes" - "fmt" + "log" + "net/http" "os" "reflect" + "strings" "testing" "git.tilde.town/tildetown/bbj2/server/cmd/config" ) func TestInstanceInfo(t *testing.T) { + // TODO a lot of this needs to be cleaned up and generalized etc stderr := []byte{} stdout := []byte{} testIO := config.IOStreams{ @@ -22,13 +25,10 @@ func TestInstanceInfo(t *testing.T) { if err != nil { t.Fatalf("failed to make test db: %s", err.Error()) } + logger := log.New(os.Stdout, "bbj test", log.Lshortfile) defaultOptions := config.Options{ - IO: testIO, - Log: func(s string) { fmt.Fprintln(testIO.Out, s) }, - Logf: func(s string, args ...interface{}) { - fmt.Fprintf(testIO.Out, s, args...) - fmt.Fprintln(testIO.Out) - }, + IO: testIO, + Logger: logger, Config: config.Config{ Admins: []string{"jillValentine", "rebeccaChambers"}, Port: 666, @@ -57,8 +57,10 @@ func TestInstanceInfo(t *testing.T) { for _, tt := range ts { t.Run(tt.name, func(t *testing.T) { + req, _ := http.NewRequest("GET", "", strings.NewReader("")) api := &API{ Opts: tt.opts, + Req: req, } resp, err := api.InstanceInfo() if tt.wantErr != nil && err != nil { diff --git a/server/cmd/config/config.go b/server/cmd/config/config.go index b4e50c1..1725ec9 100644 --- a/server/cmd/config/config.go +++ b/server/cmd/config/config.go @@ -4,6 +4,7 @@ import ( "database/sql" "fmt" "io" + "log" "os" yaml "gopkg.in/yaml.v3" @@ -34,6 +35,7 @@ type Config struct { type Options struct { ConfigPath string IO IOStreams + Logger *log.Logger Log func(string) Logf func(string, ...interface{}) Config Config diff --git a/server/cmd/main.go b/server/cmd/main.go index e1c5130..b5dca2c 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -7,13 +7,13 @@ import ( "errors" "flag" "fmt" + "log" "net/http" "os" "strings" "git.tilde.town/tildetown/bbj2/server/cmd/api" "git.tilde.town/tildetown/bbj2/server/cmd/config" - "git.tilde.town/tildetown/bbj2/server/cmd/db" _ "github.com/mattn/go-sqlite3" ) @@ -30,23 +30,17 @@ func main() { Err: os.Stderr, Out: os.Stdout, } + logger := log.New(io.Out, "", log.Ldate|log.Ltime|log.Lshortfile) opts := &config.Options{ ConfigPath: *configFlag, Reset: *resetFlag, IO: io, - // TODO use real logger - Log: func(s string) { - fmt.Fprintln(io.Out, s) - }, - Logf: func(s string, args ...interface{}) { - fmt.Fprintf(io.Out, s, args...) - fmt.Fprintf(io.Out, "\n") - }, + Logger: logger, } err := _main(opts) if err != nil { - fmt.Fprintf(os.Stderr, "failed: %s", err) + logger.Fatalln(err.Error()) } } @@ -81,7 +75,7 @@ func _main(opts *config.Options) error { setupAPI(*opts) // TODO TLS or SSL or something - opts.Logf("starting server at %s:%d", cfg.Host, cfg.Port) + opts.Logger.Printf("starting server at %s:%d", cfg.Host, cfg.Port) if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), nil); err != nil { return fmt.Errorf("http server exited with error: %w", err) } @@ -127,7 +121,7 @@ func ensureSchema(opts config.Options) error { func handler(opts config.Options, f http.HandlerFunc) http.HandlerFunc { // TODO make this more real return func(w http.ResponseWriter, req *http.Request) { - opts.Log(req.URL.Path) + opts.Logger.Printf("<- %s", req.URL.Path) // TODO add user info to opts f(w, req) } @@ -137,54 +131,7 @@ func handler(opts config.Options, f http.HandlerFunc) http.HandlerFunc { // encryption, it doesn't really help anything. I'd rather have plaintext + // transport encryption and then, on the server side, proper salted hashing. -// TODO get rid of these - -func writeResponse(w http.ResponseWriter, resp api.BBJResponse) { - w.WriteHeader(http.StatusOK) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) -} - // NB breaking: i'm not just returning 200 always but using http status codes -func writeErrorResponse(w http.ResponseWriter, code int, resp api.BBJResponse) { - w.WriteHeader(code) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) -} - -func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { - u = &db.User{} - u.Username = req.Header.Get("User") - u.Hash = req.Header.Get("Auth") - if u.Username == "" || u.Username == "anon" { - return - } - - db := opts.DB - stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?") - if err != nil { - err = fmt.Errorf("db error: %w", err) - return - } - defer stmt.Close() - - opts.Logf("querying for %s", u.Username) - - var authHash string - if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil { - if strings.Contains(err.Error(), "no rows in result") { - err = errors.New("no such user") - } else { - err = fmt.Errorf("db error: %w", err) - } - } - - if authHash != u.Hash { - err = errors.New("bad credentials") - } - - return -} func checkAuth(opts config.Options, username, hash string) error { db := opts.DB @@ -194,7 +141,7 @@ func checkAuth(opts config.Options, username, hash string) error { } defer stmt.Close() - opts.Logf("querying for %s", username) + opts.Logger.Printf("querying for %s", username) var authHash string if err = stmt.QueryRow(username).Scan(&authHash); err != nil { @@ -212,347 +159,318 @@ func checkAuth(opts config.Options, username, hash string) error { } func setupAPI(opts config.Options) { - newAPI := func(opts config.Options, w http.ResponseWriter, req *http.Request) *api.API { - user, err := getUserFromReq(opts, req) - if err != nil { - writeErrorResponse(w, 403, api.BBJResponse{ - Error: true, - Data: err.Error(), - }) - return nil - } - return &api.API{ - Opts: opts, - User: user, - } - } - - invokeAPI := func(w http.ResponseWriter, apiFn api.ApiHandler) { - resp, err := apiFn() - if err != nil { - he := &api.HTTPError{} - _ = errors.As(err, &he) - resp := api.BBJResponse{ - Error: true, - Data: he.Msg, - } - w.WriteHeader(he.Code) - w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) - return - } - - w.WriteHeader(http.StatusOK) + handleFailedAPICreate := func(w http.ResponseWriter, err error) { + opts.Logger.Printf("failed to create API: %s", err.Error()) + w.WriteHeader(500) w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(resp) + json.NewEncoder(w).Encode(api.BBJResponse{ + Error: true, + Data: "server error check logs", + }) } + // TODO could probably generalize this even further but it's fine for now + http.HandleFunc("/instance_info", handler(opts, func(w http.ResponseWriter, req *http.Request) { - api := newAPI(opts, w, req) - if api == nil { + a, err := api.NewAPI(opts, req) + if err != nil { + handleFailedAPICreate(w, err) return } - invokeAPI(w, api.InstanceInfo) + api.Invoke(w, a.InstanceInfo) + })) + + http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { + a, err := api.NewAPI(opts, req) + if err != nil { + handleFailedAPICreate(w, err) + return + } + api.Invoke(w, a.UserRegister) })) /* - http.HandleFunc("/instance_info", handler(opts, func(w http.ResponseWriter, req *http.Request) { - type instanceInfo struct { - InstanceName string `json:"instance_name"` - AllowAnon bool `json:"allow_anon"` - Admins []string + http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return } + + type AuthArgs struct { + Username string `json:"user_name"` + Hash string `json:"auth_hash"` + } + + var args AuthArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } + + if args.Hash == "" || args.Username == "" { + invalidArgs(w) + return + } + + opts.Logf("querying for %s", args.Username) + + if err := checkAuth(opts, args.Username, args.Hash); err == nil { + opts.Logf("found %s", args.Username) + // code 4 apparently + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: "user already exists", + }) + } else if err.Error() != "no such user" { + serverErr(w, err) + return + } + + db := opts.DB + stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) + id, err := uuid.NewRandom() + if err != nil { + serverErr(w, err) + return + } + + _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) + if err != nil { + serverErr(w, err) + } + writeResponse(w, BBJResponse{ - Data: instanceInfo{ - InstanceName: opts.Config.InstanceName, - AllowAnon: opts.Config.AllowAnon, - Admins: opts.Config.Admins, - }, + Data: true, // TODO probably something else + // TODO prob usermap }) })) + http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } - http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + type AuthArgs struct { + Username string `json:"target_user"` + AuthHash string `json:"target_hash"` + } - type AuthArgs struct { - Username string `json:"user_name"` - Hash string `json:"auth_hash"` - } + var args AuthArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + opts.Logf("got %s %s", args.Username, args.AuthHash) - if args.Hash == "" || args.Username == "" { - invalidArgs(w) - return - } + db := opts.DB - opts.Logf("querying for %s", args.Username) + stmt, err := db.Prepare("select auth_hash from users where user_name = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - if err := checkAuth(opts, args.Username, args.Hash); err == nil { - opts.Logf("found %s", args.Username) - // code 4 apparently - writeErrorResponse(w, 403, BBJResponse{ + var authHash string + err = stmt.QueryRow(args.Username).Scan(&authHash) + if err != nil { + if strings.Contains(err.Error(), "no rows in result") { + opts.Logf("user not found") + writeErrorResponse(w, 404, BBJResponse{ Error: true, - Data: "user already exists", + Data: "user not found", }) - } else if err.Error() != "no such user" { - serverErr(w, err) - return - } - - db := opts.DB - stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) - id, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } - - _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) - if err != nil { + } else { serverErr(w, err) } + return + } - writeResponse(w, BBJResponse{ - Data: true, // TODO probably something else - // TODO prob usermap + // TODO unique constraint on user_name + + if authHash != args.AuthHash { + http.Error(w, "incorrect password", 403) + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: "incorrect password", }) - })) + return + } - http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + // TODO include usermap? + writeResponse(w, BBJResponse{ + Data: true, + }) + })) - type AuthArgs struct { - Username string `json:"target_user"` - AuthHash string `json:"target_hash"` - } - - var args AuthArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } - - opts.Logf("got %s %s", args.Username, args.AuthHash) - - db := opts.DB - - stmt, err := db.Prepare("select auth_hash from users where user_name = ?") + http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { + db := opts.DB + rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") + if err != nil { + serverErr(w, err) + return + } + defer rows.Close() + for rows.Next() { + var id string + err = rows.Scan(&id) if err != nil { serverErr(w, err) return } - defer stmt.Close() + opts.Log(id) + } + writeResponse(w, BBJResponse{Data: "TODO"}) + // TODO + })) - var authHash string - err = stmt.QueryRow(args.Username).Scan(&authHash) - if err != nil { - if strings.Contains(err.Error(), "no rows in result") { - opts.Logf("user not found") - writeErrorResponse(w, 404, BBJResponse{ - Error: true, - Data: "user not found", - }) - } else { - serverErr(w, err) - } - return - } + http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { + if req.Method != "POST" { + badMethod(w) + return + } - // TODO unique constraint on user_name - - if authHash != args.AuthHash { - http.Error(w, "incorrect password", 403) - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: "incorrect password", - }) - return - } - - // TODO include usermap? - writeResponse(w, BBJResponse{ - Data: true, + // TODO make this getUserInfoFromReq or similar so we can use the user ID later + user, err := getUserFromReq(opts, req) + if err != nil { + writeErrorResponse(w, 403, BBJResponse{ + Error: true, + Data: err.Error(), }) - })) + return + } - http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { - db := opts.DB - rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") - if err != nil { - serverErr(w, err) - return - } - defer rows.Close() - for rows.Next() { - var id string - err = rows.Scan(&id) - if err != nil { - serverErr(w, err) - return - } - opts.Log(id) - } - writeResponse(w, BBJResponse{Data: "TODO"}) - // TODO - })) + type threadCreateArgs struct { + Title string + Body string + SendRaw bool `json:"send_raw"` + } - http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { - if req.Method != "POST" { - badMethod(w) - return - } + var args threadCreateArgs + if err := json.NewDecoder(req.Body).Decode(&args); err != nil { + invalidArgs(w) + return + } - // TODO make this getUserInfoFromReq or similar so we can use the user ID later - user, err := getUserFromReq(opts, req) - if err != nil { - writeErrorResponse(w, 403, BBJResponse{ - Error: true, - Data: err.Error(), - }) - return - } + if args.Title == "" || args.Body == "" { + invalidArgs(w) + return + } - type threadCreateArgs struct { - Title string - Body string - SendRaw bool `json:"send_raw"` - } + db := opts.DB + tx, err := db.Begin() + if err != nil { + serverErr(w, err) + return + } - var args threadCreateArgs - if err := json.NewDecoder(req.Body).Decode(&args); err != nil { - invalidArgs(w) - return - } + stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - if args.Title == "" || args.Body == "" { - invalidArgs(w) - return - } + threadID, err := uuid.NewRandom() + if err != nil { + serverErr(w, err) + return + } + now := time.Now() + if _, err = stmt.Exec( + threadID, + user.ID, + args.Title, + now, + now, + user.Username, + ); err != nil { + serverErr(w, err) + return + } - db := opts.DB - tx, err := db.Begin() - if err != nil { - serverErr(w, err) - return - } + stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() - stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + if _, err = stmt.Exec( + threadID, + user.ID, + now, + args.Body, + args.SendRaw, + ); err != nil { + serverErr(w, err) + return + } - threadID, err := uuid.NewRandom() - if err != nil { - serverErr(w, err) - return - } - now := time.Now() - if _, err = stmt.Exec( - threadID, - user.ID, - args.Title, - now, - now, - user.Username, + if err = tx.Commit(); err != nil { + serverErr(w, err) + return + } + + stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() + + t := &Thread{} + + // TODO fill in rest of thread + if err = stmt.QueryRow(threadID).Scan( + t.ID, + t.Author, + t.Title, + t.LastMod, + t.Created, + t.ReplyCount, + t.Pinned, + t.LastAuthor, + ); err != nil { + serverErr(w, err) + return + } + + stmt, err = db.Prepare("select * from messages where thread_id = ?") + if err != nil { + serverErr(w, err) + return + } + defer stmt.Close() + rows, err := stmt.Query(threadID) + if err != nil { + serverErr(w, err) + return + } + + t.Messages = []Message{} + + for rows.Next() { + m := &Message{} + if err := rows.Scan( + m.ThreadID, + m.PostID, + m.Author, + m.Created, + m.Edited, + m.Body, + m.SendRaw, ); err != nil { serverErr(w, err) return } + t.Messages = append(t.Messages, *m) + } - stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() + writeResponse(w, BBJResponse{Data: t}) - if _, err = stmt.Exec( - threadID, - user.ID, - now, - args.Body, - args.SendRaw, - ); err != nil { - serverErr(w, err) - return - } - - if err = tx.Commit(); err != nil { - serverErr(w, err) - return - } - - stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - - t := &Thread{} - - // TODO fill in rest of thread - if err = stmt.QueryRow(threadID).Scan( - t.ID, - t.Author, - t.Title, - t.LastMod, - t.Created, - t.ReplyCount, - t.Pinned, - t.LastAuthor, - ); err != nil { - serverErr(w, err) - return - } - - stmt, err = db.Prepare("select * from messages where thread_id = ?") - if err != nil { - serverErr(w, err) - return - } - defer stmt.Close() - rows, err := stmt.Query(threadID) - if err != nil { - serverErr(w, err) - return - } - - t.Messages = []Message{} - - for rows.Next() { - m := &Message{} - if err := rows.Scan( - m.ThreadID, - m.PostID, - m.Author, - m.Created, - m.Edited, - m.Body, - m.SendRaw, - ); err != nil { - serverErr(w, err) - return - } - t.Messages = append(t.Messages, *m) - } - - writeResponse(w, BBJResponse{Data: t}) - - })) + })) */ }