package invites import ( "database/sql" "encoding/base64" "errors" "math/rand" "strings" "time" _ "github.com/mattn/go-sqlite3" ) const ( dsn = "/town/var/invites/invites.db?mode=rw" codeLen = 32 ) type Invite struct { ID int64 Created time.Time Code string Email string Used bool } func (i *Invite) Insert(db *sql.DB) error { stmt, err := db.Prepare(` INSERT INTO invites (code, email) VALUES (?, ?) `) if err != nil { return err } _, err = stmt.Exec(generateCode(i.Email), i.Email) if err != nil { return err } defer stmt.Close() return nil } func ConnectDB() (*sql.DB, error) { db, err := sql.Open("sqlite3", dsn) if err != nil { return nil, err } return db, nil } func generateCode(email string) string { rand.Seed(time.Now().Unix()) charset := "abcdefghijklmnopqrztuvwxyz" charset += strings.ToUpper(charset) charset += "0123456789" charset += "`~!@#$%^&*()-=_+[]{}|;:,./<>?" code := []byte{} for len(code) < codeLen { code = append(code, charset[rand.Intn(len(charset))]) } code = append(code, ' ') eb := []byte(email) for x := 0; x < len(eb); x++ { code = append(code, eb[x]) } return base64.StdEncoding.EncodeToString(code) } func Decode(code string) ([]string, error) { decoded, err := base64.StdEncoding.DecodeString(code) if err != nil { return nil, err } return strings.Split(string(decoded), " "), nil } func Get(db *sql.DB, code string) (*Invite, error) { inv := &Invite{ Code: code, } var created string var used int stmt, err := db.Prepare(` SELECT id, created, email, used FROM invites WHERE code = ?`) if err != nil { return nil, err } row := stmt.QueryRow(code) if err != nil { return nil, err } defer stmt.Close() err = row.Scan( &inv.ID, &created, &inv.Email, &used, ) if err != nil { return nil, err } inv.Created, err = time.Parse("2006-01-02T15:04", created) if err != nil { return inv, err } inv.Used = used > 0 return inv, nil } func (i *Invite) MarkUsed(db *sql.DB) (err error) { var stmt *sql.Stmt var result sql.Result var rowsAffected int64 if stmt, err = db.Prepare(`UPDATE invites SET used = 1 WHERE id = ?`); err != nil { return } if result, err = stmt.Exec(i.ID); err != nil { return } if rowsAffected, err = result.RowsAffected(); err != nil { return } if rowsAffected == 0 { err = errors.New("no rows affected") } return }