From b2d5332f67f8e20f5e8e7716f63c89e0b19a021a Mon Sep 17 00:00:00 2001 From: William Perron Date: Thu, 22 Sep 2022 21:15:21 +0000 Subject: [PATCH] WIP: command to run arbitrary selects --- cmd/themis-server/main.go | 44 +++++++++++++++++++++++++ fmt.go | 69 +++++++++++++++++++++++++++++++++++++++ fmt_test.go | 53 ++++++++++++++++++++++++++++++ store.go | 4 +++ 4 files changed, 170 insertions(+) create mode 100644 fmt.go create mode 100644 fmt_test.go diff --git a/cmd/themis-server/main.go b/cmd/themis-server/main.go index 5fb5301..61d1421 100644 --- a/cmd/themis-server/main.go +++ b/cmd/themis-server/main.go @@ -2,6 +2,7 @@ package main import ( "context" + "database/sql" "errors" "flag" "fmt" @@ -13,6 +14,7 @@ import ( "syscall" "github.com/bwmarrin/discordgo" + _ "github.com/mattn/go-sqlite3" "github.com/rs/zerolog/log" "go.wperron.io/themis" @@ -47,6 +49,7 @@ func main() { if err != nil { log.Fatal().Err(err).Msg("failed to initialize database") } + defer store.Close() authToken, ok := os.LookupEnv("DISCORD_TOKEN") if !ok { @@ -133,6 +136,18 @@ func main() { Description: "Remove all claims from the database and prepare for the next game!", Type: discordgo.ChatApplicationCommand, }, + { + Name: "query", + Description: "Run a raw SQL query on the database", + Type: discordgo.ChatApplicationCommand, + Options: []*discordgo.ApplicationCommandOption{ + { + Name: "query", + Description: "Raw SQL query", + Type: discordgo.ApplicationCommandOptionString, + }, + }, + }, } handlers := map[string]Handler{ "ping": func(s *discordgo.Session, i *discordgo.InteractionCreate) { @@ -356,6 +371,35 @@ func main() { log.Error().Err(err).Msg("failed to respond to interaction") } }, + "query": func(s *discordgo.Session, i *discordgo.InteractionCreate) { + roDB, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?cache=private&mode=ro", *dbFile)) + if err != nil { + log.Error().Err(err).Msg("failed to open read-only copy of databse") + } + + q := i.ApplicationCommandData().Options[0].StringValue() + rows, err := roDB.Query(q) + if err != nil { + log.Error().Err(err).Msg("failed to exec user-provided query") + return + } + + fmtd, err := themis.FormatRows(rows) + if err != nil { + log.Error().Err(err).Msg("failed to format rows") + } + + table := fmt.Sprintf("```\n%s\n```", fmtd[:min(len(fmtd), 1990)]) // TODO(wperron) find a better way to cutover + + if err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{ + Type: discordgo.InteractionResponseChannelMessageWithSource, + Data: &discordgo.InteractionResponseData{ + Content: table, + }, + }); err != nil { + log.Error().Err(err).Msg("failed to respond to interaction") + } + }, } registerHandlers(discord, handlers) diff --git a/fmt.go b/fmt.go new file mode 100644 index 0000000..a747f13 --- /dev/null +++ b/fmt.go @@ -0,0 +1,69 @@ +package themis + +import ( + "database/sql" + "fmt" + "strings" +) + +func FormatRows(rows *sql.Rows) (string, error) { + sb := strings.Builder{} + + cols, err := rows.Columns() + if err != nil { + return "", fmt.Errorf("failed to get rows columns: %w", err) + } + + c := make([]string, len(cols)) + for i := range c { + c[i] = " %-*s " + } + pattern := fmt.Sprintf("|%s|\n", strings.Join(c, "|")) + + lengths := make([]int, len(cols)) + for i := range lengths { + lengths[i] = len(cols[i]) + } + + scanned := make([][]any, 0) + for rows.Next() { + row := make([]interface{}, len(cols)) + for i := range row { + row[i] = new(string) + } + rows.Scan(row...) + scanned = append(scanned, row) // keep track of row for later + for i, a := range row { + s := a.(*string) + if len(*s) > lengths[i] { + lengths[i] = len(*s) + } + } + } + + // Write column names + curr := make([]any, 0, 2*len(cols)) + for i := range lengths { + curr = append(curr, lengths[i], cols[i]) + } + sb.WriteString(fmt.Sprintf(pattern, curr...)) + + // Write header separator row + curr = curr[:0] // empty slice but preserve capacity + for i := range lengths { + curr = append(curr, lengths[i], strings.Repeat("-", lengths[i])) + } + sb.WriteString(fmt.Sprintf(pattern, curr...)) + + // iterate rows and write each one + for _, r := range scanned { + curr = curr[:0] // empty slice but preserve capacity + for i := range lengths { + s := r[i].(*string) + curr = append(curr, lengths[i], *s) + } + sb.WriteString(fmt.Sprintf(pattern, curr...)) + } + + return sb.String(), nil +} diff --git a/fmt_test.go b/fmt_test.go new file mode 100644 index 0000000..0529da4 --- /dev/null +++ b/fmt_test.go @@ -0,0 +1,53 @@ +package themis + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFormatRows(t *testing.T) { + store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows")) + assert.NoError(t, err) + + rows, err := store.db.Query("SELECT provinces.name, provinces.region, provinces.area, provinces.trade_node FROM provinces WHERE area = 'Gascony'") + assert.NoError(t, err) + + fmtd, err := FormatRows(rows) + assert.NoError(t, err) + assert.Equal(t, `| name | region | area | trade_node | +| -------- | ------ | ------- | ---------- | +| Labourd | France | Gascony | Bordeaux | +| Armagnac | France | Gascony | Bordeaux | +| BĂ©arn | France | Gascony | Bordeaux | +| Foix | France | Gascony | Bordeaux | +`, fmtd) +} + +func TestFormatRowsAggregated(t *testing.T) { + store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows")) + assert.NoError(t, err) + + rows, err := store.db.Query("SELECT count(1) as total, trade_node from provinces where region = 'France' group by trade_node") + assert.NoError(t, err) + + fmtd, err := FormatRows(rows) + assert.NoError(t, err) + assert.Equal(t, `| total | trade_node | +| ----- | --------------- | +| 25 | Bordeaux | +| 24 | Champagne | +| 8 | English Channel | +| 4 | Genoa | +| 5 | Valencia | +`, fmtd) +} + +func TestFormatRowsInvalidQuery(t *testing.T) { + store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows")) + assert.NoError(t, err) + + _, err = store.db.Query("SELECT count(name), distinct(trade_node) from provinces where region = 'France'") + assert.Error(t, err) +} diff --git a/store.go b/store.go index faeab8e..b5c7a9c 100644 --- a/store.go +++ b/store.go @@ -91,6 +91,10 @@ func NewStore(conn string) (*Store, error) { }, nil } +func (s *Store) Close() error { + return s.db.Close() +} + func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) { tx, err := s.db.Begin() if err != nil {