diff --git a/cmd/themis-server/main.go b/cmd/themis-server/main.go index c24f681..6a44196 100644 --- a/cmd/themis-server/main.go +++ b/cmd/themis-server/main.go @@ -199,7 +199,9 @@ func main() { player = i.Member.User.Username } - err = store.Claim(ctx, player, name, claimType) + userId := i.Member.User.ID + + _, err = store.Claim(ctx, userId, player, name, claimType) if err != nil { conflict, ok := err.(themis.ErrConflict) if ok { @@ -246,11 +248,8 @@ func main() { }, "delete-claim": func(s *discordgo.Session, i *discordgo.InteractionCreate) { id := i.ApplicationCommandData().Options[0] - nick := i.Member.Nick - if nick == "" { - nick = i.Member.User.Username - } - err := store.DeleteClaim(ctx, int(id.IntValue()), nick) + userId := i.Member.User.ID + err := store.DeleteClaim(ctx, int(id.IntValue()), userId) if err != nil { msg := "Oops, something went wrong :( blame @wperron" if errors.Is(err, themis.ErrNoSuchClaim) { diff --git a/migrations/20220912-add-user-id-col.sql b/migrations/20220912-add-user-id-col.sql new file mode 100644 index 0000000..941c043 --- /dev/null +++ b/migrations/20220912-add-user-id-col.sql @@ -0,0 +1,6 @@ +-- Part of #11 +ALTER TABLE claims ADD COLUMN userid TEXT; +UPDATE claims SET userid='212714834490294273' WHERE player = 'shinemperor'; +UPDATE claims SET userid='345340157333078016' WHERE player = 'wperron'; +UPDATE claims SET userid='203896675960487936' WHERE player = 'Gillfren'; +UPDATE claims SET userid='264861923814801408' WHERE player = 'B i r b'; \ No newline at end of file diff --git a/migrations/init.sql b/migrations/init.sql index bd8cd9f..5853437 100644 --- a/migrations/init.sql +++ b/migrations/init.sql @@ -3952,6 +3952,7 @@ CREATE TABLE IF NOT EXISTS claims ( player TEXT, claim_type TEXT, val TEXT, + userid TEXT, FOREIGN KEY(claim_type) REFERENCES claim_types(claim_type) ); diff --git a/store.go b/store.go index b2aeffa..703b005 100644 --- a/store.go +++ b/store.go @@ -79,26 +79,26 @@ func NewStore(conn string) (*Store, error) { }, nil } -func (s *Store) Claim(ctx context.Context, player, province string, claimType ClaimType) error { +func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) { tx, err := s.db.Begin() if err != nil { - return fmt.Errorf("failed to begin transaction: %w", err) + return 0, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Commit() // Check conflicts stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in ( - SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.player IS NOT ? - UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.player IS NOT ? - UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.player IS NOT ? + SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ? + UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.userid IS NOT ? + UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.userid IS NOT ? )`, claimTypeToColumn[claimType])) if err != nil { - return fmt.Errorf("failed to prepare conflicts query: %w", err) + return 0, fmt.Errorf("failed to prepare conflicts query: %w", err) } - rows, err := stmt.QueryContext(ctx, province, player, player, player) + rows, err := stmt.QueryContext(ctx, province, userId, userId, userId) if err != nil { - return fmt.Errorf("failed to get conflicting provinces: %w", err) + return 0, fmt.Errorf("failed to get conflicting provinces: %w", err) } conflicts := make([]string, 0) @@ -106,43 +106,48 @@ func (s *Store) Claim(ctx context.Context, player, province string, claimType Cl var p string err = rows.Scan(&p) if err != nil { - return fmt.Errorf("failed to scan row: %w", err) + return 0, fmt.Errorf("failed to scan row: %w", err) } conflicts = append(conflicts, p) } if len(conflicts) > 0 { - return ErrConflict{Conflicts: conflicts} + return 0, ErrConflict{Conflicts: conflicts} } // check that provided name matches the claim type stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) if err != nil { - return fmt.Errorf("failed to prepare count query: %w", err) + return 0, fmt.Errorf("failed to prepare count query: %w", err) } row := stmt.QueryRowContext(ctx, province) var count int err = row.Scan(&count) if err != nil { - return fmt.Errorf("failed to scan: %w", err) + return 0, fmt.Errorf("failed to scan: %w", err) } if count == 0 { - return fmt.Errorf("found no provinces for %s named %s", claimType, province) + return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province) } - stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val) VALUES (?, ?, ?)") + stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)") if err != nil { - return fmt.Errorf("failed to prepare claim query: %w", err) + return 0, fmt.Errorf("failed to prepare claim query: %w", err) } - _, err = stmt.ExecContext(ctx, player, claimType, province) + res, err := stmt.ExecContext(ctx, player, claimType, province, userId) if err != nil { - return fmt.Errorf("failed to insert claim: %w", err) + return 0, fmt.Errorf("failed to insert claim: %w", err) } - return nil + id, err := res.LastInsertId() + if err != nil { + return 0, fmt.Errorf("failed to get last ID: %w", err) + } + + return int(id), nil } func (s *Store) ListAvailability(ctx context.Context, claimType ClaimType, search ...string) ([]string, error) { @@ -270,15 +275,15 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error) }, nil } -var ErrNoSuchClaim = errors.New("No such claim found for player") +var ErrNoSuchClaim = errors.New("no such claim found for player") -func (s *Store) DeleteClaim(ctx context.Context, ID int, player string) error { - stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND player = ?") +func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error { + stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?") if err != nil { return fmt.Errorf("failed to prepare query: %w", err) } - res, err := stmt.ExecContext(ctx, ID, player) + res, err := stmt.ExecContext(ctx, ID, userId) if err != nil { return fmt.Errorf("failed to delete claim ID %d: %w", ID, err) } diff --git a/store_test.go b/store_test.go index 5457b79..18780aa 100644 --- a/store_test.go +++ b/store_test.go @@ -3,6 +3,7 @@ package themis import ( "context" _ "embed" + "fmt" "testing" _ "github.com/mattn/go-sqlite3" @@ -18,6 +19,7 @@ func TestStore_Claim(t *testing.T) { type args struct { player string province string + userId string claimType ClaimType } tests := []struct { @@ -31,6 +33,7 @@ func TestStore_Claim(t *testing.T) { player: "foo", province: "Italy", claimType: CLAIM_TYPE_REGION, + userId: "000000000000000001", }, wantErr: false, }, @@ -40,6 +43,7 @@ func TestStore_Claim(t *testing.T) { player: "foo", province: "Italy", claimType: CLAIM_TYPE_TRADE, // Italy is a Region you silly goose + userId: "000000000000000001", }, wantErr: true, }, @@ -49,6 +53,7 @@ func TestStore_Claim(t *testing.T) { player: "bar", province: "Genoa", claimType: CLAIM_TYPE_TRADE, + userId: "000000000000000002", }, wantErr: true, }, @@ -58,13 +63,14 @@ func TestStore_Claim(t *testing.T) { player: "foo", // 'foo' has a claim on Italy, which has overlapping provinces province: "Genoa", claimType: CLAIM_TYPE_TRADE, + userId: "000000000000000001", }, wantErr: false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := store.Claim(context.TODO(), tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr { + if _, err := store.Claim(context.TODO(), tt.args.userId, tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr { t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -75,9 +81,9 @@ func TestAvailability(t *testing.T) { store, err := NewStore(TEST_CONN_STRING) assert.NoError(t, err) - store.Claim(context.TODO(), "foo", "Genoa", CLAIM_TYPE_TRADE) - store.Claim(context.TODO(), "foo", "Venice", CLAIM_TYPE_TRADE) - store.Claim(context.TODO(), "foo", "English Channel", CLAIM_TYPE_TRADE) + store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE) + store.Claim(context.TODO(), "000000000000000001", "foo", "Venice", CLAIM_TYPE_TRADE) + store.Claim(context.TODO(), "000000000000000001", "foo", "English Channel", CLAIM_TYPE_TRADE) // There's a total of 80 distinct trade nodes, there should be 77 available // after the three claims above @@ -85,8 +91,8 @@ func TestAvailability(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 77, len(availability)) - store.Claim(context.TODO(), "foo", "France", CLAIM_TYPE_REGION) - store.Claim(context.TODO(), "foo", "Italy", CLAIM_TYPE_REGION) + store.Claim(context.TODO(), "000000000000000001", "foo", "France", CLAIM_TYPE_REGION) + store.Claim(context.TODO(), "000000000000000001", "foo", "Italy", CLAIM_TYPE_REGION) // There's a total of 73 distinct regions, there should be 71 available // after the two claims above @@ -94,10 +100,10 @@ func TestAvailability(t *testing.T) { assert.NoError(t, err) assert.Equal(t, 71, len(availability)) - store.Claim(context.TODO(), "foo", "Normandy", CLAIM_TYPE_AREA) - store.Claim(context.TODO(), "foo", "Champagne", CLAIM_TYPE_AREA) - store.Claim(context.TODO(), "foo", "Lorraine", CLAIM_TYPE_AREA) - store.Claim(context.TODO(), "foo", "Provence", CLAIM_TYPE_AREA) + store.Claim(context.TODO(), "000000000000000001", "foo", "Normandy", CLAIM_TYPE_AREA) + store.Claim(context.TODO(), "000000000000000001", "foo", "Champagne", CLAIM_TYPE_AREA) + store.Claim(context.TODO(), "000000000000000001", "foo", "Lorraine", CLAIM_TYPE_AREA) + store.Claim(context.TODO(), "000000000000000001", "foo", "Provence", CLAIM_TYPE_AREA) // There's a total of 823 distinct regions, there should be 819 available // after the four claims above @@ -108,7 +114,7 @@ func TestAvailability(t *testing.T) { // There is both a Trade Node and an Area called 'Valencia', while the trade // node is claimed, the area should show up in the availability list (even // though there are conflicting provinces) - store.Claim(context.TODO(), "foo", "Valencia", CLAIM_TYPE_TRADE) + store.Claim(context.TODO(), "000000000000000001", "foo", "Valencia", CLAIM_TYPE_TRADE) availability, err = store.ListAvailability(context.TODO(), CLAIM_TYPE_AREA) assert.NoError(t, err) assert.Equal(t, 819, len(availability)) // availability for areas should be the same as before @@ -122,14 +128,23 @@ func TestDeleteClaim(t *testing.T) { store, err := NewStore(TEST_CONN_STRING) assert.NoError(t, err) - store.Claim(context.TODO(), "foo", "Genoa", CLAIM_TYPE_TRADE) - store.Claim(context.TODO(), "bar", "Balkans", CLAIM_TYPE_REGION) - store.Claim(context.TODO(), "baz", "English Channel", CLAIM_TYPE_TRADE) + // make sure all claims are gone, this is due to how the in-memory database + // with a shared cache interacts with other tests running in parallel + _, err = store.db.ExecContext(context.TODO(), "DELETE FROM claims") + assert.NoError(t, err) + + fooId, _ := store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE) + barId, _ := store.Claim(context.TODO(), "000000000000000002", "bar", "Balkans", CLAIM_TYPE_REGION) + store.Claim(context.TODO(), "000000000000000003", "baz", "English Channel", CLAIM_TYPE_TRADE) + + claims, err := store.ListClaims(context.TODO()) + assert.NoError(t, err) + fmt.Print(claims) - err = store.DeleteClaim(context.TODO(), 1, "foo") + err = store.DeleteClaim(context.TODO(), fooId, "000000000000000001") assert.NoError(t, err) - err = store.DeleteClaim(context.TODO(), 2, "foo") + err = store.DeleteClaim(context.TODO(), barId, "000000000000000001") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoSuchClaim) }