Skip to content

Commit

Permalink
Some refactoring, add table listing
Browse files Browse the repository at this point in the history
  • Loading branch information
aidansteele committed Mar 21, 2020
1 parent e7b62ea commit 5c56cb5
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 213 deletions.
85 changes: 85 additions & 0 deletions api.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package main

import (
"fmt"
daxc "github.com/aws/aws-dax-go/dax"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dax"
"github.com/aws/aws-sdk-go/service/dynamodb"
"strings"
)

type Api struct {
dynamo *dynamodb.DynamoDB
dax *daxc.Dax
}

func (a *Api) ListTablesPages(input *dynamodb.ListTablesInput, cb func(*dynamodb.ListTablesOutput, bool) bool) error {
return a.dynamo.ListTablesPages(input, cb)
}

func (a *Api) DescribeTable(input *dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) {
return a.dynamo.DescribeTable(input)
}

func (a *Api) QueryPages(input *dynamodb.QueryInput, cb func(*dynamodb.QueryOutput, bool) bool) error {
if a.dax != nil {
return a.dax.QueryPages(input, cb)
}

return a.dynamo.QueryPages(input, cb)
}

func (a *Api) ScanPages(input *dynamodb.ScanInput, cb func(*dynamodb.ScanOutput, bool) bool) error {
if a.dax != nil {
return a.dax.ScanPages(input, cb)
}

return a.dynamo.ScanPages(input, cb)
}

func apiClient(daxCluster string) *Api {
sess, err := session.NewSession()
if err != nil {
panic(err)
}

api := &Api{dynamo: dynamodb.New(sess)}

if len(daxCluster) == 0 {
return api
}

if !strings.Contains(daxCluster, ".") {
// must be a cluster name rather than domain name
dapi := dax.New(sess)
desc, err := dapi.DescribeClusters(&dax.DescribeClustersInput{ClusterNames: []*string{}})
if err != nil {
panic(err)
}

if len(desc.Clusters) == 0 {
panic("no cluster found by that name")
}

e := desc.Clusters[0].ClusterDiscoveryEndpoint
daxCluster = fmt.Sprintf("%s:%d", *e.Address, *e.Port)
}

if !strings.Contains(daxCluster, ":") {
// missing port, assume default
daxCluster += ":8111"
}

cfg := daxc.DefaultConfig()
cfg.HostPorts = []string{daxCluster}
cfg.Credentials = sess.Config.Credentials
cfg.Region = *sess.Config.Region

api.dax, err = daxc.New(cfg)
if err != nil {
panic(err)
}

return api
}
255 changes: 42 additions & 213 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,16 @@ import (
"encoding/json"
"fmt"
"github.com/TylerBrock/colorjson"
daxc "github.com/aws/aws-dax-go/dax"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/dax"
"github.com/aws/aws-sdk-go/service/dynamodb"
"github.com/aws/aws-sdk-go/service/dynamodb/dynamodbattribute"
"github.com/davecgh/go-spew/spew"
"github.com/mattn/go-colorable"
"github.com/mattn/go-isatty"
"github.com/pkg/errors"
"github.com/spf13/pflag"
"io"
"os"
"regexp"
"sort"
"strings"
)

Expand All @@ -30,36 +26,11 @@ type Dynamo struct {
emitted int
}

type Api struct {
dynamo *dynamodb.DynamoDB
dax *daxc.Dax
}

func (a *Api) DescribeTable(input *dynamodb.DescribeTableInput) (*dynamodb.DescribeTableOutput, error) {
return a.dynamo.DescribeTable(input)
}

func (a *Api) QueryPages(input *dynamodb.QueryInput, cb func(*dynamodb.QueryOutput, bool) bool) error {
if a.dax != nil {
return a.dax.QueryPages(input, cb)
}

return a.dynamo.QueryPages(input, cb)
}

func (a *Api) ScanPages(input *dynamodb.ScanInput, cb func(*dynamodb.ScanOutput, bool) bool) error {
if a.dax != nil {
return a.dax.ScanPages(input, cb)
}

return a.dynamo.ScanPages(input, cb)
}

