diff --git a/availability.sql b/availability.sql new file mode 100644 index 0000000..d86d4c8 --- /dev/null +++ b/availability.sql @@ -0,0 +1,6 @@ +SELECT count(distinct name) +FROM claimables +LEFT JOIN claims ON claimables.name = claims.val AND claimables.typ = claims.claim_type +WHERE claims.val IS NULL +AND claimables.typ = 'area' +AND claimables.name LIKE '%bay%'; diff --git a/claim_type.go b/claim_type.go index 3741ecc..9c547ec 100644 --- a/claim_type.go +++ b/claim_type.go @@ -34,12 +34,6 @@ const ( CLAIM_TYPE_TRADE = "trade" ) -var claimTypeToColumn = map[ClaimType]string{ - CLAIM_TYPE_AREA: "area", - CLAIM_TYPE_REGION: "region", - CLAIM_TYPE_TRADE: "trade_node", -} - type Claim struct { ID int Player string diff --git a/conflicts.go b/conflicts.go index e0036f8..a1cb5c7 100644 --- a/conflicts.go +++ b/conflicts.go @@ -19,37 +19,33 @@ func (c Conflict) String() string { return fmt.Sprintf("%s owned by #%d %s %s (%s)", c.Province, c.ClaimID, c.ClaimType, c.Claim, c.Player) } -const conflictQuery string = `SELECT name, player, claim_type, val, id FROM ( - SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id - FROM claims - LEFT JOIN provinces ON claims.val = provinces.trade_node - WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ? - AND provinces.%[1]s = ? - UNION - SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id - FROM claims - LEFT JOIN provinces ON claims.val = provinces.region - WHERE claims.claim_type = 'region' AND claims.userid IS NOT ? - AND provinces.%[1]s = ? - UNION - SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id - FROM claims - LEFT JOIN provinces ON claims.val = provinces.area - WHERE claims.claim_type = 'area' AND claims.userid IS NOT ? - AND provinces.%[1]s = ? -);` +const conflictQuery string = `WITH claiming AS ( + SELECT province FROM claimables + WHERE claimables.typ = ? + AND claimables.name = ? +) +SELECT claimables.province, claims.player, claims.claim_type, claims.val, claims.id +FROM claims +INNER JOIN claimables + ON claims.claim_type = claimables.typ + AND claims.val = claimables.name +INNER JOIN claiming + ON claiming.province = claimables.province +WHERE claims.userid IS NOT ?;` func (s *Store) FindConflicts(ctx context.Context, userId, name string, claimType ClaimType) ([]Conflict, error) { log.Debug().Ctx(ctx).Stringer("claim_type", claimType).Str("userid", userId).Msg("searching for potential conflicts") - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(conflictQuery, claimTypeToColumn[claimType])) + + stmt, err := s.db.PrepareContext(ctx, conflictQuery) if err != nil { return nil, fmt.Errorf("failed to prepare conflicts query: %w", err) } - rows, err := stmt.QueryContext(ctx, userId, name, userId, name, userId, name) + rows, err := stmt.QueryContext(ctx, claimType, name, userId) if err != nil { return nil, fmt.Errorf("failed to get conflicting provinces: %w", err) } + defer stmt.Close() conflicts := make([]Conflict, 0) for rows.Next() { diff --git a/conflicts_test.go b/conflicts_test.go index 2ecab23..339e4be 100644 --- a/conflicts_test.go +++ b/conflicts_test.go @@ -2,19 +2,42 @@ package themis import ( "context" + "errors" "fmt" + "os" "reflect" "testing" + "github.com/rs/zerolog/log" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) +func touchDbFile(path string) error { + log.Debug().Str("path", path).Msg("touching database file") + f, err := os.Open(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + f, err := os.Create(path) + if err != nil { + return err + } + f.Close() + } else { + return err + } + } + f.Close() + + return nil +} + func TestStore_FindConflicts(t *testing.T) { store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "TestStore_FindConflicts")) assert.NoError(t, err) - id, err := store.Claim(context.TODO(), "000000000000000001", "foo", "Bordeaux", CLAIM_TYPE_TRADE) - assert.NoError(t, err) + id, err := store.Claim(context.Background(), "000000000000000001", "foo", "Bordeaux", CLAIM_TYPE_TRADE) + require.NoError(t, err) type args struct { ctx context.Context @@ -42,7 +65,7 @@ func TestStore_FindConflicts(t *testing.T) { { name: "overlapping", args: args{ - context.TODO(), + context.Background(), "000000000000000002", "Iberia", CLAIM_TYPE_REGION, diff --git a/go.mod b/go.mod index e6522d9..2de0931 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module go.wperron.io/themis -go 1.21 +go 1.19 require ( github.com/bwmarrin/discordgo v0.26.1 diff --git a/migrations/20240105153349_create_claimables_view.down.sql b/migrations/20240105153349_create_claimables_view.down.sql new file mode 100644 index 0000000..9bfe924 --- /dev/null +++ b/migrations/20240105153349_create_claimables_view.down.sql @@ -0,0 +1 @@ +drop claimables; diff --git a/migrations/20240105153349_create_claimables_view.up.sql b/migrations/20240105153349_create_claimables_view.up.sql new file mode 100644 index 0000000..fa67086 --- /dev/null +++ b/migrations/20240105153349_create_claimables_view.up.sql @@ -0,0 +1,14 @@ +create view if not exists claimables as + with + trades as (select distinct trade_node from provinces where trade_node != ''), + areas as (select distinct area from provinces where area != ''), + regions as (select distinct region from provinces where region != '') + select 'trade' as typ, provinces.trade_node as name, name as province, id + from provinces inner join trades on trades.trade_node = provinces.trade_node + union + select 'area' as typ, provinces.area as name, name as province, id + from provinces inner join areas on areas.area = provinces.area + union + select 'region' as typ, provinces.region as name, name as province, id + from provinces inner join regions on regions.region = provinces.region +; diff --git a/store.go b/store.go index 9802858..c2aab6a 100644 --- a/store.go +++ b/store.go @@ -45,16 +45,17 @@ func NewStore(conn string) (*Store, error) { return nil, fmt.Errorf("failed to initialize db migrate: %w", err) } + err = m.Up() + if err != nil && !errors.Is(err, migrate.ErrNoChange) { + return nil, fmt.Errorf("failed to roll up migrations: %w", err) + } + ver, dirty, err := m.Version() - if err != nil { + if err != nil && err != migrate.ErrNilVersion { return nil, fmt.Errorf("failed to get database migration version: %w", err) } log.Debug().Uint("current_version", ver).Bool("dirty", dirty).Msg("running database migrations") - err = m.Up() - if err != nil && !errors.Is(err, migrate.ErrNoChange) { - return nil, fmt.Errorf("failed to roll up migrations: %w", err) - } return &Store{ db: db, @@ -71,7 +72,7 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai Ctx(ctx). Str("userid", userId). Str("player", player). - Str("provice", province). + Str("province", province). Stringer("claim_type", claimType). Msg("inserting claim") audit := &AuditableEvent{ @@ -100,13 +101,14 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai } // check that provided name matches the claim type - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE LOWER(provinces.%s) = ?`, claimTypeToColumn[claimType])) + stmt, err := s.db.PrepareContext(ctx, `SELECT COUNT(1) FROM claimables WHERE lower(name) = ? and typ = ?`) if err != nil { audit.err = err return 0, fmt.Errorf("failed to prepare count query: %w", err) } + defer stmt.Close() - row := stmt.QueryRowContext(ctx, strings.ToLower(province)) + row := stmt.QueryRowContext(ctx, strings.ToLower(province), claimType) var count int err = row.Scan(&count) if err != nil { @@ -124,6 +126,7 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai audit.err = err return 0, fmt.Errorf("failed to prepare claim query: %w", err) } + defer stmt.Close() res, err := stmt.ExecContext(ctx, player, claimType, province, userId) if err != nil { @@ -143,20 +146,24 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai func (s *Store) ListAvailability(ctx context.Context, claimType ClaimType, search ...string) ([]string, error) { log.Debug().Ctx(ctx).Stringer("claim_type", claimType).Strs("search_terms", search).Msg("listing available entries") queryParams := []any{string(claimType)} - queryPattern := `SELECT DISTINCT(provinces.%[1]s) - FROM provinces LEFT JOIN claims ON provinces.%[1]s = claims.val AND claims.claim_type = ? + + queryPattern := `SELECT distinct name + FROM claimables + LEFT JOIN claims ON claimables.name = claims.val AND claimables.typ = claims.claim_type WHERE claims.val IS NULL - AND provinces.typ = 'Land'` + AND claimables.typ = ?` + if len(search) > 0 && search[0] != "" { // only take one search param, ignore the rest - queryPattern += `AND provinces.%[1]s LIKE ?` + queryPattern += `AND claimables.name LIKE ?` queryParams = append(queryParams, fmt.Sprintf("%%%s%%", search[0])) } - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(queryPattern, claimTypeToColumn[claimType])) + stmt, err := s.db.PrepareContext(ctx, queryPattern) if err != nil { return nil, fmt.Errorf("failed to prepare query: %w", err) } + defer stmt.Close() rows, err := stmt.QueryContext(ctx, queryParams...) if err != nil { @@ -181,6 +188,7 @@ func (s *Store) ListClaims(ctx context.Context) ([]Claim, error) { if err != nil { return nil, fmt.Errorf("failed to prepare query: %w", err) } + defer stmt.Close() rows, err := stmt.QueryContext(ctx) if err != nil { @@ -227,6 +235,7 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error) if err != nil { return ClaimDetail{}, fmt.Errorf("failed to get claim: %w", err) } + defer stmt.Close() row := stmt.QueryRowContext(ctx, ID) @@ -245,12 +254,13 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error) } c.Type = cl - stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT name FROM provinces where provinces.%s = ?`, claimTypeToColumn[cl])) + stmt, err = s.db.PrepareContext(ctx, `SELECT province FROM claimables WHERE name = ? AND typ = ?`) if err != nil { return ClaimDetail{}, fmt.Errorf("failed to prepare query: %w", err) } + defer stmt.Close() - rows, err := stmt.QueryContext(ctx, c.Name) + rows, err := stmt.QueryContext(ctx, c.Name, cl) if err != nil { return ClaimDetail{}, fmt.Errorf("failed to execute query: %w", err) } @@ -284,6 +294,7 @@ func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error { audit.err = err return fmt.Errorf("failed to prepare query: %w", err) } + defer stmt.Close() res, err := stmt.ExecContext(ctx, ID, userId) if err != nil { @@ -309,6 +320,7 @@ func (s *Store) CountClaims(ctx context.Context) (total, uniquePlayers int, err if err != nil { return 0, 0, fmt.Errorf("failed to prepare query: %w", err) } + defer stmt.Close() res := stmt.QueryRowContext(ctx) diff --git a/store_test.go b/store_test.go index 2316a13..c7ea76b 100644 --- a/store_test.go +++ b/store_test.go @@ -122,22 +122,22 @@ func TestAvailability(t *testing.T) { store.Claim(context.TODO(), "000000000000000001", "foo", "France", CLAIM_TYPE_REGION) store.Claim(context.TODO(), "000000000000000001", "foo", "Italy", CLAIM_TYPE_REGION) - // There's a total of 73 distinct regions, there should be 71 available + // There's a total of 92 distinct regions, there should be 90 available // after the two claims above availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_REGION) assert.NoError(t, err) - assert.Equal(t, 71, len(availability)) + assert.Equal(t, 90, len(availability)) store.Claim(context.TODO(), "000000000000000001", "foo", "Normandy", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Champagne", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Lorraine", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Provence", CLAIM_TYPE_AREA) - // There's a total of 823 distinct regions, there should be 819 available + // There's a total of 882 distinct regions, there should be 878 available // after the four claims above availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA) assert.NoError(t, err) - assert.Equal(t, 819, len(availability)) + assert.Equal(t, 878, len(availability)) // There is both a Trade Node and an Area called 'Valencia', while the trade // node is claimed, the area should show up in the availability list (even @@ -145,11 +145,11 @@ func TestAvailability(t *testing.T) { store.Claim(context.TODO(), "000000000000000001", "foo", "Valencia", CLAIM_TYPE_TRADE) availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA) assert.NoError(t, err) - assert.Equal(t, 819, len(availability)) // availability for areas should be the same as before + assert.Equal(t, 878, len(availability)) // availability for areas should be the same as before availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA, "bay") assert.NoError(t, err) - assert.Equal(t, 3, len(availability)) // availability for areas should be the same as before + assert.Equal(t, 6, len(availability)) // availability for areas should be the same as before } func TestDeleteClaim(t *testing.T) { diff --git a/uptime_darwin.go b/uptime_darwin.go index 29fbd0e..3665cd3 100644 --- a/uptime_darwin.go +++ b/uptime_darwin.go @@ -3,10 +3,11 @@ package themis import ( + "context" "time" ) // Uptime returns the time elapsed since the start of the current process ID. -func Uptime() (time.Duration, error) { +func Uptime(ctx context.Context) (time.Duration, error) { return 0, nil }