diff --git a/store.go b/store.go index 49648ee..a75217d 100644 --- a/store.go +++ b/store.go @@ -5,6 +5,7 @@ import ( "database/sql" _ "embed" "fmt" + "strings" _ "github.com/mattn/go-sqlite3" ) @@ -12,34 +13,55 @@ import ( //go:embed migrations/init.sql var initScript string -const ( - CLAIM_TYPE_AREA = iota - CLAIM_TYPE_REGION - CLAIM_TYPE_TRADE -) +type ClaimType string -var claimTypeEnum = map[string]int{ - "area": CLAIM_TYPE_AREA, - "region": CLAIM_TYPE_REGION, - "trade": CLAIM_TYPE_TRADE, +func ClaimTypeFromString(s string) (ClaimType, error) { + switch s { + case CLAIM_TYPE_AREA: + return CLAIM_TYPE_AREA, nil + case CLAIM_TYPE_REGION: + return CLAIM_TYPE_REGION, nil + case CLAIM_TYPE_TRADE: + return CLAIM_TYPE_TRADE, nil + } + return "", fmt.Errorf("no claim type matching '%s'", s) } -var claimTypeEnumVals = map[int]string{ +const ( + CLAIM_TYPE_AREA = "area" + CLAIM_TYPE_REGION = "region" + CLAIM_TYPE_TRADE = "trade" +) + +var claimTypeToColumn = map[ClaimType]string{ CLAIM_TYPE_AREA: "area", CLAIM_TYPE_REGION: "region", - CLAIM_TYPE_TRADE: "trade", -} - -var claimTypeToColumn = map[string]string{ - "area": "area", - "region": "region", - "trade": "trade_node", + CLAIM_TYPE_TRADE: "trade_node", } type Store struct { db *sql.DB } +type Claim struct { + ID int + Player string + Name string + Type ClaimType +} + +func (c Claim) String() string { + return fmt.Sprintf("id=%d player=%s claim_type=%s name=%s", c.ID, c.Player, c.Type, c.Name) +} + +type ErrConflict struct { + Conflicts []string +} + +func (ec ErrConflict) Error() string { + return fmt.Sprintf("found conflicting provinces: %s", strings.Join(ec.Conflicts, ", ")) +} + func NewStore(conn string) (*Store, error) { db, err := sql.Open("sqlite3", conn) if err != nil { @@ -56,7 +78,7 @@ func NewStore(conn string) (*Store, error) { }, nil } -func (s *Store) Claim(ctx context.Context, player, province string, typ int) error { +func (s *Store) Claim(ctx context.Context, player, province string, claimType ClaimType) error { tx, err := s.db.Begin() if err != nil { return fmt.Errorf("failed to begin transaction: %w", err) @@ -64,24 +86,32 @@ func (s *Store) Claim(ctx context.Context, player, province string, typ int) err defer tx.Commit() // Check conflicts - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ? and provinces.name in ( + stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in ( SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' -)`, claimTypeToColumn[claimTypeEnumVals[typ]])) +)`, claimTypeToColumn[claimType])) if err != nil { return fmt.Errorf("failed to prepare conflicts query: %w", err) } - row := stmt.QueryRowContext(ctx, province) - var count int - err = row.Scan(&count) + rows, err := stmt.QueryContext(ctx, province) if err != nil { - return fmt.Errorf("failed to get count of conflicting provinces: %w", err) + return fmt.Errorf("failed to get conflicting provinces: %w", err) + } + + conflicts := make([]string, 0) + for rows.Next() { + var p string + err = rows.Scan(&p) + if err != nil { + return fmt.Errorf("failed to scan row: %w", err) + } + conflicts = append(conflicts, p) } - if count > 0 { - return fmt.Errorf("found %d conflicting provinces", count) + if len(conflicts) > 0 { + return ErrConflict{Conflicts: conflicts} } stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val) VALUES (?, ?, ?)") @@ -89,10 +119,41 @@ func (s *Store) Claim(ctx context.Context, player, province string, typ int) err return fmt.Errorf("failed to prepare claim query: %w", err) } - _, err = stmt.ExecContext(ctx, player, claimTypeEnumVals[typ], province) + _, err = stmt.ExecContext(ctx, player, claimType, province) if err != nil { return fmt.Errorf("failed to insert claim: %w", err) } return nil } + +func (s *Store) ListClaims(ctx context.Context) ([]Claim, error) { + stmt, err := s.db.PrepareContext(ctx, `SELECT id, player, claim_type, val FROM claims`) + if err != nil { + return nil, fmt.Errorf("failed to prepare query: %w", err) + } + + rows, err := stmt.QueryContext(ctx) + if err != nil { + return nil, fmt.Errorf("failed to execute query: %w", err) + } + + claims := make([]Claim, 0) + for rows.Next() { + c := Claim{} + var rawType string + err = rows.Scan(&c.ID, &c.Player, &rawType, &c.Name) + if err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + cl, err := ClaimTypeFromString(rawType) + if err != nil { + return nil, fmt.Errorf("unexpected error converting raw claim type: %w", err) + } + c.Type = cl + + claims = append(claims, c) + } + + return claims, nil +}