new-sql-view
William Perron 11 months ago
parent b040694c2a
commit f0c401f604
Signed by: wperron
GPG Key ID: BFDB4EF72D73C5F2

@ -25,6 +25,9 @@ func TestAddAbsence(t *testing.T) {
absentees, err = store.GetAbsentees(context.TODO(), now) absentees, err = store.GetAbsentees(context.TODO(), now)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(absentees)) assert.Equal(t, 1, len(absentees))
_, err = store.LastOf(context.TODO(), EventAbsence)
require.NoError(t, err)
} }
func TestGetSchedule(t *testing.T) { func TestGetSchedule(t *testing.T) {

@ -52,51 +52,61 @@ func EventTypeFromString(ev string) (EventType, error) {
type AuditableEvent struct { type AuditableEvent struct {
userId string userId string
eventType EventType eventType EventType
Timestamp time.Time timestamp time.Time
err error
} }
// Audit writes to the audit table, returns nothing because it is meant to be // Audit writes to the audit table, returns nothing because it is meant to be
// used in a defered statement on functions that write to the database. // used in a defered statement on functions that write to the database.
func (s *Store) Audit(ev *AuditableEvent) { func (s *Store) Audit(ev *AuditableEvent) {
ctx := context.Background() if ev.err == nil {
ctx := context.Background()
tx, err := s.db.Begin()
if err != nil { tx, err := s.db.Begin()
log.Error().Err(err).Msg("failed to start transaction") if err != nil {
} log.Error().Err(err).Msg("failed to start transaction")
defer tx.Commit() //nolint:errcheck }
defer tx.Commit() //nolint:errcheck
stmt, err := s.db.PrepareContext(ctx, "INSERT INTO audit_log (userid, event_type, ts) VALUES (?, ?, ?)")
if err != nil { stmt, err := s.db.PrepareContext(ctx, "INSERT INTO audit_log (userid, event_type, ts) VALUES (?, ?, ?)")
log.Error().Err(err).Msg("failed to prepare audit log insert") if err != nil {
log.Error().Err(err).Msg("failed to prepare audit log insert")
}
if _, err := stmt.ExecContext(ctx, ev.userId, ev.eventType.String(), time.Now()); err != nil {
log.Error().Err(err).Msg("failed to insert audit log")
}
} }
}
if _, err := stmt.ExecContext(ctx, ev.userId, ev.eventType.String(), time.Now()); err != nil { type AuditEvent struct {
log.Error().Err(err).Msg("failed to insert audit log") id int
} userId string
eventType EventType
timestamp time.Time
} }
func (s *Store) LastOf(ctx context.Context, t EventType) (AuditableEvent, error) { func (s *Store) LastOf(ctx context.Context, t EventType) (AuditEvent, error) {
stmt, err := s.db.PrepareContext(ctx, `SELECT userid, event_type, ts FROM audit_log WHERE event_type = ? ORDER BY ts DESC LIMIT 1`) stmt, err := s.db.PrepareContext(ctx, `SELECT id, userid, event_type, ts FROM audit_log WHERE event_type = ? ORDER BY ts DESC LIMIT 1`)
if err != nil { if err != nil {
return AuditableEvent{}, fmt.Errorf("failed to get last event of type %s: %w", t.String(), err) return AuditEvent{}, fmt.Errorf("failed to get last event of type %s: %w", t.String(), err)
} }
row := stmt.QueryRowContext(ctx, t.String()) row := stmt.QueryRowContext(ctx, t.String())
ev := AuditableEvent{} ev := AuditEvent{}
var rawEventType string var rawEventType string
err = row.Scan(&ev.userId, &rawEventType, &ev.Timestamp) err = row.Scan(&ev.id, &ev.userId, &rawEventType, &ev.timestamp)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return AuditableEvent{}, errors.New("") return AuditEvent{}, errors.New("no rows found")
} }
if err != nil { if err != nil {
return AuditableEvent{}, fmt.Errorf("failed to scan row: %w", err) return AuditEvent{}, fmt.Errorf("failed to scan row: %w", err)
} }
ev.eventType, err = EventTypeFromString(rawEventType) ev.eventType, err = EventTypeFromString(rawEventType)
if err != nil { if err != nil {
return AuditableEvent{}, fmt.Errorf("failed to parse event type %s: %w", rawEventType, err) return AuditEvent{}, fmt.Errorf("failed to parse event type %s: %w", rawEventType, err)
} }
return ev, nil return ev, nil

@ -58,29 +58,34 @@ func (s *Store) Close() error {
} }
func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) { func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) {
defer s.Audit(&AuditableEvent{ audit := &AuditableEvent{
userId: userId, userId: userId,
eventType: EventClaim, eventType: EventClaim,
}) }
defer s.Audit(audit)
tx, err := s.db.Begin() tx, err := s.db.Begin()
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to begin transaction: %w", err) return 0, fmt.Errorf("failed to begin transaction: %w", err)
} }
defer tx.Commit() //nolint:errcheck defer tx.Commit() //nolint:errcheck
conflicts, err := s.FindConflicts(ctx, userId, province, claimType) conflicts, err := s.FindConflicts(ctx, userId, province, claimType)
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to run conflicts check: %w", err) return 0, fmt.Errorf("failed to run conflicts check: %w", err)
} }
if len(conflicts) > 0 { if len(conflicts) > 0 {
audit.err = err
return 0, ErrConflict{Conflicts: conflicts} return 0, ErrConflict{Conflicts: conflicts}
} }
// check that provided name matches the claim type // check that provided name matches the claim type
stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE LOWER(provinces.%s) = ?`, claimTypeToColumn[claimType])) stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE LOWER(provinces.%s) = ?`, claimTypeToColumn[claimType]))
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to prepare count query: %w", err) return 0, fmt.Errorf("failed to prepare count query: %w", err)
} }
@ -88,25 +93,30 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai
var count int var count int
err = row.Scan(&count) err = row.Scan(&count)
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to scan: %w", err) return 0, fmt.Errorf("failed to scan: %w", err)
} }
if count == 0 { if count == 0 {
audit.err = err
return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province) return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province)
} }
stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)") stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)")
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to prepare claim query: %w", err) return 0, fmt.Errorf("failed to prepare claim query: %w", err)
} }
res, err := stmt.ExecContext(ctx, player, claimType, province, userId) res, err := stmt.ExecContext(ctx, player, claimType, province, userId)
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to insert claim: %w", err) return 0, fmt.Errorf("failed to insert claim: %w", err)
} }
id, err := res.LastInsertId() id, err := res.LastInsertId()
if err != nil { if err != nil {
audit.err = err
return 0, fmt.Errorf("failed to get last ID: %w", err) return 0, fmt.Errorf("failed to get last ID: %w", err)
} }
@ -242,26 +252,31 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error)
} }
func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error { func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error {
defer s.Audit(&AuditableEvent{ audit := &AuditableEvent{
userId: userId, userId: userId,
eventType: EventUnclaim, eventType: EventUnclaim,
}) }
defer s.Audit(audit)
stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?") stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?")
if err != nil { if err != nil {
audit.err = err
return fmt.Errorf("failed to prepare query: %w", err) return fmt.Errorf("failed to prepare query: %w", err)
} }
res, err := stmt.ExecContext(ctx, ID, userId) res, err := stmt.ExecContext(ctx, ID, userId)
if err != nil { if err != nil {
audit.err = err
return fmt.Errorf("failed to delete claim ID %d: %w", ID, err) return fmt.Errorf("failed to delete claim ID %d: %w", ID, err)
} }
rows, err := res.RowsAffected() rows, err := res.RowsAffected()
if err != nil { if err != nil {
audit.err = err
return fmt.Errorf("failed to get affected rows: %w", err) return fmt.Errorf("failed to get affected rows: %w", err)
} }
if rows == 0 { if rows == 0 {
audit.err = ErrNoSuchClaim
return ErrNoSuchClaim return ErrNoSuchClaim
} }
return nil return nil
@ -283,13 +298,15 @@ func (s *Store) CountClaims(ctx context.Context) (total, uniquePlayers int, err
} }
func (s *Store) Flush(ctx context.Context, userId string) error { func (s *Store) Flush(ctx context.Context, userId string) error {
defer s.Audit(&AuditableEvent{ audit := &AuditableEvent{
userId: userId, userId: userId,
eventType: EventFlush, eventType: EventFlush,
}) }
defer s.Audit(audit)
_, err := s.db.ExecContext(ctx, "DELETE FROM claims;") _, err := s.db.ExecContext(ctx, "DELETE FROM claims;")
if err != nil { if err != nil {
audit.err = err
return fmt.Errorf("failed to execute delete query: %w", err) return fmt.Errorf("failed to execute delete query: %w", err)
} }
return nil return nil

@ -90,11 +90,17 @@ func TestStore_Claim(t *testing.T) {
wantErr: false, wantErr: false,
}, },
} }
lastAudit := 0
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if _, err := store.Claim(context.TODO(), tt.args.userId, tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr { if _, err := store.Claim(context.TODO(), tt.args.userId, tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr {
t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr)
} }
ae, err := store.LastOf(context.TODO(), EventClaim)
require.NoError(t, err)
assert.Greater(t, ae.id, lastAudit)
lastAudit = ae.id
}) })
} }
} }
@ -165,9 +171,17 @@ func TestDeleteClaim(t *testing.T) {
err = store.DeleteClaim(context.TODO(), fooId, "000000000000000001") err = store.DeleteClaim(context.TODO(), fooId, "000000000000000001")
assert.NoError(t, err) assert.NoError(t, err)
ae, err := store.LastOf(context.TODO(), EventUnclaim)
require.NoError(t, err)
last := ae.id
err = store.DeleteClaim(context.TODO(), barId, "000000000000000001") err = store.DeleteClaim(context.TODO(), barId, "000000000000000001")
assert.Error(t, err) assert.Error(t, err)
assert.ErrorIs(t, err, ErrNoSuchClaim) assert.ErrorIs(t, err, ErrNoSuchClaim)
ae, err = store.LastOf(context.TODO(), EventUnclaim)
require.NoError(t, err)
assert.Equal(t, last, ae.id) // no new audit log was added
} }
func TestDescribeClaim(t *testing.T) { func TestDescribeClaim(t *testing.T) {
@ -187,7 +201,7 @@ func TestDescribeClaim(t *testing.T) {
} }
func TestCountClaims(t *testing.T) { func TestCountClaims(t *testing.T) {
store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "TestFlush")) store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "TestCountClaims"))
assert.NoError(t, err) assert.NoError(t, err)
store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE) store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE)
@ -216,4 +230,7 @@ func TestFlush(t *testing.T) {
claims, err := store.ListClaims(context.TODO()) claims, err := store.ListClaims(context.TODO())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(claims)) assert.Equal(t, 0, len(claims))
_, err = store.LastOf(context.TODO(), EventFlush)
require.NoError(t, err)
} }

Loading…
Cancel
Save