switch to interface for database

trunk
vilmibm 2022-07-27 20:45:21 -05:00
parent 8fd90a5385
commit 41bcfb442a
2 changed files with 56 additions and 32 deletions

View File

@ -47,7 +47,11 @@ func _main() (err error) {
*/
}
grpcServer := grpc.NewServer(opts...)
proto.RegisterGameWorldServer(grpcServer, newServer())
srv, err := newServer()
if err != nil {
return err
}
proto.RegisterGameWorldServer(grpcServer, srv)
grpcServer.Serve(l)
return nil
@ -57,29 +61,38 @@ func _main() (err error) {
type gameWorldServer struct {
proto.UnimplementedGameWorldServer
db db.DB
mu sync.Mutex // for msgRouter
msgRouter map[string]func(*proto.ClientMessage) error
}
func newServer() *gameWorldServer {
func newServer() (*gameWorldServer, error) {
// TODO read from env or whatever
db, err := db.NewDB("postgres://vilmibm:vilmibm@localhost:5432/hermeticum")
if err != nil {
return nil, err
}
s := &gameWorldServer{
msgRouter: make(map[string]func(*proto.ClientMessage) error),
db: db,
}
return s
return s, nil
}
func (s *gameWorldServer) Commands(stream proto.GameWorld_CommandsServer) error {
var sid string
for {
cmd, err := stream.Recv()
if err == io.EOF {
// TODO end session
return nil
return s.db.EndSession(sid)
}
if err != nil {
return err
}
sid := cmd.SessionInfo.SessionID
sid = cmd.SessionInfo.SessionID
send := s.msgRouter[sid]
msg := &proto.ClientMessage{
@ -122,13 +135,13 @@ func (s *gameWorldServer) Messages(si *proto.SessionInfo, stream proto.GameWorld
func (s *gameWorldServer) Register(ctx context.Context, auth *proto.AuthInfo) (si *proto.SessionInfo, err error) {
var a *db.Account
a, err = db.CreateAccount(auth.Username, auth.Password)
a, err = s.db.CreateAccount(auth.Username, auth.Password)
if err != nil {
return nil, err
}
var sessionID string
sessionID, err = db.StartSession(*a)
sessionID, err = s.db.StartSession(*a)
if err != nil {
return nil, err
}
@ -140,13 +153,13 @@ func (s *gameWorldServer) Register(ctx context.Context, auth *proto.AuthInfo) (s
func (s *gameWorldServer) Login(ctx context.Context, auth *proto.AuthInfo) (si *proto.SessionInfo, err error) {
var a *db.Account
a, err = db.ValidateCredentials(auth.Username, auth.Password)
a, err = s.db.ValidateCredentials(auth.Username, auth.Password)
if err != nil {
return
}
var sessionID string
sessionID, err = db.StartSession(*a)
sessionID, err = s.db.StartSession(*a)
if err != nil {
return
}

View File

@ -4,6 +4,7 @@ import (
"context"
_ "embed"
"errors"
"log"
"github.com/google/uuid"
"github.com/jackc/pgx/v4/pgxpool"
@ -12,22 +13,39 @@ import (
// go:embed schema.sql
var schema string
func EnsureSchema() {
// TODO look into tern
type Account struct {
ID int
Name string
Pwhash string
}
func connect() (*pgxpool.Pool, error) {
// TODO read dburl from environment
conn, err := pgxpool.Connect(context.Background(), "postgres://vilmibm:vilmibm@localhost:5432/hermeticum")
type DB interface {
// EnsureSchema() TODO look into tern
CreateAccount(string, string) (*Account, error)
ValidateCredentials(string, string) (*Account, error)
GetAccount(string) (*Account, error)
StartSession(Account) (string, error)
EndSession(string) error
}
type pgDB struct {
pool *pgxpool.Pool
}
func NewDB(connURL string) (DB, error) {
pool, err := pgxpool.Connect(context.Background(), connURL)
if err != nil {
return nil, err
}
pgdb := &pgDB{
pool: pool,
}
return conn, nil
return pgdb, nil
}
func CreateAccount(name, password string) (*Account, error) {
conn, err := connect()
func (db *pgDB) CreateAccount(name, password string) (*Account, error) {
conn, err := db.pool.Acquire(context.Background())
if err != nil {
return nil, err
}
@ -50,8 +68,8 @@ func CreateAccount(name, password string) (*Account, error) {
return a, err
}
func ValidateCredentials(name, password string) (*Account, error) {
a, err := GetAccount(name)
func (db *pgDB) ValidateCredentials(name, password string) (*Account, error) {
a, err := db.GetAccount(name)
if err != nil {
return nil, err
}
@ -65,14 +83,8 @@ func ValidateCredentials(name, password string) (*Account, error) {
return a, nil
}
type Account struct {
ID int
Name string
Pwhash string
}
func GetAccount(name string) (*Account, error) {
conn, err := connect()
func (db *pgDB) GetAccount(name string) (*Account, error) {
conn, err := db.pool.Acquire(context.Background())
if err != nil {
return nil, err
}
@ -89,11 +101,10 @@ func GetAccount(name string) (*Account, error) {
return a, nil
}
func StartSession(a Account) (sessionID string, err error) {
var conn *pgxpool.Pool
conn, err = connect()
func (db *pgDB) StartSession(a Account) (sessionID string, err error) {
conn, err := db.pool.Acquire(context.Background())
if err != nil {
return
return "", err
}
sessionID = uuid.New().String()