package main import ( "database/sql" _ "embed" "encoding/json" "errors" "flag" "fmt" "net/http" "os" "strings" "time" "git.tilde.town/tildetown/bbj2/server/cmd/api" "git.tilde.town/tildetown/bbj2/server/cmd/config" "git.tilde.town/tildetown/bbj2/server/cmd/db" _ "github.com/mattn/go-sqlite3" ) // TODO tests //go:embed schema.sql var schemaSQL string func main() { var configFlag = flag.String("config", "config.yml", "A path to a config file.") var resetFlag = flag.Bool("reset", false, "reset the database. WARNING this wipes everything.") flag.Parse() io := config.IOStreams{ Err: os.Stderr, Out: os.Stdout, } opts := &config.Options{ ConfigPath: *configFlag, Reset: *resetFlag, IO: io, // TODO use real logger Log: func(s string) { fmt.Fprintln(io.Out, s) }, Logf: func(s string, args ...interface{}) { fmt.Fprintf(io.Out, s, args...) fmt.Fprintf(io.Out, "\n") }, } err := _main(opts) if err != nil { fmt.Fprintf(os.Stderr, "failed: %s", err) } } type Teardown func() func setupDB(opts *config.Options) (Teardown, error) { db, err := sql.Open("sqlite3", opts.Config.DBPath) opts.DB = db return func() { db.Close() }, err } func _main(opts *config.Options) error { cfg, err := config.ParseConfig(opts.ConfigPath) if err != nil { fmt.Fprintf(os.Stderr, "could not read config file '%s'", opts.ConfigPath) os.Exit(1) } opts.Config = *cfg teardown, err := setupDB(opts) if err != nil { return fmt.Errorf("could not initialize DB: %w", err) } defer teardown() err = ensureSchema(*opts) if err != nil { return err } setupAPI(*opts) // TODO TLS or SSL or something opts.Logf("starting server at %s:%d", cfg.Host, cfg.Port) if err := http.ListenAndServe(fmt.Sprintf("%s:%d", cfg.Host, cfg.Port), nil); err != nil { return fmt.Errorf("http server exited with error: %w", err) } return nil } func ensureSchema(opts config.Options) error { db := opts.DB if opts.Reset { err := os.Remove(opts.Config.DBPath) if err != nil { return fmt.Errorf("failed to delete database: %w", err) } } rows, err := db.Query("select version from meta") if err == nil { defer rows.Close() rows.Next() var version string err = rows.Scan(&version) if err != nil { return fmt.Errorf("failed to check database schema version: %w", err) } else if version == "" { return errors.New("database is in unknown state") } return nil } if !strings.Contains(err.Error(), "no such table") { return fmt.Errorf("got error checking database state: %w", err) } _, err = db.Exec(schemaSQL) if err != nil { return fmt.Errorf("failed to initialize database schema: %w", err) } return nil } func handler(opts config.Options, f http.HandlerFunc) http.HandlerFunc { // TODO make this more real return func(w http.ResponseWriter, req *http.Request) { opts.Log(req.URL.Path) // TODO add user info to opts f(w, req) } } // TODO I'm not entirely sold on this hash system; without transport // encryption, it doesn't really help anything. I'd rather have plaintext + // transport encryption and then, on the server side, proper salted hashing. func writeResponse(w http.ResponseWriter, resp api.BBJResponse) { w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) } // NB breaking: i'm not just returning 200 always but using http status codes func writeErrorResponse(w http.ResponseWriter, code int, resp api.BBJResponse) { w.WriteHeader(code) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) } func getUserFromReq(opts config.Options, req *http.Request) (u *db.User, err error) { u = &db.User{} u.Username = req.Header.Get("User") u.Hash = req.Header.Get("Auth") if u.Username == "" || u.Username == "anon" { return } db := opts.DB stmt, err := db.Prepare("select auth_hash, id from users where user_name = ?") if err != nil { err = fmt.Errorf("db error: %w", err) return } defer stmt.Close() opts.Logf("querying for %s", u.Username) var authHash string if err = stmt.QueryRow(u.Username).Scan(&authHash, u.ID); err != nil { if strings.Contains(err.Error(), "no rows in result") { err = errors.New("no such user") } else { err = fmt.Errorf("db error: %w", err) } } if authHash != u.Hash { err = errors.New("bad credentials") } return } func checkAuth(opts config.Options, username, hash string) error { db := opts.DB stmt, err := db.Prepare("select auth_hash from users where user_name = ?") if err != nil { return fmt.Errorf("db error: %w", err) } defer stmt.Close() opts.Logf("querying for %s", username) var authHash string if err = stmt.QueryRow(username).Scan(&authHash); err != nil { if strings.Contains(err.Error(), "no rows in result") { return errors.New("no such user") } return fmt.Errorf("db error: %w", err) } if authHash != hash { return errors.New("bad credentials") } return nil } func setupAPI(opts config.Options) { newAPI := func(opts config.Options, w http.ResponseWriter, req *http.Request) *api.API { user, err := getUserFromReq(opts, req) if err != nil { writeErrorResponse(w, 403, api.BBJResponse{ Error: true, Data: err.Error(), }) return nil } return &api.API{ Opts: opts, User: user, } } invokeAPI := func(w http.ResponseWriter, apiFn api.ApiHandler) { resp, err := apiFn() if err != nil { he := &api.HTTPError{} _ = errors.As(err, &he) resp := api.BBJResponse{ Error: true, Data: he.Msg, } w.WriteHeader(he.Code) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) return } w.WriteHeader(http.StatusOK) w.Header().Set("Content-Type", "application/json") json.NewEncoder(w).Encode(resp) } http.HandleFunc("/instance_info", handler(opts, func(w http.ResponseWriter, req *http.Request) { api := newAPI(opts, w, req) if api == nil { return } invokeAPI(w, api.InstanceInfo) })) /* http.HandleFunc("/instance_info", handler(opts, func(w http.ResponseWriter, req *http.Request) { type instanceInfo struct { InstanceName string `json:"instance_name"` AllowAnon bool `json:"allow_anon"` Admins []string } writeResponse(w, BBJResponse{ Data: instanceInfo{ InstanceName: opts.Config.InstanceName, AllowAnon: opts.Config.AllowAnon, Admins: opts.Config.Admins, }, }) })) http.HandleFunc("/user_register", handler(opts, func(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { badMethod(w) return } type AuthArgs struct { Username string `json:"user_name"` Hash string `json:"auth_hash"` } var args AuthArgs if err := json.NewDecoder(req.Body).Decode(&args); err != nil { invalidArgs(w) return } if args.Hash == "" || args.Username == "" { invalidArgs(w) return } opts.Logf("querying for %s", args.Username) if err := checkAuth(opts, args.Username, args.Hash); err == nil { opts.Logf("found %s", args.Username) // code 4 apparently writeErrorResponse(w, 403, BBJResponse{ Error: true, Data: "user already exists", }) } else if err.Error() != "no such user" { serverErr(w, err) return } db := opts.DB stmt, err := db.Prepare(`INSERT INTO users VALUES (?, ?, ?, "", "", 0, 0, ?)`) id, err := uuid.NewRandom() if err != nil { serverErr(w, err) return } _, err = stmt.Exec(id, args.Username, args.Hash, time.Now()) if err != nil { serverErr(w, err) } writeResponse(w, BBJResponse{ Data: true, // TODO probably something else // TODO prob usermap }) })) http.HandleFunc("/check_auth", handler(opts, func(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { badMethod(w) return } type AuthArgs struct { Username string `json:"target_user"` AuthHash string `json:"target_hash"` } var args AuthArgs if err := json.NewDecoder(req.Body).Decode(&args); err != nil { invalidArgs(w) return } opts.Logf("got %s %s", args.Username, args.AuthHash) db := opts.DB stmt, err := db.Prepare("select auth_hash from users where user_name = ?") if err != nil { serverErr(w, err) return } defer stmt.Close() var authHash string err = stmt.QueryRow(args.Username).Scan(&authHash) if err != nil { if strings.Contains(err.Error(), "no rows in result") { opts.Logf("user not found") writeErrorResponse(w, 404, BBJResponse{ Error: true, Data: "user not found", }) } else { serverErr(w, err) } return } // TODO unique constraint on user_name if authHash != args.AuthHash { http.Error(w, "incorrect password", 403) writeErrorResponse(w, 403, BBJResponse{ Error: true, Data: "incorrect password", }) return } // TODO include usermap? writeResponse(w, BBJResponse{ Data: true, }) })) http.HandleFunc("/thread_index", handler(opts, func(w http.ResponseWriter, req *http.Request) { db := opts.DB rows, err := db.Query("SELECT * FROM threads JOIN messages ON threads.thread_id = messages.thread_id") if err != nil { serverErr(w, err) return } defer rows.Close() for rows.Next() { var id string err = rows.Scan(&id) if err != nil { serverErr(w, err) return } opts.Log(id) } writeResponse(w, BBJResponse{Data: "TODO"}) // TODO })) http.HandleFunc("/thread_create", handler(opts, func(w http.ResponseWriter, req *http.Request) { if req.Method != "POST" { badMethod(w) return } // TODO make this getUserInfoFromReq or similar so we can use the user ID later user, err := getUserFromReq(opts, req) if err != nil { writeErrorResponse(w, 403, BBJResponse{ Error: true, Data: err.Error(), }) return } type threadCreateArgs struct { Title string Body string SendRaw bool `json:"send_raw"` } var args threadCreateArgs if err := json.NewDecoder(req.Body).Decode(&args); err != nil { invalidArgs(w) return } if args.Title == "" || args.Body == "" { invalidArgs(w) return } db := opts.DB tx, err := db.Begin() if err != nil { serverErr(w, err) return } stmt, err := tx.Prepare("insert into threads VALUES ( ?, ?, ?, ?, ?, 0, 0, ? )") if err != nil { serverErr(w, err) return } defer stmt.Close() threadID, err := uuid.NewRandom() if err != nil { serverErr(w, err) return } now := time.Now() if _, err = stmt.Exec( threadID, user.ID, args.Title, now, now, user.Username, ); err != nil { serverErr(w, err) return } stmt, err = tx.Prepare("insert into messages values ( ?, 1, ?, ?, 0, ?, ? )") if err != nil { serverErr(w, err) return } defer stmt.Close() if _, err = stmt.Exec( threadID, user.ID, now, args.Body, args.SendRaw, ); err != nil { serverErr(w, err) return } if err = tx.Commit(); err != nil { serverErr(w, err) return } stmt, err = db.Prepare("select * from threads where thread_id = ? limit 1") if err != nil { serverErr(w, err) return } defer stmt.Close() t := &Thread{} // TODO fill in rest of thread if err = stmt.QueryRow(threadID).Scan( t.ID, t.Author, t.Title, t.LastMod, t.Created, t.ReplyCount, t.Pinned, t.LastAuthor, ); err != nil { serverErr(w, err) return } stmt, err = db.Prepare("select * from messages where thread_id = ?") if err != nil { serverErr(w, err) return } defer stmt.Close() rows, err := stmt.Query(threadID) if err != nil { serverErr(w, err) return } t.Messages = []Message{} for rows.Next() { m := &Message{} if err := rows.Scan( m.ThreadID, m.PostID, m.Author, m.Created, m.Edited, m.Body, m.SendRaw, ); err != nil { serverErr(w, err) return } t.Messages = append(t.Messages, *m) } writeResponse(w, BBJResponse{Data: t}) })) */ } type Thread struct { ID string `json:"thread_id"` Author string Title string LastMod time.Time `json:"last_mod"` Created time.Time ReplyCount int `json:"reply_count"` Pinned int // TODO bool LastAuthor string `json:"last_author"` Messages []Message } type Message struct { ThreadID string `json:"thread_id"` PostID string `json:"post_id"` Author string Created time.Time Edited int // TODO bool Body string SendRaw int `json:"send_raw"` // TODO bool }