Add claim details to conflicts output

Fixes #2
absences
William Perron 2 years ago
parent 955a2648dd
commit 541e8691a9

@ -0,0 +1,82 @@
package themis
import (
"context"
"fmt"
)
type Conflict struct {
Province string
Player string
ClaimType ClaimType
Claim string
ClaimID int
}
func (c Conflict) String() string {
return fmt.Sprintf("%s owned by #%d %s %s (%s)", c.Province, c.ClaimID, c.ClaimType, c.Claim, c.Player)
}
const conflictQuery string = `SELECT name, player, claim_type, val, id FROM (
SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id
FROM claims
LEFT JOIN provinces ON claims.val = provinces.trade_node
WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ?
AND provinces.%[1]s = ?
UNION
SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id
FROM claims
LEFT JOIN provinces ON claims.val = provinces.region
WHERE claims.claim_type = 'region' AND claims.userid IS NOT ?
AND provinces.%[1]s = ?
UNION
SELECT provinces.name, claims.player, claims.claim_type, claims.val, claims.id
FROM claims
LEFT JOIN provinces ON claims.val = provinces.area
WHERE claims.claim_type = 'area' AND claims.userid IS NOT ?
AND provinces.%[1]s = ?
);`
func (s *Store) FindConflicts(ctx context.Context, userId, name string, claimType ClaimType) ([]Conflict, error) {
stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(conflictQuery, claimTypeToColumn[claimType]))
if err != nil {
return nil, fmt.Errorf("failed to prepare conflicts query: %w", err)
}
rows, err := stmt.QueryContext(ctx, userId, name, userId, name, userId, name)
if err != nil {
return nil, fmt.Errorf("failed to get conflicting provinces: %w", err)
}
conflicts := make([]Conflict, 0)
for rows.Next() {
var (
province string
player string
sClaimType string
claimName string
claimId int
)
err = rows.Scan(&province, &player, &sClaimType, &claimName, &claimId)
if err != nil {
return nil, fmt.Errorf("failed to scan row: %w", err)
}
ct, err := ClaimTypeFromString(sClaimType)
if err != nil {
// In case of an error parsing the claim type, simply default to
// whatever the database sends; this is a read-only function, the
// input validation is assumed to have already been done at insert.
ct = ClaimType(sClaimType)
}
conflicts = append(conflicts, Conflict{
Province: province,
Player: player,
ClaimType: ct,
Claim: claimName,
ClaimID: claimId,
})
}
return conflicts, nil
}

@ -0,0 +1,80 @@
package themis
import (
"context"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
)
func TestStore_FindConflicts(t *testing.T) {
store, err := NewStore(TEST_CONN_STRING)
assert.NoError(t, err)
id, err := store.Claim(context.TODO(), "000000000000000001", "foo", "Bordeaux", CLAIM_TYPE_TRADE)
assert.NoError(t, err)
type args struct {
ctx context.Context
userId string
name string
claimType ClaimType
}
tests := []struct {
name string
args args
want []Conflict
wantErr bool
}{
{
name: "same-player",
args: args{
context.TODO(),
"000000000000000001",
"France",
CLAIM_TYPE_REGION,
},
want: []Conflict{},
wantErr: false,
},
{
name: "overlapping",
args: args{
context.TODO(),
"000000000000000002",
"Iberia",
CLAIM_TYPE_REGION,
},
want: []Conflict{
{Province: "Navarra", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id},
{Province: "Rioja", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id},
{Province: "Vizcaya", Player: "foo", ClaimType: "trade", Claim: "Bordeaux", ClaimID: id},
},
wantErr: false,
},
{
name: "no-overlap",
args: args{
context.TODO(),
"000000000000000002",
"Scandinavia",
CLAIM_TYPE_REGION,
},
want: []Conflict{},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := store.FindConflicts(tt.args.ctx, tt.args.userId, tt.args.name, tt.args.claimType)
if (err != nil) != tt.wantErr {
t.Errorf("Store.FindConflicts() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("Store.FindConflicts() = %v, want %v", got, tt.want)
}
})
}
}

@ -64,11 +64,11 @@ func (c Claim) String() string {
} }
type ErrConflict struct { type ErrConflict struct {
Conflicts []string Conflicts []Conflict
} }
func (ec ErrConflict) Error() string { func (ec ErrConflict) Error() string {
return fmt.Sprintf("found conflicting provinces: %s", strings.Join(ec.Conflicts, ", ")) return fmt.Sprintf("found %d conflicting provinces", len(ec.Conflicts))
} }
type Store struct { type Store struct {
@ -98,29 +98,9 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai
} }
defer tx.Commit() defer tx.Commit()
// Check conflicts conflicts, err := s.FindConflicts(ctx, userId, province, claimType)
stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT provinces.name FROM provinces WHERE provinces.%s = ? and provinces.name in (
SELECT provinces.name FROM claims LEFT JOIN provinces ON claims.val = provinces.trade_node WHERE claims.claim_type = 'trade' AND claims.userid IS NOT ?
UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.region WHERE claims.claim_type = 'region' AND claims.userid IS NOT ?
UNION SELECT provinces.name from claims LEFT JOIN provinces ON claims.val = provinces.area WHERE claims.claim_type = 'area' AND claims.userid IS NOT ?
)`, claimTypeToColumn[claimType]))
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to prepare conflicts query: %w", err) return 0, fmt.Errorf("failed to run conflicts check: %w", err)
}
rows, err := stmt.QueryContext(ctx, province, userId, userId, userId)
if err != nil {
return 0, fmt.Errorf("failed to get conflicting provinces: %w", err)
}
conflicts := make([]string, 0)
for rows.Next() {
var p string
err = rows.Scan(&p)
if err != nil {
return 0, fmt.Errorf("failed to scan row: %w", err)
}
conflicts = append(conflicts, p)
} }
if len(conflicts) > 0 { if len(conflicts) > 0 {
@ -128,7 +108,7 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai
} }
// check that provided name matches the claim type // check that provided name matches the claim type
stmt, err = s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType]))
if err != nil { if err != nil {
return 0, fmt.Errorf("failed to prepare count query: %w", err) return 0, fmt.Errorf("failed to prepare count query: %w", err)
} }

Loading…
Cancel
Save