package mdb import ( "errors" "net/http" "net/http/httptest" "reflect" "testing" "time" ) func TestDBRunTests(t *testing.T) { t.Helper() for _, testCase := range testDBTestCases { testCase := testCase t.Run(testCase.Name, func(t *testing.T) { t.Parallel() testRunner_testCase(t, testCase) }) } } func testRunner_testCase(t *testing.T, testCase DBTestCase) { rootDir := t.TempDir() db := NewTestDBPrimary(t, rootDir) mux := http.NewServeMux() mux.HandleFunc("/rep/", db.Handle) testServer := httptest.NewServer(mux) defer testServer.Close() rootDir2 := t.TempDir() db2 := NewTestSecondaryDB(t, rootDir2, testServer.URL+"/rep/") defer db2.Close() snapshots := make([]*Snapshot, 0, len(testCase.Steps)) // Run each step and it's associated check function. for _, step := range testCase.Steps { t.Run(step.Name, func(t *testing.T) { err := db.Update(func(tx *Snapshot) error { return step.Update(t, db, tx) }) if !errors.Is(err, step.ExpectedUpdateError) { t.Fatal(err, step.ExpectedUpdateError) } snapshot := db.Snapshot() snapshots = append(snapshots, snapshot) testRunner_checkState(t, db, snapshot, step.State) }) } // Run each step's check function again with stored snapshot. for i, step := range testCase.Steps { snapshot := snapshots[i] t.Run(step.Name+"-checkSnapshot", func(t *testing.T) { testRunner_checkState(t, db, snapshot, step.State) }) } pInfo := db.Info() for { info := db2.Info() if info.SeqNum == pInfo.SeqNum { break } time.Sleep(time.Millisecond) } // TODO: Why is this necessary? time.Sleep(time.Second) finalStep := testCase.Steps[len(testCase.Steps)-1] secondarySnapshot := db2.Snapshot() t.Run("Check secondary", func(t *testing.T) { testRunner_checkState(t, db2, secondarySnapshot, finalStep.State) }) if err := db.Close(); err != nil { t.Fatal(err) } db = NewTestDBPrimary(t, rootDir) snapshot := db.Snapshot() // Run the final step's check function again with a newly loaded db. t.Run("Check after reload", func(t *testing.T) { testRunner_checkState(t, db, snapshot, finalStep.State) }) t.Run("Check that primary and secondary are equal", func(t *testing.T) { db.AssertEqual(t, db2.Database) }) db.Close() } func testRunner_checkState( t *testing.T, db TestDB, tx *Snapshot, state DBState, ) { t.Helper() checkSlicesEqual(t, "UsersByID", db.Users.ByID.Dump(tx), state.UsersByID) checkSlicesEqual(t, "UsersByEmail", db.Users.ByEmail.Dump(tx), state.UsersByEmail) checkSlicesEqual(t, "UsersByName", db.Users.ByName.Dump(tx), state.UsersByName) checkSlicesEqual(t, "UsersByBlocked", db.Users.ByBlocked.Dump(tx), state.UsersByBlocked) checkSlicesEqual(t, "DataByID", db.UserData.ByID.Dump(tx), state.DataByID) checkSlicesEqual(t, "DataByName", db.UserData.ByName.Dump(tx), state.DataByName) checkMinMaxEqual(t, "UsersByID", tx, db.Users.ByID, state.UsersByID) checkMinMaxEqual(t, "UsersByEmail", tx, db.Users.ByEmail, state.UsersByEmail) checkMinMaxEqual(t, "UsersByName", tx, db.Users.ByName, state.UsersByName) checkMinMaxEqual(t, "UsersByBlocked", tx, db.Users.ByBlocked, state.UsersByBlocked) checkMinMaxEqual(t, "DataByID", tx, db.UserData.ByID, state.DataByID) checkMinMaxEqual(t, "DataByName", tx, db.UserData.ByName, state.DataByName) } func checkSlicesEqual[T any](t *testing.T, name string, actual, expected []T) { t.Helper() if len(actual) != len(expected) { t.Fatal(name, len(actual), len(expected)) } for i := range actual { if !reflect.DeepEqual(actual[i], expected[i]) { t.Fatal(name, actual[i], expected[i]) } } } func checkMinMaxEqual[T any](t *testing.T, name string, tx *Snapshot, index *Index[T], expected []T) { if len(expected) == 0 { if min := index.Min(tx); min != nil { t.Fatal(min) } if max := index.Max(tx); max != nil { t.Fatal(max) } return } min := index.Min(tx) if min == nil { t.Fatal("No min") } max := index.Max(tx) if max == nil { t.Fatal("No max") } if !reflect.DeepEqual(*min, expected[0]) { t.Fatal(min, expected[0]) } if !reflect.DeepEqual(*max, expected[len(expected)-1]) { t.Fatal(max) } }