Add userid column to claims table

Fixes #11
absences
William Perron 2 years ago
parent c2d8cf2b83
commit 84df9458be

@ -199,7 +199,9 @@ func main() {
player = i.Member.User.Username player = i.Member.User.Username
} }
err = store.Claim(ctx, player, name, claimType) userId := i.Member.User.ID
_, err = store.Claim(ctx, userId, player, name, claimType)
if err != nil { if err != nil {
conflict, ok := err.(themis.ErrConflict) conflict, ok := err.(themis.ErrConflict)
if ok { if ok {
@ -246,11 +248,8 @@ func main() {
}, },
"delete-claim": func(s *discordgo.Session, i *discordgo.InteractionCreate) { "delete-claim": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
id := i.ApplicationCommandData().Options[0] id := i.ApplicationCommandData().Options[0]
nick := i.Member.Nick userId := i.Member.User.ID
if nick == "" { err := store.DeleteClaim(ctx, int(id.IntValue()), userId)
nick = i.Member.User.Username
}
err := store.DeleteClaim(ctx, int(id.IntValue()), nick)
if err != nil { if err != nil {
msg := "Oops, something went wrong :( blame @wperron" msg := "Oops, something went wrong :( blame @wperron"
if errors.Is(err, themis.ErrNoSuchClaim) { if errors.Is(err, themis.ErrNoSuchClaim) {

@ -0,0 +1,6 @@
-- Part of #11
ALTER TABLE claims ADD COLUMN userid TEXT;
UPDATE claims SET userid='212714834490294273' WHERE player = 'shinemperor';
UPDATE claims SET userid='345340157333078016' WHERE player = 'wperron';
UPDATE claims SET userid='203896675960487936' WHERE player = 'Gillfren';
UPDATE claims SET userid='264861923814801408' WHERE player = 'B i r b';

@ -3952,6 +3952,7 @@ CREATE TABLE IF NOT EXISTS claims (
player TEXT, player TEXT,
claim_type TEXT, claim_type TEXT,
val TEXT, val TEXT,
userid TEXT,
FOREIGN KEY(claim_type) REFERENCES claim_types(claim_type) FOREIGN KEY(claim_type) REFERENCES claim_types(claim_type)
); );

@ -79,26 +79,26 @@ func NewStore(conn string) (*Store, error) {
}, nil }, nil
} }
func (s *Store) Claim(ctx context.Context, player, province string, claimType ClaimType) error { func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) {
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
return fmt.Errorf("failed to begin transaction: %w", err) return 0, fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Commit() defer tx.Commit()
// Check conflicts // Check conflicts
stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in ( stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in (
SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.player IS NOT ? SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ?
UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.player IS NOT ? UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.userid IS NOT ?
UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.player IS NOT ? UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.userid IS NOT ?
)`, claimTypeToColumn[claimType])) )`, claimTypeToColumn[claimType]))
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare conflicts query: %w", err) return 0, fmt.Errorf("failed to prepare conflicts query: %w", err)
} }
rows, err := stmt.QueryContext(ctx, province, player, player, player) rows, err := stmt.QueryContext(ctx, province, userId, userId, userId)
if err != nil { if err != nil {
return fmt.Errorf("failed to get conflicting provinces: %w", err) return 0, fmt.Errorf("failed to get conflicting provinces: %w", err)
} }
conflicts := make([]string, 0) conflicts := make([]string, 0)
@ -106,43 +106,48 @@ func (s *Store) Claim(ctx context.Context, player, province string, claimType Cl
var p string var p string
err = rows.Scan(&p) err = rows.Scan(&p)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan row: %w", err) return 0, fmt.Errorf("failed to scan row: %w", err)
} }
conflicts = append(conflicts, p) conflicts = append(conflicts, p)
} }
if len(conflicts) > 0 { if len(conflicts) > 0 {
return ErrConflict{Conflicts: conflicts} return 0, ErrConflict{Conflicts: conflicts}
} }
// check that provided name matches the claim type // check that provided name matches the claim type
stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType]))
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare count query: %w", err) return 0, fmt.Errorf("failed to prepare count query: %w", err)
} }
row := stmt.QueryRowContext(ctx, province) row := stmt.QueryRowContext(ctx, province)
var count int var count int
err = row.Scan(&count) err = row.Scan(&count)
if err != nil { if err != nil {
return fmt.Errorf("failed to scan: %w", err) return 0, fmt.Errorf("failed to scan: %w", err)
} }
if count == 0 { if count == 0 {
return fmt.Errorf("found no provinces for %s named %s", claimType, province) return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province)
} }
stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val) VALUES (?, ?, ?)") stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)")
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare claim query: %w", err) return 0, fmt.Errorf("failed to prepare claim query: %w", err)
} }
_, err = stmt.ExecContext(ctx, player, claimType, province) res, err := stmt.ExecContext(ctx, player, claimType, province, userId)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert claim: %w", err) return 0, fmt.Errorf("failed to insert claim: %w", err)
} }
return nil id, err := res.LastInsertId()
if err != nil {
return 0, fmt.Errorf("failed to get last ID: %w", err)
}
return int(id), nil
} }
func (s *Store) ListAvailability(ctx context.Context, claimType ClaimType, search ...string) ([]string, error) { func (s *Store) ListAvailability(ctx context.Context, claimType ClaimType, search ...string) ([]string, error) {
@ -270,15 +275,15 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error)
}, nil }, nil
} }
var ErrNoSuchClaim = errors.New("No such claim found for player") var ErrNoSuchClaim = errors.New("no such claim found for player")
func (s *Store) DeleteClaim(ctx context.Context, ID int, player string) error { func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error {
stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND player = ?") stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?")
if err != nil { if err != nil {
return fmt.Errorf("failed to prepare query: %w", err) return fmt.Errorf("failed to prepare query: %w", err)
} }
res, err := stmt.ExecContext(ctx, ID, player) res, err := stmt.ExecContext(ctx, ID, userId)
if err != nil { if err != nil {
return fmt.Errorf("failed to delete claim ID %d: %w", ID, err) return fmt.Errorf("failed to delete claim ID %d: %w", ID, err)
} }

@ -3,6 +3,7 @@ package themis
import ( import (
"context" "context"
_ "embed" _ "embed"
"fmt"
"testing" "testing"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
@ -18,6 +19,7 @@ func TestStore_Claim(t *testing.T) {
type args struct { type args struct {
player string player string
province string province string
userId string
claimType ClaimType claimType ClaimType
} }
tests := []struct { tests := []struct {
@ -31,6 +33,7 @@ func TestStore_Claim(t *testing.T) {
player: "foo", player: "foo",
province: "Italy", province: "Italy",
claimType: CLAIM_TYPE_REGION, claimType: CLAIM_TYPE_REGION,
userId: "000000000000000001",
}, },
wantErr: false, wantErr: false,
}, },
@ -40,6 +43,7 @@ func TestStore_Claim(t *testing.T) {
player: "foo", player: "foo",
province: "Italy", province: "Italy",
claimType: CLAIM_TYPE_TRADE, // Italy is a Region you silly goose claimType: CLAIM_TYPE_TRADE, // Italy is a Region you silly goose
userId: "000000000000000001",
}, },
wantErr: true, wantErr: true,
}, },
@ -49,6 +53,7 @@ func TestStore_Claim(t *testing.T) {
player: "bar", player: "bar",
province: "Genoa", province: "Genoa",
claimType: CLAIM_TYPE_TRADE, claimType: CLAIM_TYPE_TRADE,
userId: "000000000000000002",
}, },
wantErr: true, wantErr: true,
}, },
@ -58,13 +63,14 @@ func TestStore_Claim(t *testing.T) {
player: "foo", // 'foo' has a claim on Italy, which has overlapping provinces player: "foo", // 'foo' has a claim on Italy, which has overlapping provinces
province: "Genoa", province: "Genoa",
claimType: CLAIM_TYPE_TRADE, claimType: CLAIM_TYPE_TRADE,
userId: "000000000000000001",
}, },
wantErr: false, wantErr: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := store.Claim(context.TODO(), tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr { if _, err := store.Claim(context.TODO(), tt.args.userId, tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr {
t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr)
} }
}) })
@ -75,9 +81,9 @@ func TestAvailability(t *testing.T) {
store, err := NewStore(TEST_CONN_STRING) store, err := NewStore(TEST_CONN_STRING)
assert.NoError(t, err) assert.NoError(t, err)
store.Claim(context.TODO(), "foo", "Genoa", CLAIM_TYPE_TRADE) store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE)
store.Claim(context.TODO(), "foo", "Venice", CLAIM_TYPE_TRADE) store.Claim(context.TODO(), "000000000000000001", "foo", "Venice", CLAIM_TYPE_TRADE)
store.Claim(context.TODO(), "foo", "English Channel", CLAIM_TYPE_TRADE) store.Claim(context.TODO(), "000000000000000001", "foo", "English Channel", CLAIM_TYPE_TRADE)
// There's a total of 80 distinct trade nodes, there should be 77 available // There's a total of 80 distinct trade nodes, there should be 77 available
// after the three claims above // after the three claims above
@ -85,8 +91,8 @@ func TestAvailability(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 77, len(availability)) assert.Equal(t, 77, len(availability))
store.Claim(context.TODO(), "foo", "France", CLAIM_TYPE_REGION) store.Claim(context.TODO(), "000000000000000001", "foo", "France", CLAIM_TYPE_REGION)
store.Claim(context.TODO(), "foo", "Italy", CLAIM_TYPE_REGION) store.Claim(context.TODO(), "000000000000000001", "foo", "Italy", CLAIM_TYPE_REGION)
// There's a total of 73 distinct regions, there should be 71 available // There's a total of 73 distinct regions, there should be 71 available
// after the two claims above // after the two claims above
@ -94,10 +100,10 @@ func TestAvailability(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 71, len(availability)) assert.Equal(t, 71, len(availability))
store.Claim(context.TODO(), "foo", "Normandy", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Normandy", CLAIM_TYPE_AREA)
store.Claim(context.TODO(), "foo", "Champagne", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Champagne", CLAIM_TYPE_AREA)
store.Claim(context.TODO(), "foo", "Lorraine", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Lorraine", CLAIM_TYPE_AREA)
store.Claim(context.TODO(), "foo", "Provence", CLAIM_TYPE_AREA) store.Claim(context.TODO(), "000000000000000001", "foo", "Provence", CLAIM_TYPE_AREA)
// There's a total of 823 distinct regions, there should be 819 available // There's a total of 823 distinct regions, there should be 819 available
// after the four claims above // after the four claims above
@ -108,7 +114,7 @@ func TestAvailability(t *testing.T) {
// There is both a Trade Node and an Area called 'Valencia', while the trade // There is both a Trade Node and an Area called 'Valencia', while the trade
// node is claimed, the area should show up in the availability list (even // node is claimed, the area should show up in the availability list (even
// though there are conflicting provinces) // though there are conflicting provinces)
store.Claim(context.TODO(), "foo", "Valencia", CLAIM_TYPE_TRADE) store.Claim(context.TODO(), "000000000000000001", "foo", "Valencia", CLAIM_TYPE_TRADE)
availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA) availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 819, len(availability)) // availability for areas should be the same as before assert.Equal(t, 819, len(availability)) // availability for areas should be the same as before
@ -122,14 +128,23 @@ func TestDeleteClaim(t *testing.T) {
store, err := NewStore(TEST_CONN_STRING) store, err := NewStore(TEST_CONN_STRING)
assert.NoError(t, err) assert.NoError(t, err)
store.Claim(context.TODO(), "foo", "Genoa", CLAIM_TYPE_TRADE) // make sure all claims are gone, this is due to how the in-memory database
store.Claim(context.TODO(), "bar", "Balkans", CLAIM_TYPE_REGION) // with a shared cache interacts with other tests running in parallel
store.Claim(context.TODO(), "baz", "English Channel", CLAIM_TYPE_TRADE) _, err = store.db.ExecContext(context.TODO(), "DELETE FROM claims")
assert.NoError(t, err)
fooId, _ := store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE)
barId, _ := store.Claim(context.TODO(), "000000000000000002", "bar", "Balkans", CLAIM_TYPE_REGION)
store.Claim(context.TODO(), "000000000000000003", "baz", "English Channel", CLAIM_TYPE_TRADE)
claims, err := store.ListClaims(context.TODO())
assert.NoError(t, err)
fmt.Print(claims)
err = store.DeleteClaim(context.TODO(), 1, "foo") err = store.DeleteClaim(context.TODO(), fooId, "000000000000000001")
assert.NoError(t, err) assert.NoError(t, err)
err = store.DeleteClaim(context.TODO(), 2, "foo") err = store.DeleteClaim(context.TODO(), barId, "000000000000000001")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoSuchClaim) assert.ErrorIs(t, err, ErrNoSuchClaim)
} }

Loading…
Cancel
Save