diff --git a/absences_test.go b/absences_test.go index 710e4fe..1aa0735 100644 --- a/absences_test.go +++ b/absences_test.go @@ -25,6 +25,9 @@ func TestAddAbsence(t *testing.T) { absentees, err = store.GetAbsentees(context.TODO(), now) assert.NoError(t, err) assert.Equal(t, 1, len(absentees)) + + _, err = store.LastOf(context.TODO(), EventAbsence) + require.NoError(t, err) } func TestGetSchedule(t *testing.T) { diff --git a/audit_log.go b/audit_log.go index 72c367f..4a2c6d3 100644 --- a/audit_log.go +++ b/audit_log.go @@ -52,51 +52,61 @@ func EventTypeFromString(ev string) (EventType, error) { type AuditableEvent struct { userId string eventType EventType - Timestamp time.Time + timestamp time.Time + err error } // Audit writes to the audit table, returns nothing because it is meant to be // used in a defered statement on functions that write to the database. func (s *Store) Audit(ev *AuditableEvent) { - ctx := context.Background() - - tx, err := s.db.Begin() - if err != nil { - log.Error().Err(err).Msg("failed to start transaction") - } - defer tx.Commit() //nolint:errcheck - - stmt, err := s.db.PrepareContext(ctx, "INSERT INTO audit_log (userid, event_type, ts) VALUES (?, ?, ?)") - if err != nil { - log.Error().Err(err).Msg("failed to prepare audit log insert") + if ev.err == nil { + ctx := context.Background() + + tx, err := s.db.Begin() + if err != nil { + log.Error().Err(err).Msg("failed to start transaction") + } + defer tx.Commit() //nolint:errcheck + + stmt, err := s.db.PrepareContext(ctx, "INSERT INTO audit_log (userid, event_type, ts) VALUES (?, ?, ?)") + if err != nil { + log.Error().Err(err).Msg("failed to prepare audit log insert") + } + + if _, err := stmt.ExecContext(ctx, ev.userId, ev.eventType.String(), time.Now()); err != nil { + log.Error().Err(err).Msg("failed to insert audit log") + } } +} - if _, err := stmt.ExecContext(ctx, ev.userId, ev.eventType.String(), time.Now()); err != nil { - log.Error().Err(err).Msg("failed to insert audit log") - } +type AuditEvent struct { + id int + userId string + eventType EventType + timestamp time.Time } -func (s *Store) LastOf(ctx context.Context, t EventType) (AuditableEvent, error) { - stmt, err := s.db.PrepareContext(ctx, `SELECT userid, event_type, ts FROM audit_log WHERE event_type = ? ORDER BY ts DESC LIMIT 1`) +func (s *Store) LastOf(ctx context.Context, t EventType) (AuditEvent, error) { + stmt, err := s.db.PrepareContext(ctx, `SELECT id, userid, event_type, ts FROM audit_log WHERE event_type = ? ORDER BY ts DESC LIMIT 1`) if err != nil { - return AuditableEvent{}, fmt.Errorf("failed to get last event of type %s: %w", t.String(), err) + return AuditEvent{}, fmt.Errorf("failed to get last event of type %s: %w", t.String(), err) } row := stmt.QueryRowContext(ctx, t.String()) - ev := AuditableEvent{} + ev := AuditEvent{} var rawEventType string - err = row.Scan(&ev.userId, &rawEventType, &ev.Timestamp) + err = row.Scan(&ev.id, &ev.userId, &rawEventType, &ev.timestamp) if err == sql.ErrNoRows { - return AuditableEvent{}, errors.New("") + return AuditEvent{}, errors.New("no rows found") } if err != nil { - return AuditableEvent{}, fmt.Errorf("failed to scan row: %w", err) + return AuditEvent{}, fmt.Errorf("failed to scan row: %w", err) } ev.eventType, err = EventTypeFromString(rawEventType) if err != nil { - return AuditableEvent{}, fmt.Errorf("failed to parse event type %s: %w", rawEventType, err) + return AuditEvent{}, fmt.Errorf("failed to parse event type %s: %w", rawEventType, err) } return ev, nil diff --git a/store.go b/store.go index fe133ca..594d162 100644 --- a/store.go +++ b/store.go @@ -58,29 +58,34 @@ func (s *Store) Close() error { } func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) { - defer s.Audit(&AuditableEvent{ + audit := &AuditableEvent{ userId: userId, eventType: EventClaim, - }) + } + defer s.Audit(audit) tx, err := s.db.Begin() if err != nil { + audit.err = err return 0, fmt.Errorf("failed to begin transaction: %w", err) } defer tx.Commit() //nolint:errcheck conflicts, err := s.FindConflicts(ctx, userId, province, claimType) if err != nil { + audit.err = err return 0, fmt.Errorf("failed to run conflicts check: %w", err) } if len(conflicts) > 0 { + audit.err = err return 0, ErrConflict{Conflicts: conflicts} } // check that provided name matches the claim type stmt, err := s.db.PrepareContext(ctx, fmt.Sprintf(`SELECT COUNT(1) FROM provinces WHERE LOWER(provinces.%s) = ?`, claimTypeToColumn[claimType])) if err != nil { + audit.err = err return 0, fmt.Errorf("failed to prepare count query: %w", err) } @@ -88,25 +93,30 @@ func (s *Store) Claim(ctx context.Context, userId, player, province string, clai var count int err = row.Scan(&count) if err != nil { + audit.err = err return 0, fmt.Errorf("failed to scan: %w", err) } if count == 0 { + audit.err = err return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province) } stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)") if err != nil { + audit.err = err return 0, fmt.Errorf("failed to prepare claim query: %w", err) } res, err := stmt.ExecContext(ctx, player, claimType, province, userId) if err != nil { + audit.err = err return 0, fmt.Errorf("failed to insert claim: %w", err) } id, err := res.LastInsertId() if err != nil { + audit.err = err return 0, fmt.Errorf("failed to get last ID: %w", err) } @@ -242,26 +252,31 @@ func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error) } func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error { - defer s.Audit(&AuditableEvent{ + audit := &AuditableEvent{ userId: userId, eventType: EventUnclaim, - }) + } + defer s.Audit(audit) stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?") if err != nil { + audit.err = err return fmt.Errorf("failed to prepare query: %w", err) } res, err := stmt.ExecContext(ctx, ID, userId) if err != nil { + audit.err = err return fmt.Errorf("failed to delete claim ID %d: %w", ID, err) } rows, err := res.RowsAffected() if err != nil { + audit.err = err return fmt.Errorf("failed to get affected rows: %w", err) } if rows == 0 { + audit.err = ErrNoSuchClaim return ErrNoSuchClaim } return nil @@ -283,13 +298,15 @@ func (s *Store) CountClaims(ctx context.Context) (total, uniquePlayers int, err } func (s *Store) Flush(ctx context.Context, userId string) error { - defer s.Audit(&AuditableEvent{ + audit := &AuditableEvent{ userId: userId, eventType: EventFlush, - }) + } + defer s.Audit(audit) _, err := s.db.ExecContext(ctx, "DELETE FROM claims;") if err != nil { + audit.err = err return fmt.Errorf("failed to execute delete query: %w", err) } return nil diff --git a/store_test.go b/store_test.go index 1107836..be92bbd 100644 --- a/store_test.go +++ b/store_test.go @@ -90,11 +90,17 @@ func TestStore_Claim(t *testing.T) { wantErr: false, }, } + lastAudit := 0 for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if _, err := store.Claim(context.TODO(), tt.args.userId, tt.args.player, tt.args.province, tt.args.claimType); (err != nil) != tt.wantErr { t.Errorf("Store.Claim() error = %v, wantErr %v", err, tt.wantErr) } + + ae, err := store.LastOf(context.TODO(), EventClaim) + require.NoError(t, err) + assert.Greater(t, ae.id, lastAudit) + lastAudit = ae.id }) } } @@ -165,9 +171,17 @@ func TestDeleteClaim(t *testing.T) { err = store.DeleteClaim(context.TODO(), fooId, "000000000000000001") assert.NoError(t, err) + ae, err := store.LastOf(context.TODO(), EventUnclaim) + require.NoError(t, err) + last := ae.id + err = store.DeleteClaim(context.TODO(), barId, "000000000000000001") assert.Error(t, err) assert.ErrorIs(t, err, ErrNoSuchClaim) + + ae, err = store.LastOf(context.TODO(), EventUnclaim) + require.NoError(t, err) + assert.Equal(t, last, ae.id) // no new audit log was added } func TestDescribeClaim(t *testing.T) { @@ -187,7 +201,7 @@ func TestDescribeClaim(t *testing.T) { } func TestCountClaims(t *testing.T) { - store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "TestFlush")) + store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "TestCountClaims")) assert.NoError(t, err) store.Claim(context.TODO(), "000000000000000001", "foo", "Genoa", CLAIM_TYPE_TRADE) @@ -216,4 +230,7 @@ func TestFlush(t *testing.T) { claims, err := store.ListClaims(context.TODO()) assert.NoError(t, err) assert.Equal(t, 0, len(claims)) + + _, err = store.LastOf(context.TODO(), EventFlush) + require.NoError(t, err) }