You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
themis/store.go

469 lines
13 KiB

package themis
import (
"context"
"database/sql"
"embed"
"errors"
"fmt"
"strings"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
semconv "go.opentelemetry.io/otel/semconv/v1.24.0"
"go.opentelemetry.io/otel/trace"
)
//go:embed migrations/*.sql
var migrations embed.FS
var tracer trace.Tracer
func init() {
tp := otel.GetTracerProvider()
tracer = tp.Tracer("themis")
}
type Store struct {
db *sql.DB
logger zerolog.Logger
}
func NewStore(db *sql.DB, logger zerolog.Logger) (*Store, error) {
d, err := iofs.New(migrations, "migrations")
if err != nil {
return nil, fmt.Errorf("failed to open iofs migration source: %w", err)
}
dr, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
return nil, fmt.Errorf("failed to initialize sqlite3 migrate driver: %w", err)
}
m, err := migrate.NewWithInstance("iofs", d, "main", dr)
if err != nil {
return nil, fmt.Errorf("failed to initialize db migrate: %w", err)
}
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
return nil, fmt.Errorf("failed to roll up migrations: %w", err)
}
ver, dirty, err := m.Version()
if err != nil && err != migrate.ErrNilVersion {
return nil, fmt.Errorf("failed to get database migration version: %w", err)
}
logger.Debug().Uint("current_version", ver).Bool("dirty", dirty).Msg("running database migrations")
return &Store{
logger: logger,
db: db,
}, nil
}
func (s *Store) Close() error {
s.logger.Debug().Msg("closing database")
return s.db.Close()
}
func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) {
ctx, span := tracer.Start(ctx, "claim", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("insert"),
attribute.String("user_id", userId),
attribute.String("claim_name", province),
attribute.Stringer("claim_type", claimType),
))
defer span.End()
s.logger.Debug().
Ctx(ctx).
Str("userid", userId).
Str("player", player).
Str("province", province).
Stringer("claim_type", claimType).
Msg("inserting claim")
audit := &AuditableEvent{
userId: userId,
eventType: EventClaim,
}
defer s.Audit(ctx, audit)
tx, err := s.db.Begin()
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to begin transaction")
return 0, fmt.Errorf("failed to begin transaction: %w", err)
}
defer tx.Commit() //nolint:errcheck
conflicts, err := s.FindConflicts(ctx, userId, province, claimType)
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "conflict check failed")
return 0, fmt.Errorf("failed to run conflicts check: %w", err)
}
if len(conflicts) > 0 {
s.logger.Debug().Ctx(ctx).Int("len", len(conflicts)).Msg("found conflicts")
audit.err = errors.New("found conflicts")
return 0, ErrConflict{Conflicts: conflicts}
}
// check that provided name matches the claim type
stmt, err := s.db.PrepareContext(ctx, `SELECT COUNT(1) FROM claimables WHERE lower(name) = ? and typ = ?`)
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return 0, fmt.Errorf("failed to prepare count query: %w", err)
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, strings.ToLower(province), claimType)
var count int
err = row.Scan(&count)
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return 0, fmt.Errorf("failed to scan: %w", err)
}
if count == 0 {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "no matching provinces found")
return 0, fmt.Errorf("found no provinces for %s named %s", claimType, province)
}
stmt, err = s.db.PrepareContext(ctx, "INSERT INTO claims (player, claim_type, val, userid) VALUES (?, ?, ?, ?)")
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return 0, fmt.Errorf("failed to prepare claim query: %w", err)
}
defer stmt.Close()
res, err := stmt.ExecContext(ctx, player, claimType, province, userId)
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return 0, fmt.Errorf("failed to insert claim: %w", err)
}
id, err := res.LastInsertId()
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to last insert id")
return 0, fmt.Errorf("failed to get last ID: %w", err)
}
return int(id), nil
}
func (s *Store) ListAvailability(ctx context.Context, claimType ClaimType, search ...string) ([]string, error) {
ctx, span := tracer.Start(ctx, "list_availability", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claimables"),
semconv.DBOperation("select"),
attribute.StringSlice("search_terms", search),
attribute.Stringer("claim_type", claimType),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Stringer("claim_type", claimType).Strs("search_terms", search).Msg("listing available entries")
queryParams := []any{string(claimType)}
queryPattern := `SELECT distinct name
FROM claimables
LEFT JOIN claims ON claimables.name = claims.val AND claimables.typ = claims.claim_type
WHERE claims.val IS NULL
AND claimables.typ = ?`
if len(search) > 0 && search[0] != "" {
// only take one search param, ignore the rest
queryPattern += `AND claimables.name LIKE ?`
queryParams = append(queryParams, fmt.Sprintf("%%%s%%", search[0]))
}
stmt, err := s.db.PrepareContext(ctx, queryPattern)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return nil, fmt.Errorf("failed to prepare query: %w", err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, queryParams...)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return nil, fmt.Errorf("failed to execute query: %w", err)
}
avail := make([]string, 0)
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return nil, fmt.Errorf("failed to scan rows: %w", err)
}
avail = append(avail, s)
}
return avail, nil
}
func (s *Store) ListClaims(ctx context.Context) ([]Claim, error) {
ctx, span := tracer.Start(ctx, "list_claims", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("select"),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Msg("listing all claims currently in database")
stmt, err := s.db.PrepareContext(ctx, `SELECT id, player, claim_type, val FROM claims`)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return nil, fmt.Errorf("failed to prepare query: %w", err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return nil, fmt.Errorf("failed to execute query: %w", err)
}
claims := make([]Claim, 0)
for rows.Next() {
c := Claim{}
var rawType string
err = rows.Scan(&c.ID, &c.Player, &rawType, &c.Name)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return nil, fmt.Errorf("failed to scan row: %w", err)
}
cl, err := ClaimTypeFromString(rawType)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to parse claim type")
return nil, fmt.Errorf("unexpected error converting raw claim type: %w", err)
}
c.Type = cl
claims = append(claims, c)
}
return claims, nil
}
type ClaimDetail struct {
Claim
Provinces []string
}
func (cd ClaimDetail) String() string {
sb := strings.Builder{}
sb.WriteString(fmt.Sprintf("%s\n", cd.Claim))
for _, p := range cd.Provinces {
sb.WriteString(fmt.Sprintf("- %s\n", p))
}
return sb.String()
}
func (s *Store) DescribeClaim(ctx context.Context, ID int) (ClaimDetail, error) {
ctx, span := tracer.Start(ctx, "describe_claim", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("select"),
attribute.Int("claim_id", ID),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Int("id", ID).Msg("describing claim")
stmt, err := s.db.PrepareContext(ctx, `SELECT id, player, claim_type, val FROM claims WHERE id = ?`)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return ClaimDetail{}, fmt.Errorf("failed to prepare select claim query: %w", err)
}
defer stmt.Close()
row := stmt.QueryRowContext(ctx, ID)
c := Claim{}
var rawType string
err = row.Scan(&c.ID, &c.Player, &rawType, &c.Name)
if err == sql.ErrNoRows {
span.RecordError(ErrNoSuchClaim)
return ClaimDetail{}, ErrNoSuchClaim
}
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return ClaimDetail{}, fmt.Errorf("failed to scan row: %w", err)
}
cl, err := ClaimTypeFromString(rawType)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to parse claim type")
return ClaimDetail{}, fmt.Errorf("unexpected error converting raw claim type: %w", err)
}
c.Type = cl
stmt, err = s.db.PrepareContext(ctx, `SELECT province FROM claimables WHERE name = ? AND typ = ?`)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return ClaimDetail{}, fmt.Errorf("failed to prepare query: %w", err)
}
defer stmt.Close()
rows, err := stmt.QueryContext(ctx, c.Name, cl)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return ClaimDetail{}, fmt.Errorf("failed to execute query: %w", err)
}
provinces := make([]string, 0)
for rows.Next() {
var p string
err = rows.Scan(&p)
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return ClaimDetail{}, fmt.Errorf("failed to scan result set: %w", err)
}
provinces = append(provinces, p)
}
return ClaimDetail{
Claim: c,
Provinces: provinces,
}, nil
}
func (s *Store) DeleteClaim(ctx context.Context, ID int, userId string) error {
ctx, span := tracer.Start(ctx, "delete_claim", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("delete"),
attribute.Int("claim_id", ID),
attribute.String("user_id", userId),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Str("userid", userId).Int("id", ID).Msg("deleting claim")
audit := &AuditableEvent{
userId: userId,
eventType: EventUnclaim,
}
defer s.Audit(ctx, audit)
stmt, err := s.db.PrepareContext(ctx, "DELETE FROM claims WHERE id = ? AND userid = ?")
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return fmt.Errorf("failed to prepare query: %w", err)
}
defer stmt.Close()
res, err := stmt.ExecContext(ctx, ID, userId)
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return fmt.Errorf("failed to delete claim ID %d: %w", ID, err)
}
rows, err := res.RowsAffected()
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to get number of affected rows")
return fmt.Errorf("failed to get affected rows: %w", err)
}
if rows == 0 {
audit.err = ErrNoSuchClaim
span.RecordError(ErrNoSuchClaim)
return ErrNoSuchClaim
}
return nil
}
func (s *Store) CountClaims(ctx context.Context) (total, uniquePlayers int, err error) {
ctx, span := tracer.Start(ctx, "count_claims", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("select"),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Msg("counting all claims and unique users")
stmt, err := s.db.PrepareContext(ctx, "SELECT COUNT(1), COUNT(DISTINCT(userid)) FROM claims")
if err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to prepare query")
return 0, 0, fmt.Errorf("failed to prepare query: %w", err)
}
defer stmt.Close()
res := stmt.QueryRowContext(ctx)
if err := res.Scan(&total, &uniquePlayers); err != nil {
span.RecordError(err)
span.SetStatus(codes.Error, "failed to scan row")
return 0, 0, fmt.Errorf("failed to scan result: %w", err)
}
return total, uniquePlayers, nil
}
func (s *Store) Flush(ctx context.Context, userId string) error {
ctx, span := tracer.Start(ctx, "flush", trace.WithAttributes(
semconv.DBSystemSqlite,
semconv.DBSQLTable("claims"),
semconv.DBOperation("delete"),
))
defer span.End()
s.logger.Debug().Ctx(ctx).Str("initiated_by", userId).Msg("flushing all currently help claims")
audit := &AuditableEvent{
userId: userId,
eventType: EventFlush,
}
defer s.Audit(ctx, audit)
_, err := s.db.ExecContext(ctx, "DELETE FROM claims;")
if err != nil {
audit.err = err
span.RecordError(err)
span.SetStatus(codes.Error, "failed to execute query")
return fmt.Errorf("failed to execute delete query: %w", err)
}
return nil
}