func main() {
/*
len(os.Args)
1 => just app name, do help
1 => just app name, list tables
2 => app and table, do a scan
3 => app, table, pkey => do a query
4 => app, table, pkey, skey => implies equality if no operator in skey
Expand All @@ -82,80 +53,60 @@ func main() {
d.Run(args)
}

func apiClient(daxCluster string) *Api {
sess, err := session.NewSessionWithOptions(session.Options{
SharedConfigState: session.SharedConfigEnable,
})
if err != nil {
panic(err)
func (d *Dynamo) Run(args []string) {
var err error
switch len(args) {
case 0:
err = d.tables()
case 1:
err = d.scan(args[0])
default:
err = d.query(args)
}

api := &Api{dynamo: dynamodb.New(sess)}

if len(daxCluster) == 0 {
return api
if err != nil {
spew.Dump(err)
os.Exit(1)
}
}

if !strings.Contains(daxCluster, ".") {
// must be a cluster name rather than domain name
dapi := dax.New(sess)
desc, err := dapi.DescribeClusters(&dax.DescribeClustersInput{ClusterNames: []*string{}})
if err != nil {
panic(err)
}

if len(desc.Clusters) == 0 {
panic("no cluster found by that name")
func (d *Dynamo) tables() error {
names := []string{}
err := d.api.ListTablesPages(&dynamodb.ListTablesInput{}, func(page *dynamodb.ListTablesOutput, lastPage bool) bool {
for _, name := range page.TableNames {
names = append(names, *name)
}
return !lastPage
})

e := desc.Clusters[0].ClusterDiscoveryEndpoint
daxCluster = fmt.Sprintf("%s:%d", *e.Address, *e.Port)
}

if !strings.Contains(daxCluster, ":") {
// missing port, assume default
daxCluster += ":8111"
}
sort.Slice(names, func(i, j int) bool {
return strings.ToLower(names[i]) < strings.ToLower(names[j])
})

cfg := daxc.DefaultConfig()
cfg.HostPorts = []string{daxCluster}
cfg.Credentials = sess.Config.Credentials
cfg.Region = *sess.Config.Region
fmt.Println(strings.Join(names, "\n"))
return err
}

api.dax, err = daxc.New(cfg)
func (d *Dynamo) query(args []string) error {
input, err := queryForArgs(d.api, args)
if err != nil {
panic(err)
return err
}

return api
return d.api.QueryPages(input, func(page *dynamodb.QueryOutput, lastPage bool) bool {
return d.write(convert(page.Items)) || lastPage
})
}

func (d *Dynamo) Run(args []string) {
if len(args) == 1 { // only table name passed in
input := &dynamodb.ScanInput{
TableName: aws.String(args[0]),
Limit: aws.Int64(100),
}
err := d.api.ScanPages(input, func(page *dynamodb.ScanOutput, lastPage bool) bool {
return d.write(convert(page.Items)) || lastPage
})
if err != nil {
spew.Dump(err)
}
} else {
input, err := queryForArgs(d.api, args)
if err != nil {
spew.Dump(err)
os.Exit(1)
}

err = d.api.QueryPages(input, func(page *dynamodb.QueryOutput, lastPage bool) bool {
return d.write(convert(page.Items)) || lastPage
})
if err != nil {
spew.Dump(err)
}
func (d *Dynamo) scan(table string) error {
input := &dynamodb.ScanInput{
TableName: aws.String(table),
Limit: aws.Int64(100),
}

return d.api.ScanPages(input, func(page *dynamodb.ScanOutput, lastPage bool) bool {
return d.write(convert(page.Items)) || lastPage
})
}

func convert(items []map[string]*dynamodb.AttributeValue) []interface{} {
Expand Down Expand Up @@ -193,125 +144,3 @@ func (d *Dynamo) write(jsonItems []interface{}) bool {

return true
}

func queryForArgs(api *Api, args []string) (*dynamodb.QueryInput, error) {
table := args[0]
tableDescription, err := tableDescription(api, table)
if err != nil {
return nil, err
}

attrType := func(name string) string {
for _, def := range tableDescription.AttributeDefinitions {
if name == *def.AttributeName {
return *def.AttributeType
}
}
panic(fmt.Sprintf("unknown key: %s", name))
}

partitionKeyValue := args[1]
partitionKeyName := *tableDescription.KeySchema[0].AttributeName

expression := ""
names := map[string]*string{}
values := map[string]*dynamodb.AttributeValue{}

setValue := func(values map[string]*dynamodb.AttributeValue, name, key, value string) {
typ := attrType(name)
switch typ {
case dynamodb.ScalarAttributeTypeS:
values[key] = &dynamodb.AttributeValue{S: &value}
case dynamodb.ScalarAttributeTypeB:
values[key] = &dynamodb.AttributeValue{B: []byte(value)}
case dynamodb.ScalarAttributeTypeN:
values[key] = &dynamodb.AttributeValue{N: &value}
}
}

setValue(values, partitionKeyName, ":partitionKey", partitionKeyValue)

if len(args) == 2 { // table, partition value
expression = "#partitionKey = :partitionKey"
names = map[string]*string{
"#partitionKey": &partitionKeyName,
}
} else if len(args) == 3 { // table, partition value, sort (value|expression)
sortKeyName := *tableDescription.KeySchema[1].AttributeName
expr := parseSortExpr(args[2])
expression = fmt.Sprintf("#partitionKey = :partitionKey and %s", expr.expression)
for k, v := range expr.values {
setValue(values, sortKeyName, k, v)
}
names = map[string]*string{
"#partitionKey": &partitionKeyName,
"#skey": &sortKeyName,
}
}

input := &dynamodb.QueryInput{
TableName: &table,
KeyConditionExpression: &expression,
ExpressionAttributeValues: values,
ExpressionAttributeNames: names,
}

return input, nil
}

func tableDescription(api *Api, table string) (*dynamodb.TableDescription, error) {
describeResp, err := api.DescribeTable(&dynamodb.DescribeTableInput{TableName: &table})
if err != nil {
return nil, errors.WithStack(err)
}

tableDescription := describeResp.Table
return tableDescription, nil
}

type parsedExpr struct {
expression string
values map[string]string
}

func parseSortExpr(input string) *parsedExpr {
exprs := []struct {
re *regexp.Regexp
expr string
}{
{re: regexp.MustCompile(`^\s*<\s*=\s*(\S+)`), expr: "#skey <= :skey"},
{re: regexp.MustCompile(`^\s*>\s*=\s*(\S+)`), expr: "#skey >= :skey"},
{re: regexp.MustCompile(`^\s*<\s*(\S+)`), expr: "#skey < :skey"},
{re: regexp.MustCompile(`^\s*>\s*(\S+)`), expr: "#skey > :skey"},
{re: regexp.MustCompile(`^\s*=\s*(\S+)`), expr: "#skey = :skey"},
{re: regexp.MustCompile(`begins_with\s*\(?\s*([^)\s]+)\s*\)?`), expr: "begins_with(#skey, :skey)"},
{re: regexp.MustCompile(`\s*([^*]+)\*`), expr: "begins_with(#skey, :skey)"},
}

for _, expr := range exprs {
if m := expr.re.FindStringSubmatch(input); len(m) > 0 {
return &parsedExpr{
expression: expr.expr,
values: map[string]string{":skey": m[1]},
}
}
}

between := regexp.MustCompile(`\s*between\s+(\S+)\s+(\S+)`)
if m := between.FindStringSubmatch(input); len(m) > 0 {
return &parsedExpr{
expression: "#skey between :skey and :skeyb",
values: map[string]string{
":skey": m[1],
":skeyb": m[2],
},
}
}

return &parsedExpr{
expression: "#skey = :skey",
values: map[string]string{":skey": strings.TrimSpace(input)},
}

return nil
}
Loading

0 comments on commit 5c56cb5

Please sign in to comment.