diff --git a/storage.go b/storage.go index b83fd81..66e2af4 100644 --- a/storage.go +++ b/storage.go @@ -68,9 +68,6 @@ func (sto *Storage) Load(key string) ([]byte, error) { func (sto *Storage) Delete(key string) error { err := sto.DB.Update(func(txn *badger.Txn) error { k := []byte(key) - if _, err := txn.Get(k); err == badger.ErrKeyNotFound { - return err - } return txn.Delete(k) }) if err != nil { @@ -90,24 +87,26 @@ func (sto *Storage) Exists(key string) bool { // List implements certmagic.Storage.List func (sto *Storage) List(prefix string, recursive bool) ([]string, error) { + seen := map[string]bool{} var keys []string err := sto.DB.View(func(txn *badger.Txn) error { - dir := make([]byte, 0, len(prefix)+1) - dir = append(dir, prefix...) - dir = append(dir, '/') - it := txn.NewIterator(badger.IteratorOptions{Prefix: dir}) + pfx := make([]byte, 0, len(prefix)+1) + pfx = append(pfx, prefix...) + pfx = append(pfx, '/') + it := txn.NewIterator(badger.IteratorOptions{Prefix: pfx}) defer it.Close() it.Rewind() if !it.Valid() { return badger.ErrKeyNotFound } for ; it.Valid(); it.Next() { - itm := it.Item() - key := itm.Key() - fn := bytes.TrimPrefix(key, dir) - if len(fn) != 0 && (recursive || !bytes.Contains(fn, []byte{'/'})) { - keys = append(keys, string(key)) - } + walkKey(it.Item().Key(), len(pfx), recursive, func(k []byte) { + if seen[string(k)] { + return + } + seen[string(k)] = true + keys = append(keys, string(k)) + }) } return nil }) @@ -153,3 +152,18 @@ func (sto *Storage) Stat(key string) (certmagic.KeyInfo, error) { func New(db *badger.DB) *Storage { return &Storage{DB: db} } + +func walkKey(k []byte, sp int, recursive bool, f func([]byte)) { + if sp >= len(k) { + return + } + if i := bytes.IndexByte(k[sp:], '/'); i >= 0 { + sp += i + } else { + sp = len(k) + } + f(k[:sp]) + if recursive { + walkKey(k, sp+1, recursive, f) + } +} diff --git a/storage_test.go b/storage_test.go index a77f10d..1599f08 100644 --- a/storage_test.go +++ b/storage_test.go @@ -1,6 +1,7 @@ package badgerstorage import ( + "fmt" "github.com/caddyserver/certmagic" "github.com/dgraph-io/badger/v2" tests "github.com/oyato/certmagic-storage-tests" @@ -26,5 +27,47 @@ func TestStorage(t *testing.T) { if err != nil { t.Fatalf("Cannot open badger memory DB: %s", err) } - tests.NewTestSuite(New(db)).Run(t) + sto := New(db) + tests.NewTestSuite(sto).Run(t) + if err := sto.Delete(""); err == nil { + t.Fatalf("Storage.Delete with empty key should fail") + } +} + +func TestWalkKey(t *testing.T) { + pfx := "dir/" + tbl := []struct { + rec bool + key string + exp []string + }{ + {false, "", []string{}}, + {false, "a/1/2", []string{"a"}}, + {false, "b/3", []string{"b"}}, + {false, "c", []string{"c"}}, + {true, "", []string{}}, + {true, "a/1/2", []string{"a", "a/1", "a/1/2"}}, + {true, "b/3", []string{"b", "b/3"}}, + {true, "c", []string{"c"}}, + } + for _, tst := range tbl { + if tst.key != "" { + tst.key = pfx + tst.key + } + for i, s := range tst.exp { + tst.exp[i] = pfx + s + } + + ls := []string{} + walkKey([]byte(tst.key), len(pfx), tst.rec, func(k []byte) { + ls = append(ls, string(k)) + }) + got := fmt.Sprintf("%#q", ls) + exp := fmt.Sprintf("%#q", tst.exp) + if got != exp { + t.Errorf("walkKey(%#q, %d, %v): should return %s, not %s", + tst.key, len(pfx), tst.rec, exp, got, + ) + } + } }