From bd490bc18f9f0c15c6e9c7d4bd2611da1f5d67be Mon Sep 17 00:00:00 2001 From: William Perron Date: Mon, 19 Sep 2022 23:08:23 +0000 Subject: [PATCH] make claim command case insensitive Fixes #3 --- store.go | 4 ++-- store_test.go | 20 ++++++++++++++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/store.go b/store.go index d17901e..faeab8e 100644 --- a/store.go +++ b/store.go @@ -108,12 +108,12 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai } // check that provided name matches the claim type - stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE provinces.%s = ?`, claimTypeToColumn[claimType])) + stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE LOWER(provinces.%s) = ?`, claimTypeToColumn[claimType])) if err != nil { return 0, fmt.Errorf("failed to prepare count query: %w", err) } - row := stmt.QueryRowContext(ctx, province) + row := stmt.QueryRowContext(ctx, strings.ToLower(province)) var count int err = row.Scan(&count) if err != nil { diff --git a/store_test.go b/store_test.go index 8ab98c5..3cdd03e 100644 --- a/store_test.go +++ b/store_test.go @@ -68,6 +68,26 @@ func TestStore_Claim(t *testing.T) { }, wantErr: false, }, + { + name: "case sensitivity lower", + args: args{ + player: "foo", + province: "wien", + claimType: CLAIM_TYPE_TRADE, + userId: "000000000000000001", + }, + wantErr: false, + }, + { + name: "case sensitivity upper", + args: args{ + player: "foo", + province: "CONSTANTINOPLE", + claimType: CLAIM_TYPE_TRADE, + userId: "000000000000000001", + }, + wantErr: false, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {