diff --git a/engine_test.go b/engine_test.go index b022a3e6b..707ff1e54 100644 --- a/engine_test.go +++ b/engine_test.go @@ -2,6 +2,8 @@ package sqle_test import ( "context" + "fmt" + "github.com/src-d/go-mysql-server/sql/expression" "io" "math" "strings" @@ -59,6 +61,14 @@ var queries = []struct { "SELECT i FROM mytable WHERE i <> 2;", []sql.Row{{int64(1)}, {int64(3)}}, }, + { + "SELECT i FROM mytable WHERE i in (1, 3)", + []sql.Row{{int64(1)}, {int64(3)}}, + }, + { + "SELECT i FROM mytable WHERE i = 1 or i = 3", + []sql.Row{{int64(1)}, {int64(3)}}, + }, { "SELECT f32 FROM floattable WHERE f64 = 2.0;", []sql.Row{{float32(2.0)}}, @@ -1589,21 +1599,146 @@ var queries = []struct { } func TestQueries(t *testing.T) { - e := newEngine(t) - t.Run("sequential", func(t *testing.T) { - for _, tt := range queries { - testQuery(t, e, tt.query, tt.expected) + type indexDriverInitalizer func(map[string]*memory.Table) sql.IndexDriver + type indexDriverTestCase struct { + name string + initializer indexDriverInitalizer + } + + // Test all queries with these combinations, for a total of 12 runs: + // 1) Partitioned tables / non partitioned tables + // 2) Mergeable / unmergeable / no indexes + // 3) Parallelism on / off + numPartitionsVals := []int{ + 1, + testNumPartitions, + } + indexDrivers := []*indexDriverTestCase{ + nil, + {"unmergableIndexes", unmergableIndexDriver}, + {"mergableIndexes", mergableIndexDriver}, + } + parallelVals := []int{ + 1, + 2, + } + for _, numPartitions := range numPartitionsVals { + for _, indexDriverInit := range indexDrivers { + for _, parallelism := range parallelVals { + tables := allTestTables(t, numPartitions) + + var indexDriver sql.IndexDriver + if indexDriverInit != nil { + indexDriver = indexDriverInit.initializer(tables) + } + engine := newEngineWithParallelism(t, parallelism, tables, indexDriver) + + indexDriverName := "none" + if indexDriverInit != nil { + indexDriverName = indexDriverInit.name + } + testName := fmt.Sprintf("partitions=%d,indexes=%v,parallelism=%v", numPartitions, indexDriverName, parallelism) + t.Run(testName, func(t *testing.T) { + for _, tt := range queries { + testQuery(t, engine, tt.query, tt.expected) + } + }) + } } + } +} + +func unmergableIndexDriver(tables map[string]*memory.Table) sql.IndexDriver { + return memory.NewIndexDriver("mydb", map[string][]sql.Index{ + "mytable": { + newUnmergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false)), + newUnmergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(1, sql.Text, "mytable", "s", false)), + newUnmergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetFieldWithTable(1, sql.Text, "mytable", "s", false)), + }, + "othertable": { + newUnmergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(0, sql.Text, "othertable", "s2", false)), + newUnmergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(1, sql.Text, "othertable", "i2", false)), + newUnmergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(0, sql.Text, "othertable", "s2", false), + expression.NewGetFieldWithTable(1, sql.Text, "othertable", "i2", false)), + }, + "bigtable": { + newUnmergableIndex(tables, "bigtable", + expression.NewGetFieldWithTable(0, sql.Text, "bigtable", "t", false)), + }, + "floattable": { + newUnmergableIndex(tables, "floattable", + expression.NewGetFieldWithTable(2, sql.Text, "floattable", "f64", false)), + }, + "niltable": { + newUnmergableIndex(tables, "niltable", + expression.NewGetFieldWithTable(0, sql.Int64, "niltable", "i", false)), + }, }) +} - ep := newEngineWithParallelism(t, 2) - t.Run("parallel", func(t *testing.T) { - for _, tt := range queries { - testQuery(t, ep, tt.query, tt.expected) - } +func mergableIndexDriver(tables map[string]*memory.Table) sql.IndexDriver { + return memory.NewIndexDriver("mydb", map[string][]sql.Index{ + "mytable": { + newMergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false)), + newMergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(1, sql.Text, "mytable", "s", false)), + newMergableIndex(tables, "mytable", + expression.NewGetFieldWithTable(0, sql.Int64, "mytable", "i", false), + expression.NewGetFieldWithTable(1, sql.Text, "mytable", "s", false)), + }, + "othertable": { + newMergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(0, sql.Text, "othertable", "s2", false)), + newMergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(1, sql.Text, "othertable", "i2", false)), + newMergableIndex(tables, "othertable", + expression.NewGetFieldWithTable(0, sql.Text, "othertable", "s2", false), + expression.NewGetFieldWithTable(1, sql.Text, "othertable", "i2", false)), + }, + "bigtable": { + newMergableIndex(tables, "bigtable", + expression.NewGetFieldWithTable(0, sql.Text, "bigtable", "t", false)), + }, + "floattable": { + newMergableIndex(tables, "floattable", + expression.NewGetFieldWithTable(2, sql.Text, "floattable", "f64", false)), + }, + "niltable": { + newMergableIndex(tables, "niltable", + expression.NewGetFieldWithTable(0, sql.Int64, "niltable", "i", false)), + }, }) } + +func newUnmergableIndex(tables map[string]*memory.Table, tableName string, exprs ...sql.Expression) *memory.UnmergeableIndex { + return &memory.UnmergeableIndex{ + DB: "mydb", + DriverName: memory.IndexDriverId, + TableName: tableName, + Tbl: tables[tableName], + Exprs: exprs, + } +} + +func newMergableIndex(tables map[string]*memory.Table, tableName string, exprs ...sql.Expression) *memory.MergeableIndex { + return &memory.MergeableIndex{ + DB: "mydb", + DriverName: memory.IndexDriverId, + TableName: tableName, + Tbl: tables[tableName], + Exprs: exprs, + } +} + func TestSessionSelectLimit(t *testing.T) { ctx := newCtx() ctx.Session.Set("sql_select_limit", sql.Int64, int64(1)) @@ -1752,7 +1887,7 @@ func TestWarnings(t *testing.T) { } e := newEngine(t) - ep := newEngineWithParallelism(t, 2) + ep := newEngineWithParallelism(t, 2, allTestTables(t, testNumPartitions), nil) t.Run("sequential", func(t *testing.T) { for _, tt := range queries { @@ -1816,7 +1951,7 @@ func TestClearWarnings(t *testing.T) { func TestDescribe(t *testing.T) { e := newEngine(t) - ep := newEngineWithParallelism(t, 2) + ep := newEngineWithParallelism(t, 2, allTestTables(t, testNumPartitions), nil) query := `DESCRIBE FORMAT=TREE SELECT * FROM mytable` expectedSeq := []sql.Row{ @@ -2515,6 +2650,27 @@ func TestCreateTable(t *testing.T) { } require.Equal(s, testTable.Schema()) + + testQuery(t, e, + "CREATE TABLE t4(a INTEGER,"+ + "b TEXT NOT NULL,"+ + "c bool, primary key (a))", + []sql.Row(nil), + ) + + db, err = e.Catalog.Database("mydb") + require.NoError(err) + + testTable, ok = db.Tables()["t4"] + require.True(ok) + + s = sql.Schema{ + {Name: "a", Type: sql.Int32, Nullable: false, PrimaryKey: true, Source: "t4"}, + {Name: "b", Type: sql.Text, Nullable: false, PrimaryKey: false, Source: "t4"}, + {Name: "c", Type: sql.Uint8, Nullable: true, Source: "t4"}, + } + + require.Equal(s, testTable.Schema()) } func TestDropTable(t *testing.T) { @@ -2801,66 +2957,64 @@ func testQueryWithContext(ctx *sql.Context, t *testing.T, e *sqle.Engine, q stri }) } -func newEngine(t *testing.T) *sqle.Engine { - return newEngineWithParallelism(t, 1) -} +func allTestTables(t *testing.T, numPartitions int) map[string]*memory.Table { + tables := make(map[string]*memory.Table) -func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { - table := memory.NewPartitionedTable("mytable", sql.Schema{ + tables["mytable"] = memory.NewPartitionedTable("mytable", sql.Schema{ {Name: "i", Type: sql.Int64, Source: "mytable"}, {Name: "s", Type: sql.Text, Source: "mytable"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, table, + t, tables["mytable"], sql.NewRow(int64(1), "first row"), sql.NewRow(int64(2), "second row"), sql.NewRow(int64(3), "third row"), ) - table2 := memory.NewPartitionedTable("othertable", sql.Schema{ + tables["othertable"] = memory.NewPartitionedTable("othertable", sql.Schema{ {Name: "s2", Type: sql.Text, Source: "othertable"}, {Name: "i2", Type: sql.Int64, Source: "othertable"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, table2, + t, tables["othertable"], sql.NewRow("first", int64(3)), sql.NewRow("second", int64(2)), sql.NewRow("third", int64(1)), ) - table3 := memory.NewPartitionedTable("tabletest", sql.Schema{ + tables["tabletest"] = memory.NewPartitionedTable("tabletest", sql.Schema{ {Name: "i", Type: sql.Int32, Source: "tabletest"}, {Name: "s", Type: sql.Text, Source: "tabletest"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, table3, + t, tables["tabletest"], sql.NewRow(int64(1), "first row"), sql.NewRow(int64(2), "second row"), sql.NewRow(int64(3), "third row"), ) - table4 := memory.NewPartitionedTable("other_table", sql.Schema{ + tables["other_table"] = memory.NewPartitionedTable("other_table", sql.Schema{ {Name: "text", Type: sql.Text, Source: "tabletest"}, {Name: "number", Type: sql.Int32, Source: "tabletest"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, table4, + t, tables["other_table"], sql.NewRow("a", int32(4)), sql.NewRow("b", int32(2)), sql.NewRow("c", int32(0)), ) - bigtable := memory.NewPartitionedTable("bigtable", sql.Schema{ + tables["bigtable"] = memory.NewPartitionedTable("bigtable", sql.Schema{ {Name: "t", Type: sql.Text, Source: "bigtable"}, {Name: "n", Type: sql.Int64, Source: "bigtable"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, bigtable, + t, tables["bigtable"], sql.NewRow("a", int64(1)), sql.NewRow("s", int64(2)), sql.NewRow("f", int64(3)), @@ -2877,14 +3031,14 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { sql.NewRow("b", int64(9)), ) - floatTable := memory.NewPartitionedTable("floattable", sql.Schema{ + tables["floattable"] = memory.NewPartitionedTable("floattable", sql.Schema{ {Name: "i", Type: sql.Int64, Source: "floattable"}, {Name: "f32", Type: sql.Float32, Source: "floattable"}, {Name: "f64", Type: sql.Float64, Source: "floattable"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, floatTable, + t, tables["floattable"], sql.NewRow(int64(1), float32(1.0), float64(1.0)), sql.NewRow(int64(2), float32(1.5), float64(1.5)), sql.NewRow(int64(3), float32(2.0), float64(2.0)), @@ -2893,14 +3047,14 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { sql.NewRow(int64(-2), float32(-1.5), float64(-1.5)), ) - nilTable := memory.NewPartitionedTable("niltable", sql.Schema{ + tables["niltable"] = memory.NewPartitionedTable("niltable", sql.Schema{ {Name: "i", Type: sql.Int64, Source: "niltable", Nullable: true}, {Name: "b", Type: sql.Boolean, Source: "niltable", Nullable: true}, {Name: "f", Type: sql.Float64, Source: "niltable", Nullable: true}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, nilTable, + t, tables["niltable"], sql.NewRow(int64(1), true, float64(1.0)), sql.NewRow(int64(2), nil, float64(2.0)), sql.NewRow(nil, false, float64(3.0)), @@ -2908,13 +3062,13 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { sql.NewRow(nil, nil, nil), ) - newlineTable := memory.NewPartitionedTable("newlinetable", sql.Schema{ + tables["newlinetable"] = memory.NewPartitionedTable("newlinetable", sql.Schema{ {Name: "i", Type: sql.Int64, Source: "newlinetable"}, {Name: "s", Type: sql.Text, Source: "newlinetable"}, - }, testNumPartitions) + }, numPartitions) insertRows( - t, newlineTable, + t, tables["newlinetable"], sql.NewRow(int64(1), "\nthere is some text in here"), sql.NewRow(int64(2), "there is some\ntext in here"), sql.NewRow(int64(3), "there is some text\nin here"), @@ -2922,7 +3076,7 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { sql.NewRow(int64(5), "there is some text in here"), ) - typestable := memory.NewPartitionedTable("typestable", sql.Schema{ + tables["typestable"] = memory.NewPartitionedTable("typestable", sql.Schema{ {Name: "id", Type: sql.Int64, Source: "typestable"}, {Name: "i8", Type: sql.Int8, Source: "typestable", Nullable: true}, {Name: "i16", Type: sql.Int16, Source: "typestable", Nullable: true}, @@ -2940,20 +3094,25 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { {Name: "bo", Type: sql.Boolean, Source: "typestable", Nullable: true}, {Name: "js", Type: sql.JSON, Source: "typestable", Nullable: true}, {Name: "bl", Type: sql.Blob, Source: "typestable", Nullable: true}, - }, testNumPartitions) + }, numPartitions) + + return tables +} +func newEngine(t *testing.T) *sqle.Engine { + return newEngineWithParallelism(t, 1, allTestTables(t, testNumPartitions), nil) +} + +func newEngineWithParallelism(t *testing.T, parallelism int, tables map[string]*memory.Table, driver sql.IndexDriver) *sqle.Engine { db := memory.NewDatabase("mydb") - db.AddTable("mytable", table) - db.AddTable("othertable", table2) - db.AddTable("tabletest", table3) - db.AddTable("bigtable", bigtable) - db.AddTable("floattable", floatTable) - db.AddTable("niltable", nilTable) - db.AddTable("newlinetable", newlineTable) - db.AddTable("typestable", typestable) + for name, table := range tables { + if name != "other_table" { + db.AddTable(name, table) + } + } db2 := memory.NewDatabase("foo") - db2.AddTable("other_table", table4) + db2.AddTable("other_table", tables["other_table"]) catalog := sql.NewCatalog() catalog.AddDatabase(db) @@ -2967,7 +3126,14 @@ func newEngineWithParallelism(t *testing.T, parallelism int) *sqle.Engine { a = analyzer.NewDefault(catalog) } - return sqle.New(catalog, a, new(sqle.Config)) + if driver != nil { + catalog.RegisterIndexDriver(driver) + } + + engine := sqle.New(catalog, a, new(sqle.Config)) + require.NoError(t, engine.Init()) + + return engine } const expectedTree = `Limit(5) diff --git a/memory/ascend_index.go b/memory/ascend_index.go new file mode 100755 index 000000000..71b74e884 --- /dev/null +++ b/memory/ascend_index.go @@ -0,0 +1,63 @@ +package memory + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +type AscendIndexLookup struct { + id string + Gte []interface{} + Lt []interface{} + Index ExpressionsIndex +} + +var _ memoryIndexLookup = (*AscendIndexLookup)(nil) + +func (l *AscendIndexLookup) ID() string { return l.id } + +func (l *AscendIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + return &indexValIter{ + tbl: l.Index.MemTable(), + partition: p, + matchExpression: l.EvalExpression(), + }, nil +} + +func (l *AscendIndexLookup) Indexes() []string { + return []string{l.id} +} + +func (l *AscendIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(MergeableLookup) + return ok +} + +func (l *AscendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return union(l.Index, l, lookups...) +} + +func (l *AscendIndexLookup) EvalExpression() sql.Expression { + if len(l.Index.ColumnExpressions()) > 1 { + panic("Ascend index unsupported for multi-column indexes") + } + + lt, typ := getType(l.Lt[0]) + ltexpr := expression.NewLessThan(l.Index.ColumnExpressions()[0], expression.NewLiteral(lt, typ)) + if len(l.Gte) > 0 { + gte, _ := getType(l.Gte[0]) + return and( + ltexpr, + expression.NewGreaterThanOrEqual(l.Index.ColumnExpressions()[0], expression.NewLiteral(gte, typ)), + ) + } + return ltexpr +} + +func (*AscendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("ascendIndexLookup.Difference is not implemented") +} + +func (l *AscendIndexLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + return intersection(l.Index, l, lookups...) +} diff --git a/memory/descend_index.go b/memory/descend_index.go new file mode 100755 index 000000000..3a8bead19 --- /dev/null +++ b/memory/descend_index.go @@ -0,0 +1,64 @@ +package memory + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +type DescendIndexLookup struct { + id string + Gt []interface{} + Lte []interface{} + Index ExpressionsIndex +} + +var _ memoryIndexLookup = (*DescendIndexLookup)(nil) + +func (l *DescendIndexLookup) ID() string { return l.id } + +func (l *DescendIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + return &indexValIter{ + tbl: l.Index.MemTable(), + partition: p, + matchExpression: l.EvalExpression(), + }, nil +} + +func (l *DescendIndexLookup) EvalExpression() sql.Expression { + if len(l.Index.ColumnExpressions()) > 1 { + panic("Descend index unsupported for multi-column indexes") + } + + gt, typ := getType(l.Gt[0]) + gtexpr := expression.NewGreaterThan(l.Index.ColumnExpressions()[0], expression.NewLiteral(gt, typ)) + if len(l.Lte) > 0 { + lte, _ := getType(l.Lte[0]) + return and( + gtexpr, + expression.NewLessThanOrEqual(l.Index.ColumnExpressions()[0], expression.NewLiteral(lte, typ)), + ) + } + return gtexpr +} + +func (l *DescendIndexLookup) Indexes() []string { + return []string{l.id} +} + +func (l *DescendIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(MergeableLookup) + return ok +} + +func (l *DescendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return union(l.Index, l, lookups...) +} + +func (*DescendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("descendIndexLookup.Difference is not implemented") +} + +func (l *DescendIndexLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + return intersection(l.Index, l, lookups...) +} + diff --git a/memory/index_driver.go b/memory/index_driver.go new file mode 100755 index 000000000..9fea48bfa --- /dev/null +++ b/memory/index_driver.go @@ -0,0 +1,43 @@ +package memory + +import ( + "github.com/src-d/go-mysql-server/sql" +) + +const IndexDriverId = "MemoryIndexDriver" + +// TestIndexDriver is a non-performant index driver meant to aid in verification of engine correctness. It can not +// create or delete indexes, but will use the index types defined in this package to alter how queries are executed, +// retrieving values from the indexes rather than from the tables directly. +type TestIndexDriver struct { + db string + indexes map[string][]sql.Index +} + +// NewIndexDriver returns a new index driver for database and the indexes given, keyed by the table name. +func NewIndexDriver(db string, indexes map[string][]sql.Index) *TestIndexDriver { + return &TestIndexDriver{db: db, indexes: indexes} +} + +func (d *TestIndexDriver) ID() string { + return IndexDriverId +} + +func (d *TestIndexDriver) LoadAll(db, table string) ([]sql.Index, error) { + if d.db != db { + return nil, nil + } + return d.indexes[table], nil +} + +func (d *TestIndexDriver) Save(*sql.Context, sql.Index, sql.PartitionIndexKeyValueIter) error { + panic("not implemented") +} + +func (d *TestIndexDriver) Delete(sql.Index, sql.PartitionIter) error { + panic("not implemented") +} + +func (d *TestIndexDriver) Create(db, table, id string, expressions []sql.Expression, config map[string]string) (sql.Index, error) { + panic("not implemented") +} \ No newline at end of file diff --git a/memory/mergeable_index.go b/memory/mergeable_index.go new file mode 100755 index 000000000..e50268ba1 --- /dev/null +++ b/memory/mergeable_index.go @@ -0,0 +1,289 @@ +package memory + +import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "strings" +) + +type MergeableIndex struct { + DB string // required for engine tests with driver + DriverName string // required for engine tests with driver + Tbl *Table // required for engine tests with driver + TableName string + Exprs []sql.Expression +} + +var _ sql.Index = (*MergeableIndex)(nil) +var _ sql.AscendIndex = (*MergeableIndex)(nil) +var _ sql.DescendIndex = (*MergeableIndex)(nil) +var _ sql.NegateIndex = (*MergeableIndex)(nil) + +func (i *MergeableIndex) Database() string { return i.DB } +func (i *MergeableIndex) Driver() string { return i.DriverName } +func (i *MergeableIndex) MemTable() *Table { return i.Tbl } +func (i *MergeableIndex) ColumnExpressions() []sql.Expression { return i.Exprs } + +func (i *MergeableIndex) Expressions() []string { + var exprs []string + for _, e := range i.Exprs { + exprs = append(exprs, e.String()) + } + return exprs +} + +func (i *MergeableIndex) AscendGreaterOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + return &AscendIndexLookup{Gte: keys, Index: i}, nil +} + +func (i *MergeableIndex) AscendLessThan(keys ...interface{}) (sql.IndexLookup, error) { + return &AscendIndexLookup{Lt: keys, Index: i}, nil +} + +func (i *MergeableIndex) AscendRange(greaterOrEqual, lessThan []interface{}) (sql.IndexLookup, error) { + return &AscendIndexLookup{Gte: greaterOrEqual, Lt: lessThan, Index: i}, nil +} + +func (i *MergeableIndex) DescendGreater(keys ...interface{}) (sql.IndexLookup, error) { + return &DescendIndexLookup{Gt: keys, Index: i}, nil +} + +func (i *MergeableIndex) DescendLessOrEqual(keys ...interface{}) (sql.IndexLookup, error) { + return &DescendIndexLookup{Lte: keys, Index: i}, nil +} + +func (i *MergeableIndex) DescendRange(lessOrEqual, greaterThan []interface{}) (sql.IndexLookup, error) { + return &DescendIndexLookup{Gt: greaterThan, Lte: lessOrEqual, Index: i}, nil +} + +func (i *MergeableIndex) Not(keys ...interface{}) (sql.IndexLookup, error) { + lookup, err := i.Get(keys...) + if err != nil { + return nil, err + } + + mergeable, _ := lookup.(*MergeableIndexLookup) + return &NegateIndexLookup{Lookup: mergeable, Index: mergeable.Index}, nil +} + +func (i *MergeableIndex) Get(key ...interface{}) (sql.IndexLookup, error) { + return &MergeableIndexLookup{Key: key, Index: i}, nil +} + +func (i *MergeableIndex) Has(sql.Partition, ...interface{}) (bool, error) { + panic("not implemented") +} + +func (i *MergeableIndex) ID() string { + if len(i.Exprs) == 1 { + return i.Exprs[0].String() + } + var parts = make([]string, len(i.Exprs)) + for i, e := range i.Exprs { + parts[i] = e.String() + } + + return "(" + strings.Join(parts, ", ") + ")" +} + +func (i *MergeableIndex) Table() string { return i.TableName } + +// All lookups in this package, except for UnmergeableLookup, are MergeableLookups. The IDs are mostly for testing / +// verification purposes. +type MergeableLookup interface { + ID() string +} + +// ExpressionsIndex is an index made out of one or more expressions (usually field expressions), linked to a Table. +type ExpressionsIndex interface { + MemTable() *Table + ColumnExpressions() []sql.Expression +} + +// MergeableIndexLookup is a lookup linked to an ExpressionsIndex. It can be merged with any other MergeableIndexLookup. All lookups in this package are Merge +type MergeableIndexLookup struct { + Key []interface{} + Index ExpressionsIndex +} + +// memoryIndexLookup is a lookup that defines an expression to evaluate which rows are part of the index values +type memoryIndexLookup interface { + EvalExpression() sql.Expression +} + +var _ sql.Mergeable = (*MergeableIndexLookup)(nil) +var _ sql.SetOperations = (*MergeableIndexLookup)(nil) +var _ memoryIndexLookup = (*MergeableIndexLookup)(nil) + +func (i *MergeableIndexLookup) ID() string { return strings.Join(i.Indexes(), ",") } + +func (i *MergeableIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(MergeableLookup) + return ok +} + +func (i *MergeableIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + var exprs []sql.Expression + for exprI, expr := range i.Index.ColumnExpressions() { + lit, typ := getType(i.Key[exprI]) + exprs = append(exprs, expression.NewEquals(expr, expression.NewLiteral(lit, typ))) + } + + return &indexValIter{ + tbl: i.Index.MemTable(), + partition: p, + matchExpression: and(exprs...), + }, nil +} + +func (i *MergeableIndexLookup) EvalExpression() sql.Expression { + var exprs []sql.Expression + for exprI, expr := range i.Index.ColumnExpressions() { + lit, typ := getType(i.Key[exprI]) + exprs = append(exprs, expression.NewEquals(expr, expression.NewLiteral(lit, typ))) + } + return and(exprs...) +} + +func (i *MergeableIndexLookup) Indexes() []string { + var idxes = make([]string, len(i.Key)) + for i, e := range i.Key { + idxes[i] = fmt.Sprint(e) + } + return idxes +} + +func (i *MergeableIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("not implemented") +} + +func (i *MergeableIndexLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + return intersection(i.Index, i, lookups...) +} + +// Intersects the lookups given together, collapsing redundant layers of intersections for lookups that have previously +// been merged. E.g. merging a MergeableIndexLookup with a MergedIndexLookup that has 2 intersections will return a +// MergedIndexLookup with 3 lookups intersected: the left param and the two intersected lookups from the +// MergedIndexLookup. +func intersection(idx ExpressionsIndex, left sql.IndexLookup, lookups ...sql.IndexLookup) sql.IndexLookup { + var merged []sql.IndexLookup + var allLookups []sql.IndexLookup + allLookups = append(allLookups, left) + allLookups = append(allLookups, lookups...) + for _, lookup := range allLookups { + if mil, ok := lookup.(*MergedIndexLookup); ok && len(mil.Intersections) > 0 { + merged = append(merged, mil.Intersections...) + } else { + merged = append(merged, lookup) + } + } + + return &MergedIndexLookup{ + Intersections: merged, + Index: idx, + } +} + +func (i *MergeableIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return union(i.Index, i, lookups...) +} + +// Unions the lookups given together, collapsing redundant layers of unions for lookups that have previously been +// merged. E.g. merging a MergeableIndexLookup with a MergedIndexLookup that has 2 unions will return a +// MergedIndexLookup with 3 lookups unioned: the left param and the two unioned lookups from the MergedIndexLookup. +func union(idx ExpressionsIndex, left sql.IndexLookup, lookups ...sql.IndexLookup) sql.IndexLookup { + var merged []sql.IndexLookup + var allLookups []sql.IndexLookup + allLookups = append(allLookups, left) + allLookups = append(allLookups, lookups...) + for _, lookup := range allLookups { + if mil, ok := lookup.(*MergedIndexLookup); ok && len(mil.Unions) > 0 { + merged = append(merged, mil.Unions...) + } else { + merged = append(merged, lookup) + } + } + + return &MergedIndexLookup{ + Unions: merged, + Index: idx, + } +} + +// MergedIndexLookup is an index lookup that has been merged with another. +// Exactly one of the Unions or Intersections fields should be set, and correspond to a logical AND or OR operation, +// respectively. +type MergedIndexLookup struct { + Unions []sql.IndexLookup + Intersections []sql.IndexLookup + Index ExpressionsIndex +} + +var _ sql.Mergeable = (*MergedIndexLookup)(nil) +var _ sql.SetOperations = (*MergedIndexLookup)(nil) +var _ memoryIndexLookup = (*MergedIndexLookup)(nil) + +func (m *MergedIndexLookup) EvalExpression() sql.Expression { + var exprs []sql.Expression + if m.Intersections != nil { + for _, lookup := range m.Intersections { + exprs = append(exprs, lookup.(memoryIndexLookup).EvalExpression()) + } + return and(exprs...) + } + if m.Unions != nil { + for _, lookup := range m.Unions { + exprs = append(exprs, lookup.(memoryIndexLookup).EvalExpression()) + } + return or(exprs...) + } + panic("either Unions or Intersections must be non-nil") +} + +func (m *MergedIndexLookup) Intersection(lookups ...sql.IndexLookup) sql.IndexLookup { + return intersection(m.Index, m, lookups...) +} + +func (m *MergedIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return union(m.Index, m, lookups...) +} + +func (m *MergedIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("not implemented") +} + +func (m *MergedIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(MergeableLookup) + return ok +} + +func (m *MergedIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + return &indexValIter{ + tbl: m.Index.MemTable(), + partition: p, + matchExpression: m.EvalExpression(), + }, nil +} + +func (m *MergedIndexLookup) Indexes() []string { + panic("not implemented") +} + +func (m *MergedIndexLookup) ID() string { + panic("not implemented") +} + +func or(expressions ...sql.Expression) sql.Expression { + if len(expressions) == 1 { + return expressions[0] + } + return expression.NewOr(expressions[0], or(expressions[1:]...)) +} + +func and(expressions ...sql.Expression) sql.Expression { + if len(expressions) == 1 { + return expressions[0] + } + return expression.NewAnd(expressions[0], and(expressions[1:]...)) +} \ No newline at end of file diff --git a/memory/negative_index.go b/memory/negative_index.go new file mode 100755 index 000000000..830b75454 --- /dev/null +++ b/memory/negative_index.go @@ -0,0 +1,48 @@ +package memory + +import ( + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" +) + +type NegateIndexLookup struct { + Lookup MergeableLookup + Index ExpressionsIndex +} + +var _ memoryIndexLookup = (*NegateIndexLookup)(nil) + +func (l *NegateIndexLookup) ID() string { return "not " + l.Lookup.ID() } + +func (l *NegateIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + return &indexValIter{ + tbl: l.Index.MemTable(), + partition: p, + matchExpression: l.EvalExpression(), + }, nil +} + +func (l *NegateIndexLookup) EvalExpression() sql.Expression { + return expression.NewNot(l.Lookup.(memoryIndexLookup).EvalExpression()) +} + +func (l *NegateIndexLookup) Indexes() []string { + return []string{l.ID()} +} + +func (*NegateIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { + _, ok := lookup.(MergeableLookup) + return ok +} + +func (l *NegateIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { + return union(l.Index, l, lookups...) +} + +func (*NegateIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { + panic("negateIndexLookup.Difference is not implemented") +} + +func (l *NegateIndexLookup) Intersection(indexes ...sql.IndexLookup) sql.IndexLookup { + return intersection(l.Index, l, indexes...) +} \ No newline at end of file diff --git a/memory/unmergeable_index.go b/memory/unmergeable_index.go new file mode 100755 index 000000000..b9c9bcb81 --- /dev/null +++ b/memory/unmergeable_index.go @@ -0,0 +1,179 @@ +package memory + +import ( + "fmt" + "github.com/src-d/go-mysql-server/sql" + "github.com/src-d/go-mysql-server/sql/expression" + "io" + "strings" +) + +// A very dumb index that iterates over the rows of a table, evaluates its matching expressions against each row, and +// stores those values to be later retrieved. Only here to test the functionality of indexed queries. This kind of index +// cannot be merged with any other index. +type UnmergeableIndex struct { + DB string // required for engine tests with driver + DriverName string // required for engine tests with driver + Tbl *Table // required for engine tests with driver + TableName string + Exprs []sql.Expression +} + +var _ sql.Index = (*UnmergeableIndex)(nil) + +func (u *UnmergeableIndex) Database() string { return u.DB } +func (u *UnmergeableIndex) Driver() string { return u.DriverName } + +func (u *UnmergeableIndex) Expressions() []string { + var exprs []string + for _, e := range u.Exprs { + exprs = append(exprs, e.String()) + } + return exprs +} + +func (u *UnmergeableIndex) Get(key ...interface{}) (sql.IndexLookup, error) { + return &UnmergeableIndexLookup{ + key: key, + idx: u, + }, nil +} + +// UnmergeableIndexLookup is the only IndexLookup in this package that doesn't implement Mergeable, and therefore +// can't be merged with other lookups. +type UnmergeableIndexLookup struct { + key []interface{} + idx *UnmergeableIndex +} + +// dummyIndexValueIter does a very simple and verifiable iteration over the table values for a given index. It does this +// by iterating over all the table rows for a partition and evaluating each of them for inclusion in the index. This is +// not an efficient way to store an index, and is only suitable for testing the correctness of index code in the engine. +type indexValIter struct { + tbl *Table + partition sql.Partition + matchExpression sql.Expression + values [][]byte + i int +} + +func (u *indexValIter) Next() ([]byte, error) { + err := u.initValues() + if err != nil { + return nil, err + } + + if u.i < len(u.values) { + valBytes := u.values[u.i] + u.i++ + return valBytes, nil + } + + return nil, io.EOF +} + +func (u *indexValIter) initValues() error { + if u.values == nil { + rows, ok := u.tbl.partitions[string(u.partition.Key())] + if !ok { + return fmt.Errorf( + "partition not found: %q", u.partition.Key(), + ) + } + + for i, row := range rows { + ok, err := sql.EvaluateCondition(sql.NewEmptyContext(), u.matchExpression, row) + if err != nil { + return err + } + + if ok { + encoded, err := encodeIndexValue(&indexValue{ + Pos: i, + }) + + if err != nil { + return err + } + + u.values = append(u.values, encoded) + } + } + } + + return nil +} + +func getType(val interface{}) (interface{}, sql.Type) { + switch val := val.(type) { + case int8: + return int64(val), sql.Int64 + case uint8: + return int64(val), sql.Int64 + case int16: + return int64(val), sql.Int64 + case uint16: + return int64(val), sql.Int64 + case int32: + return int64(val), sql.Int64 + case uint32: + return int64(val), sql.Int64 + case int64: + return int64(val), sql.Int64 + case uint64: + return int64(val), sql.Int64 + case float32: + return float64(val), sql.Float64 + case float64: + return float64(val), sql.Float64 + case string: + return val, sql.Text + default:panic(fmt.Sprintf("Unsupported type for %v of type %T", val, val)) + } +} + +func (u *indexValIter) Close() error { + return nil +} + +func (u *UnmergeableIndexLookup) Values(p sql.Partition) (sql.IndexValueIter, error) { + var exprs []sql.Expression + for exprI, expr := range u.idx.Exprs { + lit, typ := getType(u.key[exprI]) + exprs = append(exprs, expression.NewEquals(expr, expression.NewLiteral(lit, typ))) + } + + return &indexValIter{ + tbl: u.idx.Tbl, + partition: p, + matchExpression: and(exprs...), + }, nil +} + +func (u *UnmergeableIndexLookup) Indexes() []string { + var idxes = make([]string, len(u.key)) + for i, e := range u.idx.Exprs { + idxes[i] = fmt.Sprint(e) + } + return idxes +} + +func (u *UnmergeableIndex) Has(partition sql.Partition, key ...interface{}) (bool, error) { + panic("not implemented") +} + +func (u *UnmergeableIndex) ID() string { + if len(u.Exprs) == 1 { + return u.Exprs[0].String() + } + var parts = make([]string, len(u.Exprs)) + for i, e := range u.Exprs { + parts[i] = e.String() + } + + return "(" + strings.Join(parts, ", ") + ")" +} + +func (u *UnmergeableIndex) Table() string { + return u.TableName +} diff --git a/sql/analyzer/assign_indexes.go b/sql/analyzer/assign_indexes.go index b6bcb1327..18220a0cc 100644 --- a/sql/analyzer/assign_indexes.go +++ b/sql/analyzer/assign_indexes.go @@ -101,12 +101,24 @@ func getIndexes(e sql.Expression, aliases map[string]sql.Expression, a *Analyzer return nil, err } - for table, idx := range leftIndexes { - if idx2, ok := rightIndexes[table]; ok && canMergeIndexes(idx.lookup, idx2.lookup) { - idx.lookup = idx.lookup.(sql.SetOperations).Union(idx2.lookup) - idx.indexes = append(idx.indexes, idx2.indexes...) + for table, leftIdx := range leftIndexes { + result[table] = leftIdx + } + + // Merge any indexes for the same table on the left and right sides. + for table, leftIdx := range leftIndexes { + if rightIdx, ok := rightIndexes[table]; ok { + if canMergeIndexes(leftIdx.lookup, rightIdx.lookup) { + leftIdx.lookup = leftIdx.lookup.(sql.SetOperations).Union(rightIdx.lookup) + leftIdx.indexes = append(leftIdx.indexes, rightIdx.indexes...) + result[table] = leftIdx + } else { + // Since we can return one index per table, if we can't merge the second index from this table, return no + // indexes. Returning a single one will lead to incorrect results from e.g. pushdown operations when only one + // side of the OR expression is used to index the table. + return nil, nil + } } - result[table] = idx } // Put in the result map the indexes for tables we don't have indexes yet. @@ -161,7 +173,6 @@ func getIndexes(e sql.Expression, aliases map[string]sql.Expression, a *Analyzer lookup, errLookup = nidx.Not(values[0]) } else { lookup, errLookup = idx.Get(values[0]) - } if errLookup != nil { @@ -175,16 +186,15 @@ func getIndexes(e sql.Expression, aliases map[string]sql.Expression, a *Analyzer lookup2, errLookup = nidx.Not(v) } else { lookup2, errLookup = idx.Get(v) - } if errLookup != nil { return nil, err } - // if one of the indexes cannot be merged, return already + // if one of the indexes cannot be merged, return a nil result for this table if !canMergeIndexes(lookup, lookup2) { - return result, nil + return nil, nil } if negate { @@ -312,6 +322,11 @@ func unifyExpressions(aliases map[string]sql.Expression, expr ...sql.Expression) } func betweenIndexLookup(index sql.Index, upper, lower []interface{}) (sql.IndexLookup, error) { + // TODO: two bugs here + // 1) Mergeable and SetOperations are separate interfaces, so a naive integrator could generate a type assertion + // error in this method + // 2) Since AscendRange and DescendRange both accept an upper and lower bound, there is no good reason to require + // both implementations from an index. One will do fine, no need to require both and merge them. ai, isAscend := index.(sql.AscendIndex) di, isDescend := index.(sql.DescendIndex) if isAscend && isDescend { @@ -493,7 +508,6 @@ func getNegatedIndexes(a *Analyzer, not *expression.Not, aliases map[string]sql. return getIndexes(or, aliases, a) default: return nil, nil - } } diff --git a/sql/analyzer/assign_indexes_test.go b/sql/analyzer/assign_indexes_test.go index d8e61f05a..58e45487a 100644 --- a/sql/analyzer/assign_indexes_test.go +++ b/sql/analyzer/assign_indexes_test.go @@ -1,8 +1,6 @@ package analyzer import ( - "fmt" - "strings" "testing" "github.com/src-d/go-mysql-server/memory" @@ -16,9 +14,9 @@ func TestNegateIndex(t *testing.T) { require := require.New(t) catalog := sql.NewCatalog() - idx1 := &dummyIndex{ - "t1", - []sql.Expression{ + idx1 := &memory.MergeableIndex{ + TableName: "t1", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), }, } @@ -52,18 +50,18 @@ func TestNegateIndex(t *testing.T) { lookupIdxs, ok := result["t1"] require.True(ok) - negate, ok := lookupIdxs.lookup.(*negateIndexLookup) + negate, ok := lookupIdxs.lookup.(*memory.NegateIndexLookup) require.True(ok) - require.True(negate.value == "1") + require.Equal("not 1", negate.ID()) } func TestAssignIndexes(t *testing.T) { require := require.New(t) catalog := sql.NewCatalog() - idx1 := &dummyIndex{ - "t2", - []sql.Expression{ + idx1 := &memory.MergeableIndex{ + TableName: "t2", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int64, "t2", "bar", false), }, } @@ -72,9 +70,9 @@ func TestAssignIndexes(t *testing.T) { close(done) <-ready - idx2 := &dummyIndex{ - "t1", - []sql.Expression{ + idx2 := &memory.MergeableIndex{ + TableName: "t1", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int64, "t1", "foo", false), }, } @@ -84,6 +82,18 @@ func TestAssignIndexes(t *testing.T) { close(done) <-ready + idx3 := &memory.UnmergeableIndex{ + TableName: "t1", + Exprs: []sql.Expression{ + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "bar", false), + }, + } + + done, ready, err = catalog.AddIndex(idx3) + require.NoError(err) + close(done) + <-ready + a := NewDefault(catalog) t1 := memory.NewTable("t1", sql.Schema{ @@ -125,39 +135,135 @@ func TestAssignIndexes(t *testing.T) { lookupIdxs, ok := result["t1"] require.True(ok) - mergeable, ok := lookupIdxs.lookup.(*mergeableIndexLookup) + mergeable, ok := lookupIdxs.lookup.(*memory.MergeableIndexLookup) require.True(ok) - require.True(mergeable.id == "2") + require.Equal("2", mergeable.ID()) lookupIdxs, ok = result["t2"] require.True(ok) - mergeable, ok = lookupIdxs.lookup.(*mergeableIndexLookup) + mergeable, ok = lookupIdxs.lookup.(*memory.MergeableIndexLookup) require.True(ok) - require.True(mergeable.id == "1") + require.Equal("1", mergeable.ID()) + + node = plan.NewProject( + []sql.Expression{}, + plan.NewFilter( + expression.NewOr( + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "bar", false), + expression.NewLiteral(int64(1), sql.Int64), + ), + expression.NewEquals( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "bar", false), + expression.NewLiteral(int64(2), sql.Int64), + ), + ), + plan.NewResolvedTable(t1), + ), + ) + + result, err = assignIndexes(a, node) + require.NoError(err) + + _, ok = result["t1"] + require.False(ok) + + node = plan.NewProject( + []sql.Expression{}, + plan.NewFilter( + expression.NewIn( + expression.NewGetFieldWithTable(0, sql.Int64, "t1", "bar", false), + expression.NewTuple(expression.NewLiteral(int64(1), sql.Int64), expression.NewLiteral(int64(2), sql.Int64)), + ), + plan.NewResolvedTable(t1), + ), + ) + + result, err = assignIndexes(a, node) + require.NoError(err) + + _, ok = result["t1"] + require.False(ok) +} + +func intersectionLookupWithKeys(table string, column string, colIdx int, keys ...interface{}) *memory.MergedIndexLookup { + var lookups []sql.IndexLookup + for _, key := range keys { + lookups = append(lookups, mergeableIndexLookup(table, column, colIdx, key)) + } + return &memory.MergedIndexLookup{ + Intersections: lookups, + Index: mergeableIndex(table, column, colIdx), + } +} + +func unionLookupWithKeys(table string, column string, colIdx int, keys ...interface{}) *memory.MergedIndexLookup { + var lookups []sql.IndexLookup + for _, key := range keys { + lookups = append(lookups, mergeableIndexLookup(table, column, colIdx, key)) + } + return &memory.MergedIndexLookup{ + Unions: lookups, + Index: mergeableIndex(table, column, colIdx), + } +} + +func unionLookup(table string, column string, colIdx int, lookups ...sql.IndexLookup) *memory.MergedIndexLookup { + return &memory.MergedIndexLookup{ + Unions: lookups, + Index: mergeableIndex(table, column, colIdx), + } +} + +func intersectionLookup(table string, column string, colIdx int, lookups ...sql.IndexLookup) *memory.MergedIndexLookup { + return &memory.MergedIndexLookup{ + Intersections: lookups, + Index: mergeableIndex(table, column, colIdx), + } +} + +func mergeableIndexLookup(table string, column string, colIdx int, key ...interface{}) *memory.MergeableIndexLookup { + return &memory.MergeableIndexLookup{ + Key: key, + Index: mergeableIndex(table, column, colIdx), + } +} + +func mergeableIndex(table string, column string, colIdx int) *memory.MergeableIndex { + return &memory.MergeableIndex{ + TableName: table, + Exprs: []sql.Expression{col(colIdx, table, column)}, + } } func TestGetIndexes(t *testing.T) { - indexes := []*dummyIndex{ - { - "t1", - []sql.Expression{ + indexes := []sql.Index { + &memory.MergeableIndex{ + TableName: "t1", + Exprs: []sql.Expression{ col(0, "t1", "bar"), }, }, - { - "t2", - []sql.Expression{ + &memory.MergeableIndex{ + TableName: "t2", + Exprs: []sql.Expression{ col(0, "t2", "foo"), col(0, "t2", "bar"), }, }, - { - "t2", - []sql.Expression{ + &memory.MergeableIndex{ + TableName: "t2", + Exprs: []sql.Expression{ col(0, "t2", "bar"), }, }, + &memory.UnmergeableIndex{ + TableName: "t3", + Exprs: []sql.Expression{ + col(0, "t3", "foo"), + }, + }, } testCases := []struct { @@ -180,7 +286,7 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1"}, + mergeableIndexLookup("t1", "bar", 0, int64(1)), []sql.Index{indexes[0]}, }, }, @@ -199,7 +305,13 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1", unions: []string{"2"}}, + &memory.MergedIndexLookup{ + Unions: []sql.IndexLookup{ + mergeableIndexLookup("t1", "bar", 0, int64(1)), + mergeableIndexLookup("t1", "bar", 0, int64(2)), + }, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{ indexes[0], indexes[0], @@ -208,6 +320,43 @@ func TestGetIndexes(t *testing.T) { }, true, }, + { + or( + eq( + col(0, "t3", "foo"), + lit(1), + ), + eq( + col(0, "t3", "foo"), + lit(2), + ), + ), + nil, + true, + }, + { + in( + col(0, "t3", "foo"), + tuple(lit(1), lit(2)), + ), + nil, + true, + }, + { + in( + col(0, "t1", "bar"), + tuple(lit(1), lit(2)), + ), + map[string]*indexLookup{ + "t1": &indexLookup{ + unionLookupWithKeys("t1", "bar", 0, int64(1), int64(2)), + []sql.Index{ + indexes[0], + }, + }, + }, + true, + }, { and( eq( @@ -221,7 +370,7 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1", intersections: []string{"2"}}, + intersectionLookupWithKeys("t1", "bar", 0, int64(1), int64(2)), []sql.Index{ indexes[0], indexes[0], @@ -255,7 +404,16 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1", unions: []string{"2", "4"}, intersections: []string{"3"}}, + intersectionLookup("t1", "bar", 0, + unionLookup("t1", "bar", 0, + mergeableIndexLookup("t1", "bar", 0, int64(1)), + mergeableIndexLookup("t1", "bar", 0, int64(2)), + ), + unionLookup("t1", "bar", 0, + mergeableIndexLookup("t1", "bar", 0, int64(3)), + mergeableIndexLookup("t1", "bar", 0, int64(4)), + ), + ), []sql.Index{ indexes[0], indexes[0], @@ -291,7 +449,7 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1", unions: []string{"2", "3", "4"}}, + unionLookupWithKeys("t1", "bar", 0, int64(1), int64(2), int64(3), int64(4)), []sql.Index{ indexes[0], indexes[0], @@ -303,18 +461,13 @@ func TestGetIndexes(t *testing.T) { true, }, { - expression.NewIn( + in( col(0, "t1", "bar"), - expression.NewTuple( - lit(1), - lit(2), - lit(3), - lit(4), - ), + tuple(lit(1), lit(2), lit(3), lit(4)), ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "1", unions: []string{"2", "3", "4"}}, + unionLookupWithKeys("t1", "bar", 0, int64(1), int64(2), int64(3), int64(4)), []sql.Index{indexes[0]}, }, }, @@ -345,11 +498,20 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "3"}, + mergeableIndexLookup("t1", "bar", 0, int64(3)), []sql.Index{indexes[0]}, }, "t2": &indexLookup{ - &mergeableIndexLookup{id: "1, 2"}, + &memory.MergeableIndexLookup{ + Key: []interface{}{int64(1), int64(2)}, + Index: &memory.MergeableIndex{ + TableName: "t2", + Exprs: []sql.Expression{ + col(0, "t2", "foo"), + col(0, "t2", "bar"), + }, + }, + }, []sql.Index{indexes[1]}, }, }, @@ -386,11 +548,22 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "3"}, + mergeableIndexLookup("t1", "bar", 0, int64(3)), []sql.Index{indexes[0]}, }, "t2": &indexLookup{ - &mergeableIndexLookup{id: "5", unions: []string{"1, 2"}}, + unionLookup("t2", "bar", 0, + mergeableIndexLookup("t2", "bar", 0, int64(5)), + &memory.MergeableIndexLookup{ + Key: []interface{}{int64(1), int64(2)}, + Index: &memory.MergeableIndex{ + TableName: "t2", + Exprs: []sql.Expression{ + col(0, "t2", "foo"), + col(0, "t2", "bar"), + }, + }, + }), []sql.Index{ indexes[2], indexes[1], @@ -406,7 +579,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &descendIndexLookup{gt: []interface{}{int64(1)}}, + &memory.DescendIndexLookup{ + Gt: []interface{}{int64(1)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -419,7 +595,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &ascendIndexLookup{lt: []interface{}{int64(1)}}, + &memory.AscendIndexLookup{ + Lt: []interface{}{int64(1)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -432,7 +611,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &ascendIndexLookup{gte: []interface{}{int64(1)}}, + &memory.AscendIndexLookup{ + Gte: []interface{}{int64(1)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -445,7 +627,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &descendIndexLookup{lte: []interface{}{int64(1)}}, + &memory.DescendIndexLookup{ + Lte: []interface{}{int64(1)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -459,18 +644,18 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergedIndexLookup{ - []sql.IndexLookup{ - &ascendIndexLookup{ - gte: []interface{}{int64(1)}, - lt: []interface{}{int64(5)}, - }, - &descendIndexLookup{ - gt: []interface{}{int64(1)}, - lte: []interface{}{int64(5)}, - }, + unionLookup("t1", "bar", 0, + &memory.AscendIndexLookup{ + Gte: []interface{}{int64(1)}, + Lt: []interface{}{int64(5)}, + Index: mergeableIndex("t1", "bar", 0), }, - }, + &memory.DescendIndexLookup{ + Gt: []interface{}{int64(1)}, + Lte: []interface{}{int64(5)}, + Index: mergeableIndex("t1", "bar", 0), + }, + ), []sql.Index{indexes[0]}, }, }, @@ -485,8 +670,9 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &negateIndexLookup{ - value: "1", + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(1)), + Index: mergeableIndex("t1", "bar", 0), }, []sql.Index{indexes[0]}, }, @@ -503,7 +689,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &descendIndexLookup{lte: []interface{}{int64(10)}}, + &memory.DescendIndexLookup{ + Lte: []interface{}{int64(10)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -519,7 +708,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &ascendIndexLookup{lt: []interface{}{int64(10)}}, + &memory.AscendIndexLookup{ + Lt: []interface{}{int64(10)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -535,7 +727,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &descendIndexLookup{gt: []interface{}{int64(10)}}, + &memory.DescendIndexLookup{ + Gt: []interface{}{int64(10)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -551,7 +746,10 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &ascendIndexLookup{gte: []interface{}{int64(10)}}, + &memory.AscendIndexLookup{ + Gte: []interface{}{int64(10)}, + Index: mergeableIndex("t1", "bar", 0), + }, []sql.Index{indexes[0]}, }, }, @@ -572,16 +770,16 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergedIndexLookup{ - children: []sql.IndexLookup{ - &negateIndexLookup{ - value: "10", - }, - &negateIndexLookup{ - value: "11", - }, + unionLookup("t1", "bar", 0, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(10)), + Index: mergeableIndex("t1", "bar", 0), }, - }, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(11)), + Index: mergeableIndex("t1", "bar", 0), + }, + ), []sql.Index{ indexes[0], indexes[0], @@ -605,10 +803,16 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{ - id: "not 10", - intersections: []string{"not 11"}, - }, + intersectionLookup("t1", "bar", 0, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(10)), + Index: mergeableIndex("t1", "bar", 0), + }, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(11)), + Index: mergeableIndex("t1", "bar", 0), + }, + ), []sql.Index{ indexes[0], indexes[0], @@ -635,8 +839,9 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t2": &indexLookup{ - &negateIndexLookup{ - value: "110", + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t2", "bar", 0, int64(110)), + Index: mergeableIndex("t2", "bar", 0), }, []sql.Index{ indexes[2], @@ -657,10 +862,24 @@ func TestGetIndexes(t *testing.T) { ), map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{ - id: "not 1", - intersections: []string{"not 2", "not 3", "not 4"}, - }, + intersectionLookup("t1", "bar", 0, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(1)), + Index: mergeableIndex("t1", "bar", 0), + }, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(2)), + Index: mergeableIndex("t1", "bar", 0), + }, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(3)), + Index: mergeableIndex("t1", "bar", 0), + }, + &memory.NegateIndexLookup{ + Lookup: mergeableIndexLookup("t1", "bar", 0, int64(4)), + Index: mergeableIndex("t1", "bar", 0), + }, + ), []sql.Index{indexes[0]}, }, }, @@ -678,6 +897,7 @@ func TestGetIndexes(t *testing.T) { a := NewDefault(catalog) + var i int for _, tt := range testCases { t.Run(tt.expr.String(), func(t *testing.T) { require := require.New(t) @@ -689,6 +909,7 @@ func TestGetIndexes(t *testing.T) { } else { require.Error(err) } + i++ }) } } @@ -697,36 +918,36 @@ func TestGetMultiColumnIndexes(t *testing.T) { require := require.New(t) catalog := sql.NewCatalog() - indexes := []*dummyIndex{ + indexes := []*memory.MergeableIndex{ { - "t1", - []sql.Expression{ + TableName: "t1", + Exprs: []sql.Expression{ col(1, "t1", "foo"), col(2, "t1", "bar"), }, }, { - "t2", - []sql.Expression{ + TableName: "t2", + Exprs: []sql.Expression{ col(0, "t2", "foo"), col(1, "t2", "bar"), col(2, "t2", "baz"), }, }, { - "t2", - []sql.Expression{ + TableName: "t2", + Exprs: []sql.Expression{ col(0, "t2", "foo"), col(0, "t2", "bar"), }, }, { - "t3", - []sql.Expression{col(0, "t3", "foo")}, + TableName: "t3", + Exprs: []sql.Expression{col(0, "t3", "foo")}, }, { - "t4", - []sql.Expression{ + TableName: "t4", + Exprs: []sql.Expression{ col(1, "t4", "foo"), col(2, "t4", "bar"), }, @@ -784,24 +1005,35 @@ func TestGetMultiColumnIndexes(t *testing.T) { expected := map[string]*indexLookup{ "t1": &indexLookup{ - &mergeableIndexLookup{id: "5, 6"}, + &memory.MergeableIndexLookup{ + Key: []interface{}{int64(5), int64(6)}, + Index: indexes[0], + }, []sql.Index{indexes[0]}, }, "t2": &indexLookup{ - &mergeableIndexLookup{id: "1, 2, 3"}, + &memory.MergeableIndexLookup{ + Key: []interface{}{int64(1), int64(2), int64(3)}, + Index: indexes[1], + }, []sql.Index{indexes[1]}, }, "t4": &indexLookup{ - &mergedIndexLookup{[]sql.IndexLookup{ - &ascendIndexLookup{ - gte: []interface{}{int64(1), int64(2)}, - lt: []interface{}{int64(6), int64(5)}, - }, - &descendIndexLookup{ - gt: []interface{}{int64(1), int64(2)}, - lte: []interface{}{int64(6), int64(5)}, + &memory.MergedIndexLookup{ + Unions: []sql.IndexLookup{ + &memory.AscendIndexLookup{ + Gte: []interface{}{int64(1), int64(2)}, + Lt: []interface{}{int64(6), int64(5)}, + Index: indexes[4], + }, + &memory.DescendIndexLookup{ + Gt: []interface{}{int64(1), int64(2)}, + Lte: []interface{}{int64(6), int64(5)}, + Index: indexes[4], + }, }, - }}, + Index: indexes[4], + }, []sql.Index{indexes[4]}, }, } @@ -894,261 +1126,49 @@ func TestExpressionSources(t *testing.T) { require.Equal(t, expected, sources) } -type dummyIndexLookup struct{} +type DummyIndexLookup struct{} -func (dummyIndexLookup) Indexes() []string { return nil } +func (DummyIndexLookup) Indexes() []string { return nil } -func (dummyIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { +func (DummyIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { return nil, nil } -type dummyIndex struct { - table string - expr []sql.Expression -} - -var _ sql.Index = (*dummyIndex)(nil) -var _ sql.AscendIndex = (*dummyIndex)(nil) -var _ sql.DescendIndex = (*dummyIndex)(nil) -var _ sql.NegateIndex = (*dummyIndex)(nil) - -func (dummyIndex) Database() string { return "" } -func (dummyIndex) Driver() string { return "" } -func (i dummyIndex) Expressions() []string { - var exprs []string - for _, e := range i.expr { - exprs = append(exprs, e.String()) - } - return exprs -} - -func (i dummyIndex) AscendGreaterOrEqual(keys ...interface{}) (sql.IndexLookup, error) { - return &ascendIndexLookup{gte: keys}, nil -} - -func (i dummyIndex) AscendLessThan(keys ...interface{}) (sql.IndexLookup, error) { - return &ascendIndexLookup{lt: keys}, nil -} - -func (i dummyIndex) AscendRange(greaterOrEqual, lessThan []interface{}) (sql.IndexLookup, error) { - return &ascendIndexLookup{gte: greaterOrEqual, lt: lessThan}, nil -} - -func (i dummyIndex) DescendGreater(keys ...interface{}) (sql.IndexLookup, error) { - return &descendIndexLookup{gt: keys}, nil -} - -func (i dummyIndex) DescendLessOrEqual(keys ...interface{}) (sql.IndexLookup, error) { - return &descendIndexLookup{lte: keys}, nil -} - -func (i dummyIndex) DescendRange(lessOrEqual, greaterThan []interface{}) (sql.IndexLookup, error) { - return &descendIndexLookup{gt: greaterThan, lte: lessOrEqual}, nil -} - -func (i dummyIndex) Not(keys ...interface{}) (sql.IndexLookup, error) { - lookup, err := i.Get(keys...) - if err != nil { - return nil, err - } - - mergeable, _ := lookup.(*mergeableIndexLookup) - return &negateIndexLookup{value: mergeable.id}, nil -} - -func (i dummyIndex) Get(key ...interface{}) (sql.IndexLookup, error) { - if len(key) != 1 { - var parts = make([]string, len(key)) - for i, p := range key { - parts[i] = fmt.Sprint(p) - } - - return &mergeableIndexLookup{id: strings.Join(parts, ", ")}, nil - } - - return &mergeableIndexLookup{id: fmt.Sprint(key[0])}, nil -} -func (i dummyIndex) Has(sql.Partition, ...interface{}) (bool, error) { - panic("not implemented") -} -func (i dummyIndex) ID() string { - if len(i.expr) == 1 { - return i.expr[0].String() - } - var parts = make([]string, len(i.expr)) - for i, e := range i.expr { - parts[i] = e.String() - } - - return "(" + strings.Join(parts, ", ") + ")" -} -func (i dummyIndex) Table() string { return i.table } - -type mergedIndexLookup struct { - children []sql.IndexLookup -} - -func (mergedIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { - panic("mergedIndexLookup.Values is a placeholder") -} - -func (i *mergedIndexLookup) Indexes() []string { - var indexes []string - for _, c := range i.children { - indexes = append(indexes, c.Indexes()...) - } - return indexes -} - -func (i *mergedIndexLookup) IsMergeable(sql.IndexLookup) bool { - return true -} - -func (i *mergedIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { - return &mergedIndexLookup{append(i.children, lookups...)} -} - -func (mergedIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { - panic("mergedIndexLookup.Difference is not implemented") -} - -func (mergedIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { - panic("mergedIndexLookup.Intersection is not implemented") -} - -type negateIndexLookup struct { - value string - intersections []string - unions []string -} - -func (l *negateIndexLookup) ID() string { return "not " + l.value } -func (l *negateIndexLookup) Unions() []string { return l.unions } -func (l *negateIndexLookup) Intersections() []string { return l.intersections } - -func (*negateIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { - panic("negateIndexLookup.Values is a placeholder") -} - -func (l *negateIndexLookup) Indexes() []string { - return []string{l.ID()} -} - -func (*negateIndexLookup) IsMergeable(sql.IndexLookup) bool { - return true -} - -func (l *negateIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { - return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} -} - -func (*negateIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { - panic("negateIndexLookup.Difference is not implemented") -} - -func (l *negateIndexLookup) Intersection(indexes ...sql.IndexLookup) sql.IndexLookup { - var intersections, unions []string - for _, idx := range indexes { - intersections = append(intersections, idx.(mergeableLookup).ID()) - intersections = append(intersections, idx.(mergeableLookup).Intersections()...) - unions = append(unions, idx.(mergeableLookup).Unions()...) - } - return &mergeableIndexLookup{ - l.ID(), - append(l.unions, unions...), - append(l.intersections, intersections...), - } -} - -type ascendIndexLookup struct { - id string - gte []interface{} - lt []interface{} -} - -func (ascendIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { - panic("ascendIndexLookup.Values is a placeholder") -} - -func (l *ascendIndexLookup) Indexes() []string { - return []string{l.id} -} - -func (l *ascendIndexLookup) IsMergeable(sql.IndexLookup) bool { - return true -} - -func (l *ascendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { - return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} -} - -func (ascendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { - panic("ascendIndexLookup.Difference is not implemented") -} - -func (ascendIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { - panic("ascendIndexLookup.Intersection is not implemented") -} - -type descendIndexLookup struct { - id string - gt []interface{} - lte []interface{} -} - -func (descendIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { - panic("descendIndexLookup.Values is a placeholder") -} - -func (l *descendIndexLookup) Indexes() []string { - return []string{l.id} -} - -func (l *descendIndexLookup) IsMergeable(sql.IndexLookup) bool { - return true -} - -func (l *descendIndexLookup) Union(lookups ...sql.IndexLookup) sql.IndexLookup { - return &mergedIndexLookup{append([]sql.IndexLookup{l}, lookups...)} -} - -func (descendIndexLookup) Difference(...sql.IndexLookup) sql.IndexLookup { - panic("descendIndexLookup.Difference is not implemented") -} - -func (descendIndexLookup) Intersection(...sql.IndexLookup) sql.IndexLookup { - panic("descendIndexLookup.Intersection is not implemented") -} - func TestIndexesIntersection(t *testing.T) { require := require.New(t) - idx1, idx2 := &dummyIndex{table: "bar"}, &dummyIndex{table: "foo"} + idx1, idx2 := &memory.MergeableIndex{TableName: "bar"}, &memory.MergeableIndex{TableName: "foo"} left := map[string]*indexLookup{ - "a": &indexLookup{&mergeableIndexLookup{id: "a"}, nil}, - "b": &indexLookup{&mergeableIndexLookup{id: "b"}, []sql.Index{idx1}}, - "c": &indexLookup{new(dummyIndexLookup), nil}, + "a": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"a"}}, nil}, + "b": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"b"}}, []sql.Index{idx1}}, + "c": &indexLookup{new(DummyIndexLookup), nil}, } right := map[string]*indexLookup{ - "b": &indexLookup{&mergeableIndexLookup{id: "b2"}, []sql.Index{idx2}}, - "c": &indexLookup{&mergeableIndexLookup{id: "c"}, nil}, - "d": &indexLookup{&mergeableIndexLookup{id: "d"}, nil}, + "b": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"b2"}}, []sql.Index{idx2}}, + "c": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"c"}}, nil}, + "d": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"d"}}, nil}, } require.Equal( map[string]*indexLookup{ - "a": &indexLookup{&mergeableIndexLookup{id: "a"}, nil}, + "a": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"a"}}, nil}, "b": &indexLookup{ - &mergeableIndexLookup{ - id: "b", - intersections: []string{"b2"}, + &memory.MergedIndexLookup { + Intersections: []sql.IndexLookup { + &memory.MergeableIndexLookup{ + Key: []interface{}{"b"}, + }, + &memory.MergeableIndexLookup{ + Key: []interface{}{"b2"}, + }, + }, }, []sql.Index{idx1, idx2}, }, - "c": &indexLookup{new(dummyIndexLookup), nil}, - "d": &indexLookup{&mergeableIndexLookup{id: "d"}, nil}, + "c": &indexLookup{new(DummyIndexLookup), nil}, + "d": &indexLookup{&memory.MergeableIndexLookup{Key: []interface{}{"d"}}, nil}, }, indexesIntersection(NewDefault(sql.NewCatalog()), left, right), ) @@ -1157,70 +1177,6 @@ func TestIndexesIntersection(t *testing.T) { func TestCanMergeIndexes(t *testing.T) { require := require.New(t) - require.False(canMergeIndexes(new(mergeableIndexLookup), new(dummyIndexLookup))) - require.True(canMergeIndexes(new(mergeableIndexLookup), new(mergeableIndexLookup))) -} - -type mergeableLookup interface { - ID() string - Unions() []string - Intersections() []string -} - -type mergeableIndexLookup struct { - id string - unions []string - intersections []string -} - -var _ sql.Mergeable = (*mergeableIndexLookup)(nil) -var _ sql.SetOperations = (*mergeableIndexLookup)(nil) - -func (i *mergeableIndexLookup) ID() string { return i.id } -func (i *mergeableIndexLookup) Unions() []string { return i.unions } -func (i *mergeableIndexLookup) Intersections() []string { return i.intersections } - -func (i *mergeableIndexLookup) IsMergeable(lookup sql.IndexLookup) bool { - _, ok := lookup.(mergeableLookup) - return ok -} - -func (i *mergeableIndexLookup) Values(sql.Partition) (sql.IndexValueIter, error) { - panic("not implemented") -} - -func (i *mergeableIndexLookup) Indexes() []string { - return []string{i.ID()} -} - -func (i *mergeableIndexLookup) Difference(indexes ...sql.IndexLookup) sql.IndexLookup { - panic("not implemented") -} - -func (i *mergeableIndexLookup) Intersection(indexes ...sql.IndexLookup) sql.IndexLookup { - var intersections, unions []string - for _, idx := range indexes { - intersections = append(intersections, idx.(mergeableLookup).ID()) - intersections = append(intersections, idx.(mergeableLookup).Intersections()...) - unions = append(unions, idx.(mergeableLookup).Unions()...) - } - return &mergeableIndexLookup{ - i.id, - append(i.unions, unions...), - append(i.intersections, intersections...), - } -} - -func (i *mergeableIndexLookup) Union(indexes ...sql.IndexLookup) sql.IndexLookup { - var intersections, unions []string - for _, idx := range indexes { - unions = append(unions, idx.(*mergeableIndexLookup).id) - unions = append(unions, idx.(*mergeableIndexLookup).unions...) - intersections = append(intersections, idx.(*mergeableIndexLookup).intersections...) - } - return &mergeableIndexLookup{ - i.id, - append(i.unions, unions...), - append(i.intersections, intersections...), - } -} + require.False(canMergeIndexes(new(memory.MergeableIndexLookup), new(DummyIndexLookup))) + require.True(canMergeIndexes(new(memory.MergeableIndexLookup), new(memory.MergeableIndexLookup))) +} \ No newline at end of file diff --git a/sql/analyzer/common_test.go b/sql/analyzer/common_test.go index a5beb26c9..377e9d10b 100644 --- a/sql/analyzer/common_test.go +++ b/sql/analyzer/common_test.go @@ -29,6 +29,14 @@ func or(left, right sql.Expression) sql.Expression { return expression.NewOr(left, right) } +func in(col sql.Expression, tuple sql.Expression) sql.Expression { + return expression.NewIn(col, tuple) +} + +func tuple(vals ...sql.Expression) sql.Expression { + return expression.NewTuple(vals...) +} + func and(left, right sql.Expression) sql.Expression { return expression.NewAnd(left, right) } diff --git a/sql/analyzer/pushdown_test.go b/sql/analyzer/pushdown_test.go index 818eb40a0..c49839159 100644 --- a/sql/analyzer/pushdown_test.go +++ b/sql/analyzer/pushdown_test.go @@ -105,9 +105,9 @@ func TestPushdownIndexable(t *testing.T) { catalog := sql.NewCatalog() catalog.AddDatabase(db) - idx1 := &dummyIndex{ - "mytable", - []sql.Expression{ + idx1 := &memory.MergeableIndex{ + TableName: "mytable", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int32, "mytable", "i", false), }, } @@ -116,9 +116,9 @@ func TestPushdownIndexable(t *testing.T) { close(done) <-ready - idx2 := &dummyIndex{ - "mytable", - []sql.Expression{ + idx2 := &memory.MergeableIndex{ + TableName: "mytable", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(1, sql.Float64, "mytable", "f", false), }, } @@ -127,9 +127,9 @@ func TestPushdownIndexable(t *testing.T) { close(done) <-ready - idx3 := &dummyIndex{ - "mytable2", - []sql.Expression{ + idx3 := &memory.MergeableIndex{ + TableName: "mytable2", + Exprs: []sql.Expression{ expression.NewGetFieldWithTable(0, sql.Int32, "mytable2", "i2", false), }, } @@ -189,7 +189,23 @@ func TestPushdownIndexable(t *testing.T) { ), }).(*memory.Table). WithProjection([]string{"i", "f"}).(*memory.Table). - WithIndexLookup(&mergeableIndexLookup{id: "3.14"}), + WithIndexLookup( + // TODO: These two indexes should not be mergeable, and fetching the values of + // them will not yield correct results with the current implementation of these indexes. + &memory.MergedIndexLookup{ + Intersections: []sql.IndexLookup{ + &memory.MergeableIndexLookup{ + Key: []interface{}{float64(3.14)}, + Index: idx2, + }, + &memory.DescendIndexLookup{ + Gt: []interface{}{1}, + Index: idx1, + }, + }, + Index: idx2, + }, + ), ), plan.NewResolvedTable( table2.WithFilters([]sql.Expression{ @@ -201,7 +217,13 @@ func TestPushdownIndexable(t *testing.T) { ), }).(*memory.Table). WithProjection([]string{"i2"}).(*memory.Table). - WithIndexLookup(&negateIndexLookup{value: "2"}), + WithIndexLookup(&memory.NegateIndexLookup{ + Lookup: &memory.MergeableIndexLookup{ + Key: []interface{}{2}, + Index: idx3, + }, + Index: idx3, + }), ), ), ), diff --git a/sql/expression/function/time_test.go b/sql/expression/function/time_test.go index b9762ec1f..204a3348a 100644 --- a/sql/expression/function/time_test.go +++ b/sql/expression/function/time_test.go @@ -368,7 +368,7 @@ func TestDate(t *testing.T) { {"null date", sql.NewRow(nil), nil, false}, {"invalid type", sql.NewRow([]byte{0, 1, 2}), nil, false}, {"date as string", sql.NewRow(stringDate), "2007-01-02", false}, - {"date as time", sql.NewRow(time.Now()), time.Now().Format("2006-01-02"), false}, + {"date as time", sql.NewRow(time.Now().UTC()), time.Now().UTC().Format("2006-01-02"), false}, {"date as unix timestamp", sql.NewRow(int64(tsDate)), "2009-11-22", false}, } diff --git a/sql/parse/parse.go b/sql/parse/parse.go index f4dfe6f0a..517337ef9 100644 --- a/sql/parse/parse.go +++ b/sql/parse/parse.go @@ -545,7 +545,7 @@ func getColumn(cd *sqlparser.ColumnDefinition, indexes []*sqlparser.IndexDefinit } return &sql.Column{ - Nullable: !bool(typ.NotNull), + Nullable: !isPkey && !bool(typ.NotNull), Type: internalTyp, Name: cd.Name.String(), PrimaryKey: isPkey, diff --git a/sql/parse/parse_test.go b/sql/parse/parse_test.go index b66dca943..ab58fe7ff 100644 --- a/sql/parse/parse_test.go +++ b/sql/parse/parse_test.go @@ -72,7 +72,7 @@ var fixtures = map[string]sql.Node{ sql.Schema{{ Name: "a", Type: sql.Int32, - Nullable: true, + Nullable: false, PrimaryKey: true, }, { Name: "b", @@ -87,12 +87,12 @@ var fixtures = map[string]sql.Node{ sql.Schema{{ Name: "a", Type: sql.Int32, - Nullable: true, + Nullable: false, PrimaryKey: true, }, { Name: "b", Type: sql.Text, - Nullable: true, + Nullable: false, PrimaryKey: true, }}, ),