diff --git a/conflicts.go b/conflicts.go new file mode 100644 index 0000000..7448192 --- /dev/null +++ b/conflicts.go @@ -0,0 +1,82 @@ +package themis + +import ( + "context" + "fmt" +) + +type Conflict struct { + Province string + Player string + ClaimType ClaimType + Claim string + ClaimID int +} + +func (c Conflict) String() string { + return fmt.Sprintf("%s owned by #%d %s %s (%s)", c.Province, c.ClaimID, c.ClaimType, c.Claim, c.Player) +} + +const conflictQuery string = `SELECT name, player, claim_type, val, id FROM ( + SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id + FROM claims + LEFT JOIN provinces ON claims.val = provinces.trade_node + WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ? + AND provinces.%[1]s = ? + UNION + SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id + FROM claims + LEFT JOIN provinces ON claims.val = provinces.region + WHERE claims.claim_type = 'region' AND claims.userid IS NOT ? + AND provinces.%[1]s = ? + UNION + SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id + FROM claims + LEFT JOIN provinces ON claims.val = provinces.area + WHERE claims.claim_type = 'area' AND claims.userid IS NOT ? + AND provinces.%[1]s = ? +);` + +func (s *Store) FindConflicts(ctx context.Context, userId, name string, claimType ClaimType) ([]Conflict, error) { + stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(conflictQuery, claimTypeToColumn[claimType])) + if err != nil { + return nil, fmt.Errorf("failed to prepare conflicts query: %w", err) + } + + rows, err := stmt.QueryContext(ctx, userId, name, userId, name, userId, name) + if err != nil { + return nil, fmt.Errorf("failed to get conflicting provinces: %w", err) + } + + conflicts := make([]Conflict, 0) + for rows.Next() { + var ( + province string + player string + sClaimType string + claimName string + claimId int + ) + err = rows.Scan(&province, &player, &sClaimType, &claimName, &claimId) + if err != nil { + return nil, fmt.Errorf("failed to scan row: %w", err) + } + + ct, err := ClaimTypeFromString(sClaimType) + if err != nil { + // In case of an error parsing the claim type, simply default to + // whatever the database sends; this is a read-only function, the + // input validation is assumed to have already been done at insert. + ct = ClaimType(sClaimType) + } + conflicts = append(conflicts, Conflict{ + Province: province, + Player: player, + ClaimType: ct, + Claim: claimName, + ClaimID: claimId, + }) + } + + return conflicts, nil +} diff --git a/conflicts_test.go b/conflicts_test.go new file mode 100644 index 0000000..b13b5c4 --- /dev/null +++ b/conflicts_test.go @@ -0,0 +1,80 @@ +package themis + +import ( + "context" + "reflect" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestStore_FindConflicts(t *testing.T) { + store, err := NewStore(TEST_CONN_STRING) + assert.NoError(t, err) + + id, err := store.Claim(context.TODO(), "000000000000000001", "foo", "Bordeaux", CLAIM_TYPE_TRADE) + assert.NoError(t, err) + + type args struct { + ctx context.Context + userId string + name string + claimType ClaimType + } + tests := []struct { + name string + args args + want []Conflict + wantErr bool + }{ + { + name: "same-player", + args: args{ + context.TODO(), + "000000000000000001", + "France", + CLAIM_TYPE_REGION, + }, + want: []Conflict{}, + wantErr: false, + }, + { + name: "overlapping", + args: args{ + context.TODO(), + "000000000000000002", + "Iberia", + CLAIM_TYPE_REGION, + }, + want: []Conflict{ + {Province: "Navarra", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id}, + {Province: "Rioja", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id}, + {Province: "Vizcaya", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id}, + }, + wantErr: false, + }, + { + name: "no-overlap", + args: args{ + context.TODO(), + "000000000000000002", + "Scandinavia", + CLAIM_TYPE_REGION, + }, + want: []Conflict{}, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := store.FindConflicts(tt.args.ctx, tt.args.userId, tt.args.name, tt.args.claimType) + if (err != nil) != tt.wantErr { + t.Errorf("Store.FindConflicts() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Store.FindConflicts() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/store.go b/store.go index 92ce02a..5b3f9e6 100644 --- a/store.go +++ b/store.go @@ -64,11 +64,11 @@ func (c Claim) String() string { } type ErrConflict struct { - Conflicts []string + Conflicts []Conflict } func (ec ErrConflict) Error() string { - return fmt.Sprintf("found conflicting provinces: %s", strings.Join(ec.Conflicts, ", ")) + return fmt.Sprintf("found %d conflicting provinces", len(ec.Conflicts)) } type Store struct { @@ -98,29 +98,9 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai } defer tx.Commit() - // Check conflicts - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in ( - SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ? - UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.userid IS NOT ? - UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.userid IS NOT ? -)`, claimTypeToColumn[claimType])) + conflicts, err := s.FindConflicts(ctx, userId, province, claimType) if err != nil { - return 0, fmt.Errorf("failed to prepare conflicts query: %w", err) - } - - rows, err := stmt.QueryContext(ctx, province, userId, userId, userId) - if err != nil { - return 0, fmt.Errorf("failed to get conflicting provinces: %w", err) - } - - conflicts := make([]string, 0) - for rows.Next() { - var p string - err = rows.Scan(&p) - if err != nil { - return 0, fmt.Errorf("failed to scan row: %w", err) - } - conflicts = append(conflicts, p) + return 0, fmt.Errorf("failed to run conflicts check: %w", err) } if len(conflicts) > 0 { @@ -128,7 +108,7 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai } // check that provided name matches the claim type - stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) + stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) if err != nil { return 0, fmt.Errorf("failed to prepare count query: %w", err) }