WIP: command to run arbitrary selects

absences
William Perron 2 years ago
parent bd490bc18f
commit b2d5332f67

@ -2,6 +2,7 @@ package main
import (
"context"
"database/sql"
"errors"
"flag"
"fmt"
@ -13,6 +14,7 @@ import (
"syscall"
"github.com/bwmarrin/discordgo"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
"go.wperron.io/themis"
@ -47,6 +49,7 @@ func main() {
if err != nil {
log.Fatal().Err(err).Msg("failed to initialize database")
}
defer store.Close()
authToken, ok := os.LookupEnv("DISCORD_TOKEN")
if !ok {
@ -133,6 +136,18 @@ func main() {
Description: "Remove all claims from the database and prepare for the next game!",
Type: discordgo.ChatApplicationCommand,
},
{
Name: "query",
Description: "Run a raw SQL query on the database",
Type: discordgo.ChatApplicationCommand,
Options: []*discordgo.ApplicationCommandOption{
{
Name: "query",
Description: "Raw SQL query",
Type: discordgo.ApplicationCommandOptionString,
},
},
},
}
handlers := map[string]Handler{
"ping": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
@ -356,6 +371,35 @@ func main() {
log.Error().Err(err).Msg("failed to respond to interaction")
}
},
"query": func(s *discordgo.Session, i *discordgo.InteractionCreate) {
roDB, err := sql.Open("sqlite3", fmt.Sprintf("file:%s?cache=private&mode=ro", *dbFile))
if err != nil {
log.Error().Err(err).Msg("failed to open read-only copy of databse")
}
q := i.ApplicationCommandData().Options[0].StringValue()
rows, err := roDB.Query(q)
if err != nil {
log.Error().Err(err).Msg("failed to exec user-provided query")
return
}
fmtd, err := themis.FormatRows(rows)
if err != nil {
log.Error().Err(err).Msg("failed to format rows")
}
table := fmt.Sprintf("```\n%s\n```", fmtd[:min(len(fmtd), 1990)]) // TODO(wperron) find a better way to cutover
if err := s.InteractionRespond(i.Interaction, &discordgo.InteractionResponse{
Type: discordgo.InteractionResponseChannelMessageWithSource,
Data: &discordgo.InteractionResponseData{
Content: table,
},
}); err != nil {
log.Error().Err(err).Msg("failed to respond to interaction")
}
},
}
registerHandlers(discord, handlers)

@ -0,0 +1,69 @@
package themis
import (
"database/sql"
"fmt"
"strings"
)
func FormatRows(rows *sql.Rows) (string, error) {
sb := strings.Builder{}
cols, err := rows.Columns()
if err != nil {
return "", fmt.Errorf("failed to get rows columns: %w", err)
}
c := make([]string, len(cols))
for i := range c {
c[i] = " %-*s "
}
pattern := fmt.Sprintf("|%s|\n", strings.Join(c, "|"))
lengths := make([]int, len(cols))
for i := range lengths {
lengths[i] = len(cols[i])
}
scanned := make([][]any, 0)
for rows.Next() {
row := make([]interface{}, len(cols))
for i := range row {
row[i] = new(string)
}
rows.Scan(row...)
scanned = append(scanned, row) // keep track of row for later
for i, a := range row {
s := a.(*string)
if len(*s) > lengths[i] {
lengths[i] = len(*s)
}
}
}
// Write column names
curr := make([]any, 0, 2*len(cols))
for i := range lengths {
curr = append(curr, lengths[i], cols[i])
}
sb.WriteString(fmt.Sprintf(pattern, curr...))
// Write header separator row
curr = curr[:0] // empty slice but preserve capacity
for i := range lengths {
curr = append(curr, lengths[i], strings.Repeat("-", lengths[i]))
}
sb.WriteString(fmt.Sprintf(pattern, curr...))
// iterate rows and write each one
for _, r := range scanned {
curr = curr[:0] // empty slice but preserve capacity
for i := range lengths {
s := r[i].(*string)
curr = append(curr, lengths[i], *s)
}
sb.WriteString(fmt.Sprintf(pattern, curr...))
}
return sb.String(), nil
}

@ -0,0 +1,53 @@
package themis
import (
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestFormatRows(t *testing.T) {
store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows"))
assert.NoError(t, err)
rows, err := store.db.Query("SELECT provinces.name, provinces.region, provinces.area, provinces.trade_node FROM provinces WHERE area = 'Gascony'")
assert.NoError(t, err)
fmtd, err := FormatRows(rows)
assert.NoError(t, err)
assert.Equal(t, `| name | region | area | trade_node |
| -------- | ------ | ------- | ---------- |
| Labourd | France | Gascony | Bordeaux |
| Armagnac | France | Gascony | Bordeaux |
| Béarn | France | Gascony | Bordeaux |
| Foix | France | Gascony | Bordeaux |
`, fmtd)
}
func TestFormatRowsAggregated(t *testing.T) {
store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows"))
assert.NoError(t, err)
rows, err := store.db.Query("SELECT count(1) as total, trade_node from provinces where region = 'France' group by trade_node")
assert.NoError(t, err)
fmtd, err := FormatRows(rows)
assert.NoError(t, err)
assert.Equal(t, `| total | trade_node |
| ----- | --------------- |
| 25 | Bordeaux |
| 24 | Champagne |
| 8 | English Channel |
| 4 | Genoa |
| 5 | Valencia |
`, fmtd)
}
func TestFormatRowsInvalidQuery(t *testing.T) {
store, err := NewStore(fmt.Sprintf(TEST_CONN_STRING_PATTERN, "format-rows"))
assert.NoError(t, err)
_, err = store.db.Query("SELECT count(name), distinct(trade_node) from provinces where region = 'France'")
assert.Error(t, err)
}

@ -91,6 +91,10 @@ func NewStore(conn string) (*Store, error) {
}, nil
}
func (s *Store) Close() error {
return s.db.Close()
}
func (s *Store) Claim(ctx context.Context, userId, player, province string, claimType ClaimType) (int, error) {
tx, err := s.db.Begin()
if err != nil {

Loading…
Cancel
Save