From 6b0b7408bc1635b3f817740edbd407c65cedf718 Mon Sep 17 00:00:00 2001 From: jdl Date: Sun, 24 Dec 2023 20:41:43 +0100 Subject: [PATCH] iterfunc experiment --- go.mod | 2 +- mdb/db-testcases_test.go | 31 ++++++++++++------------------ mdb/equality_test.go | 13 +++++-------- mdb/index.go | 41 ++++++++++++++++++++++++---------------- mdb/index_test.go | 5 ++--- 5 files changed, 45 insertions(+), 47 deletions(-) diff --git a/go.mod b/go.mod index 31caf80..3b5ce6d 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module git.crumpington.com/public/jldb -go 1.21.1 +go 1.22 require ( github.com/google/btree v1.1.2 diff --git a/mdb/db-testcases_test.go b/mdb/db-testcases_test.go index d5af3e8..5f52cb9 100644 --- a/mdb/db-testcases_test.go +++ b/mdb/db-testcases_test.go @@ -743,29 +743,25 @@ var testDBTestCases = []DBTestCase{{ first := true pivot := User{Name: "User1"} - db.Users.ByName.AscendAfter(tx, &pivot, func(u *User) bool { + for u := range db.Users.ByName.AscendAfter(tx, &pivot) { u.Name += "Mod" if err = db.Users.Update(tx, u); err != nil { - return false + return err } if first { first = false - return true + continue } prev := db.Users.ByID.Get(tx, &User{ID: u.ID - 1}) if prev == nil { - err = errors.New("Previous user not found") - return false + return errors.New("Previous user not found") } if !strings.HasSuffix(prev.Name, "Mod") { - err = errors.New("Incorrect user name: " + prev.Name) - return false + return errors.New("Incorrect user name: " + prev.Name) } - - return true - }) + } return nil }, @@ -801,29 +797,26 @@ var testDBTestCases = []DBTestCase{{ } first := true - db.Users.ByName.DescendAfter(tx, &User{Name: "User5Mod"}, func(u *User) bool { + for u := range db.Users.ByName.DescendAfter(tx, &User{Name: "User5Mod"}) { u.Name = strings.TrimSuffix(u.Name, "Mod") if err = db.Users.Update(tx, u); err != nil { - return false + return err } if first { first = false - return true + continue } prev := db.Users.ByID.Get(tx, &User{ID: u.ID + 1}) if prev == nil { - err = errors.New("Previous user not found") - return false + return errors.New("Previous user not found") } if strings.HasSuffix(prev.Name, "Mod") { - err = errors.New("Incorrect user name: " + prev.Name) - return false + return errors.New("Incorrect user name: " + prev.Name) } + } - return true - }) return nil }, diff --git a/mdb/equality_test.go b/mdb/equality_test.go index 1291e31..faa3777 100644 --- a/mdb/equality_test.go +++ b/mdb/equality_test.go @@ -1,7 +1,6 @@ package mdb import ( - "fmt" "reflect" "testing" ) @@ -20,18 +19,16 @@ func (i Index[T]) AssertEqual(t *testing.T, tx1, tx2 *Snapshot) { } errStr := "" - i.Ascend(tx1, func(item1 *T) bool { + iter := i.Ascend(tx1) + for item1 := range iter { item2 := i.Get(tx2, item1) if item2 == nil { - errStr = fmt.Sprintf("Indices don't match. %v not found.", item1) - return false + t.Fatalf("Indices don't match. %v not found.", item1) } if !reflect.DeepEqual(item1, item2) { - errStr = fmt.Sprintf("%v != %v", item1, item2) - return false + t.Fatalf("%v != %v", item1, item2) } - return true - }) + } if errStr != "" { t.Fatal(errStr) diff --git a/mdb/index.go b/mdb/index.go index 431af93..153a241 100644 --- a/mdb/index.go +++ b/mdb/index.go @@ -1,6 +1,7 @@ package mdb import ( + "iter" "unsafe" "github.com/google/btree" @@ -111,32 +112,40 @@ func (i *Index[T]) Max(tx *Snapshot) *T { return nil } -func (i *Index[T]) Ascend(tx *Snapshot, each func(*T) bool) { +func (i *Index[T]) Ascend(tx *Snapshot) iter.Seq[*T] { tx = i.ensureSnapshot(tx) - i.btreeForIter(tx).Ascend(func(t *T) bool { - return each(i.copy(t)) - }) + return func(yield func(*T) bool) { + i.btreeForIter(tx).Ascend(func(t *T) bool { + return yield(i.copy(t)) + }) + } } -func (i *Index[T]) AscendAfter(tx *Snapshot, after *T, each func(*T) bool) { +func (i *Index[T]) AscendAfter(tx *Snapshot, after *T) iter.Seq[*T] { tx = i.ensureSnapshot(tx) - i.btreeForIter(tx).AscendGreaterOrEqual(after, func(t *T) bool { - return each(i.copy(t)) - }) + return func(yield func(*T) bool) { + i.btreeForIter(tx).AscendGreaterOrEqual(after, func(t *T) bool { + return yield(i.copy(t)) + }) + } } -func (i *Index[T]) Descend(tx *Snapshot, each func(*T) bool) { +func (i *Index[T]) Descend(tx *Snapshot) iter.Seq[*T] { tx = i.ensureSnapshot(tx) - i.btreeForIter(tx).Descend(func(t *T) bool { - return each(i.copy(t)) - }) + return func(yield func(*T) bool) { + i.btreeForIter(tx).Descend(func(t *T) bool { + return yield(i.copy(t)) + }) + } } -func (i *Index[T]) DescendAfter(tx *Snapshot, after *T, each func(*T) bool) { +func (i *Index[T]) DescendAfter(tx *Snapshot, after *T) iter.Seq[*T] { tx = i.ensureSnapshot(tx) - i.btreeForIter(tx).DescendLessOrEqual(after, func(t *T) bool { - return each(i.copy(t)) - }) + return func(yield func(*T) bool) { + i.btreeForIter(tx).DescendLessOrEqual(after, func(t *T) bool { + return yield(i.copy(t)) + }) + } } func (i *Index[T]) Count(tx *Snapshot) int { diff --git a/mdb/index_test.go b/mdb/index_test.go index df7576c..a370b9a 100644 --- a/mdb/index_test.go +++ b/mdb/index_test.go @@ -1,9 +1,8 @@ package mdb func (i Index[T]) Dump(tx *Snapshot) (l []T) { - i.Ascend(tx, func(t *T) bool { + for t := range i.Ascend(tx) { l = append(l, *t) - return true - }) + } return l }