Skip to content

Commit

Permalink
Towards bulk.
Browse files Browse the repository at this point in the history
  • Loading branch information
ncruces committed Nov 29, 2024
1 parent a446900 commit d2445c1
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 2 deletions.
88 changes: 88 additions & 0 deletions bulk.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package sqlite3

import "github.com/ncruces/go-sqlite3/internal/util"

type Bulk struct {
c *Conn
prefix string
suffix string
buffer []byte
bufptr uint32
}

const _BULK_SIZE = 1024 * 1024

func (c *Conn) CreateBulk(prefix, suffix string) (*Bulk, error) {
if len(prefix)+len(suffix) > _BULK_SIZE/2 {
return nil, TOOBIG
}
ptr := c.new(_BULK_SIZE)
buf := util.View(c.mod, ptr, _BULK_SIZE)
copy(buf, prefix)
buf = buf[len(prefix):len(prefix)]
return &Bulk{
c: c,
prefix: prefix,
suffix: suffix,
buffer: buf,
bufptr: ptr,
}, nil
}

func (b *Bulk) Close() error {
if b == nil || b.c == nil {
return nil
}

err := b.Flush()
b.c.free(b.bufptr)
b.c = nil
return err
}

func (b *Bulk) Flush() error {
if len(b.buffer) == 0 {
return nil
}
if cap(b.buffer)-len(b.buffer) <= len(b.suffix) {
return TOOBIG
}
b.c.checkInterrupt(b.c.handle)
b.buffer = append(b.buffer, b.suffix...)
b.buffer = append(b.buffer, 0)
b.buffer = b.buffer[:0]
r := b.c.call("sqlite3_exec", uint64(b.c.handle), uint64(b.bufptr), 0, 0, 0)
return b.c.error(r)
}

func (b *Bulk) AppendRow(args ...any) error {
buf := b.buffer

if off := len(buf); off != 0 {
buf = append(buf[off:], ',')
}

buf = append(buf, '(')
for i, arg := range args {
if i != 0 {
buf = append(buf, ',')
}
buf = append(buf, Quote(arg)...)
}
buf = append(buf, ')')

if len(buf)+len(b.suffix) >= cap(b.buffer)-len(b.buffer) {
if err := b.Flush(); err != nil {
return err
}
if buf[0] == ',' {
buf = buf[1:]
}
}
if len(buf)+len(b.suffix) >= cap(b.buffer)-len(b.buffer) {
return TOOBIG
}

b.buffer = append(b.buffer, buf...)
return nil
}
60 changes: 60 additions & 0 deletions bulk_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package sqlite3_test

import (
"fmt"
"log"

"github.com/ncruces/go-sqlite3"
)

func ExampleConn_CreateBulk() {
db, err := sqlite3.Open(":memory:")
if err != nil {
log.Fatal(err)
}
defer db.Close()

err = db.Exec(`CREATE TABLE users (id INT, name VARCHAR(10))`)
if err != nil {
log.Fatal(err)
}

bulk, err := db.CreateBulk(`INSERT INTO users (id, name) VALUES`, ``)
if err != nil {
log.Fatal(err)
}
defer bulk.Close()

for _, row := range [][]any{
{0, "go"},
{1, "zig"},
{2, "whatever"},
} {
err = bulk.AppendRow(row...)
if err != nil {
log.Fatal(err)
}
}

err = bulk.Flush()
if err != nil {
log.Fatal(err)
}

stmt, _, err := db.Prepare(`SELECT id, name FROM users`)
if err != nil {
log.Fatal(err)
}
defer stmt.Close()

for stmt.Step() {
fmt.Println(stmt.ColumnInt(0), stmt.ColumnText(1))
}
if err := stmt.Err(); err != nil {
log.Fatal(err)
}
// Output:
// 0 go
// 1 zig
// 2 whatever
}
2 changes: 1 addition & 1 deletion driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ func (r *rows) Next(dest []driver.Value) error {
}

data := unsafe.Slice((*any)(unsafe.SliceData(dest)), len(dest))
err := r.Stmt.Columns(data)
err := r.Stmt.Columns(data...)
for i := range dest {
if t, ok := r.decodeTime(i, dest[i]); ok {
dest[i] = t
Expand Down
2 changes: 1 addition & 1 deletion stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ func (s *Stmt) ColumnValue(col int) Value {
// [TEXT] as string, and [BLOB] as []byte.
// Any []byte are owned by SQLite and may be invalidated by
// subsequent calls to [Stmt] methods.
func (s *Stmt) Columns(dest []any) error {
func (s *Stmt) Columns(dest ...any) error {
defer s.c.arena.mark()()
count := uint64(len(dest))
typePtr := s.c.arena.new(count)
Expand Down

0 comments on commit d2445c1

Please sign in to comment.