From 5ffd50bdea6ccdcfeb2656d73439d25a01487bd0 Mon Sep 17 00:00:00 2001 From: jdl Date: Tue, 19 Dec 2023 13:59:57 +0100 Subject: [PATCH] UpsertFunc --- mdb/collection.go | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/mdb/collection.go b/mdb/collection.go index 27d06c4..40b603d 100644 --- a/mdb/collection.go +++ b/mdb/collection.go @@ -287,6 +287,36 @@ func (c *Collection[T]) upsert(tx *Snapshot, item *T) error { return err } +func (c *Collection[T]) UpsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error { + if tx == nil { + c.db.Update(func(tx *Snapshot) error { + return c.upsertFunc(tx, id, update) + }) + } + return c.upsertFunc(tx, id, update) +} + +func (c *Collection[T]) upsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error { + insert := false + + item := c.Get(tx, id) + if item == nil { + item = new(T) + insert = true + } + + if err := update(item); err != nil { + return err + } + + c.setID(item, id) // Don't allow the ID to change. + + if insert { + return c.insert(tx, item) + } + return c.update(tx, item) +} + func (c *Collection[T]) Delete(tx *Snapshot, itemID uint64) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error {