From 41bcfb442aa32262c02778b26a9660413fe98691 Mon Sep 17 00:00:00 2001 From: vilmibm Date: Wed, 27 Jul 2022 20:45:21 -0500 Subject: [PATCH] switch to interface for database --- server/cmd/main.go | 33 +++++++++++++++++++--------- server/db/db.go | 55 +++++++++++++++++++++++++++------------------- 2 files changed, 56 insertions(+), 32 deletions(-) diff --git a/server/cmd/main.go b/server/cmd/main.go index fe724d4..eb1812a 100644 --- a/server/cmd/main.go +++ b/server/cmd/main.go @@ -47,7 +47,11 @@ func _main() (err error) { */ } grpcServer := grpc.NewServer(opts...) - proto.RegisterGameWorldServer(grpcServer, newServer()) + srv, err := newServer() + if err != nil { + return err + } + proto.RegisterGameWorldServer(grpcServer, srv) grpcServer.Serve(l) return nil @@ -57,29 +61,38 @@ func _main() (err error) { type gameWorldServer struct { proto.UnimplementedGameWorldServer + db db.DB mu sync.Mutex // for msgRouter msgRouter map[string]func(*proto.ClientMessage) error } -func newServer() *gameWorldServer { +func newServer() (*gameWorldServer, error) { + // TODO read from env or whatever + db, err := db.NewDB("postgres://vilmibm:vilmibm@localhost:5432/hermeticum") + if err != nil { + return nil, err + } + s := &gameWorldServer{ msgRouter: make(map[string]func(*proto.ClientMessage) error), + db: db, } - return s + + return s, nil } func (s *gameWorldServer) Commands(stream proto.GameWorld_CommandsServer) error { + var sid string for { cmd, err := stream.Recv() if err == io.EOF { - // TODO end session - return nil + return s.db.EndSession(sid) } if err != nil { return err } - sid := cmd.SessionInfo.SessionID + sid = cmd.SessionInfo.SessionID send := s.msgRouter[sid] msg := &proto.ClientMessage{ @@ -122,13 +135,13 @@ func (s *gameWorldServer) Messages(si *proto.SessionInfo, stream proto.GameWorld func (s *gameWorldServer) Register(ctx context.Context, auth *proto.AuthInfo) (si *proto.SessionInfo, err error) { var a *db.Account - a, err = db.CreateAccount(auth.Username, auth.Password) + a, err = s.db.CreateAccount(auth.Username, auth.Password) if err != nil { return nil, err } var sessionID string - sessionID, err = db.StartSession(*a) + sessionID, err = s.db.StartSession(*a) if err != nil { return nil, err } @@ -140,13 +153,13 @@ func (s *gameWorldServer) Register(ctx context.Context, auth *proto.AuthInfo) (s func (s *gameWorldServer) Login(ctx context.Context, auth *proto.AuthInfo) (si *proto.SessionInfo, err error) { var a *db.Account - a, err = db.ValidateCredentials(auth.Username, auth.Password) + a, err = s.db.ValidateCredentials(auth.Username, auth.Password) if err != nil { return } var sessionID string - sessionID, err = db.StartSession(*a) + sessionID, err = s.db.StartSession(*a) if err != nil { return } diff --git a/server/db/db.go b/server/db/db.go index af0d544..15d5e6a 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -4,6 +4,7 @@ import ( "context" _ "embed" "errors" + "log" "github.com/google/uuid" "github.com/jackc/pgx/v4/pgxpool" @@ -12,22 +13,39 @@ import ( // go:embed schema.sql var schema string -func EnsureSchema() { - // TODO look into tern +type Account struct { + ID int + Name string + Pwhash string } -func connect() (*pgxpool.Pool, error) { - // TODO read dburl from environment - conn, err := pgxpool.Connect(context.Background(), "postgres://vilmibm:vilmibm@localhost:5432/hermeticum") +type DB interface { + // EnsureSchema() TODO look into tern + CreateAccount(string, string) (*Account, error) + ValidateCredentials(string, string) (*Account, error) + GetAccount(string) (*Account, error) + StartSession(Account) (string, error) + EndSession(string) error +} + +type pgDB struct { + pool *pgxpool.Pool +} + +func NewDB(connURL string) (DB, error) { + pool, err := pgxpool.Connect(context.Background(), connURL) if err != nil { return nil, err } + pgdb := &pgDB{ + pool: pool, + } - return conn, nil + return pgdb, nil } -func CreateAccount(name, password string) (*Account, error) { - conn, err := connect() +func (db *pgDB) CreateAccount(name, password string) (*Account, error) { + conn, err := db.pool.Acquire(context.Background()) if err != nil { return nil, err } @@ -50,8 +68,8 @@ func CreateAccount(name, password string) (*Account, error) { return a, err } -func ValidateCredentials(name, password string) (*Account, error) { - a, err := GetAccount(name) +func (db *pgDB) ValidateCredentials(name, password string) (*Account, error) { + a, err := db.GetAccount(name) if err != nil { return nil, err } @@ -65,14 +83,8 @@ func ValidateCredentials(name, password string) (*Account, error) { return a, nil } -type Account struct { - ID int - Name string - Pwhash string -} - -func GetAccount(name string) (*Account, error) { - conn, err := connect() +func (db *pgDB) GetAccount(name string) (*Account, error) { + conn, err := db.pool.Acquire(context.Background()) if err != nil { return nil, err } @@ -89,11 +101,10 @@ func GetAccount(name string) (*Account, error) { return a, nil } -func StartSession(a Account) (sessionID string, err error) { - var conn *pgxpool.Pool - conn, err = connect() +func (db *pgDB) StartSession(a Account) (sessionID string, err error) { + conn, err := db.pool.Acquire(context.Background()) if err != nil { - return + return "", err } sessionID = uuid.New().String()