From 1677b97fa4e58527898cf8014a8ae83302eacb3a Mon Sep 17 00:00:00 2001 From: Nuno Cruces Date: Sun, 19 Jan 2025 01:28:53 +0000 Subject: [PATCH] Fix #215. --- ext/stats/mode.go | 112 +++++++++++++++++++++++++++++++++++++++++ ext/stats/mode_test.go | 85 +++++++++++++++++++++++++++++++ ext/stats/stats.go | 11 ++-- 3 files changed, 204 insertions(+), 4 deletions(-) create mode 100644 ext/stats/mode.go create mode 100644 ext/stats/mode_test.go diff --git a/ext/stats/mode.go b/ext/stats/mode.go new file mode 100644 index 00000000..21eedbaf --- /dev/null +++ b/ext/stats/mode.go @@ -0,0 +1,112 @@ +package stats + +import ( + "unsafe" + + "github.com/ncruces/go-sqlite3" +) + +func newMode() sqlite3.AggregateFunction { + return &mode{} +} + +type mode struct { + ints counter[int64] + reals counter[float64] + texts counter[string] + blobs counter[string] +} + +func (m mode) Value(ctx sqlite3.Context) { + var ( + max = 0 + typ = sqlite3.NULL + i64 int64 + f64 float64 + str string + ) + for k, v := range m.ints { + if v > max || v == max && k < i64 { + typ = sqlite3.INTEGER + max = v + i64 = k + } + } + f64 = float64(i64) + for k, v := range m.reals { + if v > max || v == max && k < f64 { + typ = sqlite3.FLOAT + max = v + f64 = k + } + } + for k, v := range m.texts { + if v > max || v == max && typ == sqlite3.TEXT && k < str { + typ = sqlite3.TEXT + max = v + str = k + } + } + for k, v := range m.blobs { + if v > max || v == max && typ == sqlite3.BLOB && k < str { + typ = sqlite3.BLOB + max = v + str = k + } + } + switch typ { + case sqlite3.INTEGER: + ctx.ResultInt64(i64) + case sqlite3.FLOAT: + ctx.ResultFloat(f64) + case sqlite3.TEXT: + ctx.ResultText(str) + case sqlite3.BLOB: + ctx.ResultBlob(unsafe.Slice(unsafe.StringData(str), len(str))) + } +} + +func (b *mode) Step(ctx sqlite3.Context, arg ...sqlite3.Value) { + switch arg[0].Type() { + case sqlite3.INTEGER: + b.ints.add(arg[0].Int64()) + case sqlite3.FLOAT: + b.reals.add(arg[0].Float()) + case sqlite3.TEXT: + b.texts.add(arg[0].Text()) + case sqlite3.BLOB: + b.blobs.add(string(arg[0].RawBlob())) + } +} + +func (b *mode) Inverse(ctx sqlite3.Context, arg ...sqlite3.Value) { + switch arg[0].Type() { + case sqlite3.INTEGER: + b.ints.del(arg[0].Int64()) + case sqlite3.FLOAT: + b.reals.del(arg[0].Float()) + case sqlite3.TEXT: + b.texts.del(arg[0].Text()) + case sqlite3.BLOB: + b.blobs.del(string(arg[0].RawBlob())) + } +} + +type counter[T comparable] map[T]int + +func (c *counter[T]) add(k T) { + if (*c) == nil { + (*c) = make(counter[T]) + } + (*c)[k]++ +} + +func (c counter[T]) del(k T) { + switch n := c[k]; n { + default: + c[k] = n - 1 + case 1: + delete(c, k) + case 0: + } +} diff --git a/ext/stats/mode_test.go b/ext/stats/mode_test.go new file mode 100644 index 00000000..3e2bfae7 --- /dev/null +++ b/ext/stats/mode_test.go @@ -0,0 +1,85 @@ +package stats_test + +import ( + "testing" + + "github.com/ncruces/go-sqlite3" + _ "github.com/ncruces/go-sqlite3/embed" + _ "github.com/ncruces/go-sqlite3/internal/testcfg" +) + +func TestRegister_mode(t *testing.T) { + t.Parallel() + + db, err := sqlite3.Open(":memory:") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + stmt, _, err := db.Prepare(`SELECT mode(column1) FROM (VALUES (NULL), (1), (NULL), (2), (NULL), (3), (3))`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnInt(0); got != 3 { + t.Errorf("got %v, want 3", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (1), (1), (2), (2), (3))`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnInt(0); got != 1 { + t.Errorf("got %v, want 1", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (0.5), (1), (2.5), (2), (2.5))`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnFloat(0); got != 2.5 { + t.Errorf("got %v, want 2.5", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES ('red'), ('green'), ('blue'), ('red'))`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnText(0); got != "red" { + t.Errorf("got %q, want red", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(`SELECT mode(column1) FROM (VALUES (X'cafebabe'), ('green'), ('blue'), (X'cafebabe'))`) + if err != nil { + t.Fatal(err) + } + if stmt.Step() { + if got := stmt.ColumnText(0); got != "\xca\xfe\xba\xbe" { + t.Errorf("got %q, want cafebabe", got) + } + } + stmt.Close() + + stmt, _, err = db.Prepare(` + SELECT mode(column1) OVER (ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) + FROM (VALUES (1), (1), (2.5), ('blue'), (X'cafebabe'), (1), (1)) + `) + if err != nil { + t.Fatal(err) + } + for stmt.Step() { + } + stmt.Close() +} diff --git a/ext/stats/stats.go b/ext/stats/stats.go index 4edfee33..2110a52c 100644 --- a/ext/stats/stats.go +++ b/ext/stats/stats.go @@ -18,9 +18,11 @@ // - regr_slope: slope of the least-squares-fit linear equation // - regr_intercept: y-intercept of the least-squares-fit linear equation // - regr_json: all regr stats in a JSON object -// - percentile_disc: discrete percentile -// - percentile_cont: continuous percentile -// - median: median value +// - percentile_disc: discrete quantile +// - percentile_cont: continuous quantile +// - percentile: continuous percentile +// - median: middle value +// - mode: most frequent value // - every: boolean and // - some: boolean or // @@ -77,7 +79,8 @@ func Register(db *sqlite3.Conn) error { db.CreateWindowFunction("percentile_cont", 2, order, newPercentile(percentile_cont)), db.CreateWindowFunction("percentile_disc", 2, order, newPercentile(percentile_disc)), db.CreateWindowFunction("every", 1, flags, newBoolean(every)), - db.CreateWindowFunction("some", 1, flags, newBoolean(some))) + db.CreateWindowFunction("some", 1, flags, newBoolean(some)), + db.CreateWindowFunction("mode", 1, order, newMode)) } const (