diff --git a/.gitignore b/.gitignore index 6c3169f..74f1bfa 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,6 @@ src/*.html *.so build/ + +.idea/ +.cache/ \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..527d345 --- /dev/null +++ b/Makefile @@ -0,0 +1,6 @@ +.PHONY: build + +build: + ./update_c.sh + python setup.py build + python setup.py build_ext --inplace \ No newline at end of file diff --git a/libdatrie/datrie/libdatrie.def b/libdatrie/datrie/libdatrie.def index 185bda5..a0f0566 100644 --- a/libdatrie/datrie/libdatrie.def +++ b/libdatrie/datrie/libdatrie.def @@ -10,6 +10,7 @@ trie_fread trie_free trie_save trie_fwrite +trie_size trie_is_dirty trie_retrieve trie_store diff --git a/libdatrie/datrie/trie.c b/libdatrie/datrie/trie.c index 37e95a6..05250e1 100644 --- a/libdatrie/datrie/trie.c +++ b/libdatrie/datrie/trie.c @@ -44,6 +44,7 @@ struct _Trie { DArray *da; Tail *tail; + uint32 size; Bool is_dirty; }; @@ -133,6 +134,7 @@ trie_new (const AlphaMap *alpha_map) if (UNLIKELY (!trie->tail)) goto exit_da_created; + trie->size = 0; trie->is_dirty = TRUE; return trie; @@ -203,6 +205,11 @@ trie_fread (FILE *file) if (NULL == (trie->tail = tail_fread (file))) goto exit_da_created; + uint32 counter = 0; + if (!trie_enumerate (trie, len_enumerator, &counter)) { + goto exit_trie_created; + } + trie->size = counter; trie->is_dirty = FALSE; return trie; @@ -290,6 +297,19 @@ trie_fwrite (Trie *trie, FILE *file) return 0; } +/** + * @brief Check pending changes + * + * @param trie : the trie object + * + * @return total count of trie keys + */ +uint32 +trie_size (const Trie *trie) +{ + return trie->size; +} + /** * @brief Check pending changes * @@ -431,6 +451,9 @@ trie_store_conditionally (Trie *trie, res = trie_branch_in_branch (trie, s, key_str, data); free (key_str); + if (res) { + trie->size++; + } return res; } if (0 == *p) @@ -455,6 +478,9 @@ trie_store_conditionally (Trie *trie, res = trie_branch_in_tail (trie, s, tail_str, data); free (tail_str); + if (res) { + trie->size++; + } return res; } if (0 == *p) @@ -580,6 +606,7 @@ trie_delete (Trie *trie, const AlphaChar *key) da_set_base (trie->da, s, TRIE_INDEX_ERROR); da_prune (trie->da, s); + trie->size--; trie->is_dirty = TRUE; return TRUE; } @@ -630,6 +657,14 @@ trie_enumerate (const Trie *trie, TrieEnumFunc enum_func, void *user_data) } +Bool +len_enumerator (const AlphaChar *key, TrieData key_data, uint32 *counter_ptr) +{ + (*counter_ptr)++; + return TRUE; +} + + /*-------------------------------* * STEPWISE QUERY OPERATIONS * *-------------------------------*/ diff --git a/libdatrie/datrie/trie.h b/libdatrie/datrie/trie.h index da16483..91398f8 100644 --- a/libdatrie/datrie/trie.h +++ b/libdatrie/datrie/trie.h @@ -129,6 +129,8 @@ int trie_save (Trie *trie, const char *path); int trie_fwrite (Trie *trie, FILE *file); +uint32 trie_size (const Trie *trie); + Bool trie_is_dirty (const Trie *trie); @@ -150,6 +152,10 @@ Bool trie_enumerate (const Trie *trie, TrieEnumFunc enum_func, void *user_data); +Bool len_enumerator (const AlphaChar *key, + TrieData key_data, + uint32 *counter_ptr); + /*-------------------------------* * STEPWISE QUERY OPERATIONS * diff --git a/libdatrie/tests/test_store-retrieve.c b/libdatrie/tests/test_store-retrieve.c index fed7213..f967a4d 100644 --- a/libdatrie/tests/test_store-retrieve.c +++ b/libdatrie/tests/test_store-retrieve.c @@ -48,6 +48,34 @@ main () goto err_trie_not_created; } + msg_step ("Check initial trie size"); + if (trie_size(test_trie) != 0) { + printf ("Wrong trie size; expected 0, got %d.\n", trie_size(test_trie)); + goto err_trie_size; + } + + msg_step ("Delete non-existent key from trie and check size"); + trie_delete (test_trie, (AlphaChar *)L"a"); + if (trie_size(test_trie) != 0) { + printf ("Wrong trie size; expected 0, got %d.\n", trie_size(test_trie)); + goto err_trie_size; + } + msg_step ("Add non-existent key with trie_store_if_absent and check size"); + if (!trie_store_if_absent (test_trie, (AlphaChar *)L"a", TRIE_DATA_UNREAD)) { + printf ("Failed to add non-existing key '%ls'.\n", (AlphaChar *)L"a"); + goto err_trie_created; + } + if (trie_size(test_trie) != 1) { + printf ("Wrong trie size; expected 1, got %d.\n", trie_size(test_trie)); + goto err_trie_size; + } + msg_step ("Delete existing key from trie and check size"); + trie_delete (test_trie, (AlphaChar *)L"a"); + if (trie_size(test_trie) != 0) { + printf ("Wrong trie size; expected 0, got %d.\n", trie_size(test_trie)); + goto err_trie_size; + } + /* store */ msg_step ("Adding data to trie"); for (dict_p = dict_src; dict_p->key; dict_p++) { @@ -58,6 +86,51 @@ main () } } + msg_step ("Check trie size"); + if (trie_size(test_trie) != dict_src_n_entries()) { + printf ("Wrong trie size; expected %d, got %d.\n", + dict_src_n_entries(), trie_size(test_trie)); + goto err_trie_size; + } + + msg_step ("Update existing trie element and check trie size"); + if (!trie_store (test_trie, dict_src[1].key, dict_src[1].data)) { + printf ("Failed to add key '%ls', data %d.\n", + dict_src[1].key, dict_src[1].data); + goto err_trie_created; + } + if (trie_size(test_trie) != dict_src_n_entries()) { + printf ("Wrong trie size; expected %d, got %d.\n", + dict_src_n_entries(), trie_size(test_trie)); + goto err_trie_size; + } + + msg_step ("Update existing trie element with trie_store_if_absent and check trie size"); + if (trie_store_if_absent (test_trie, dict_src[1].key, dict_src[1].data)) { + printf ("Value for existing key '%ls' was updated with trie_store_if_absent.\n", + dict_src[1].key); + goto err_trie_created; + } + if (trie_size(test_trie) != dict_src_n_entries()) { + printf ("Wrong trie size; expected %d, got %d.\n", + dict_src_n_entries(), trie_size(test_trie)); + goto err_trie_size; + } + + msg_step ("Add trie element with wrong alphabet and check trie size"); + if (trie_store (test_trie, (AlphaChar *)L"я", TRIE_DATA_UNREAD)) { + printf ("Key '%ls' with wrong alphabet was added.\n", + (AlphaChar *)L"я"); + goto err_trie_created; + } + if (trie_size(test_trie) != dict_src_n_entries()) { + printf ("Wrong trie size; expected %d, got %d.\n", + dict_src_n_entries(), trie_size(test_trie)); + goto err_trie_size; + } + + // TODO: add key with wrong alphabet and check size? + /* retrieve */ msg_step ("Retrieving data from trie"); is_failed = FALSE; @@ -99,6 +172,14 @@ main () goto err_trie_created; } + msg_step ("Check trie size after deleting some entries."); + if (trie_size(test_trie) != (n_entries - (n_entries/3 + 1))) { + printf ("Wrong trie size; expected %d, got %d.\n", + (n_entries - (n_entries/3 + 1)), trie_size(test_trie)); + goto err_trie_size; + } + + /* retrieve */ msg_step ("Retrieving data from trie again after deletions"); for (dict_p = dict_src; dict_p->key; dict_p++) { @@ -192,6 +273,8 @@ main () trie_state_free (trie_root_state); err_trie_created: trie_free (test_trie); +err_trie_size: + trie_free (test_trie); err_trie_not_created: return 1; } diff --git a/src/datrie.pyx b/src/datrie.pyx index b55447c..42bae13 100644 --- a/src/datrie.pyx +++ b/src/datrie.pyx @@ -3,6 +3,7 @@ Cython wrapper for libdatrie. """ +from cpython cimport bool from cpython.version cimport PY_MAJOR_VERSION from cython.operator import dereference as deref from libc.stdlib cimport malloc, free @@ -15,13 +16,18 @@ import itertools import warnings import sys import tempfile -from collections import MutableMapping +from collections import MutableMapping, Set, Sized try: import cPickle as pickle except ImportError: import pickle +try: + base_str = basestring +except NameError: + base_str = str + class DatrieError(Exception): pass @@ -224,18 +230,8 @@ cdef class BaseTrie: if not found: raise KeyError(key) - @staticmethod - cdef int len_enumerator(cdatrie.AlphaChar *key, cdatrie.TrieData key_data, - void *counter_ptr): - (counter_ptr)[0] += 1 - return True - def __len__(self): - cdef int counter = 0 - cdatrie.trie_enumerate(self._c_trie, - (self.len_enumerator), - &counter) - return counter + return cdatrie.trie_size(self._c_trie) def __richcmp__(self, other, int op): if op == 2: # == @@ -554,84 +550,36 @@ cdef class BaseTrie: finally: cdatrie.trie_state_free(state) + def __iter__(self): + cdef BaseIterator iter = BaseIterator(BaseState(self)) + while iter.next(): + yield iter.key() + cpdef items(self, unicode prefix=None): """ - Returns a list of this trie's items (``(key,value)`` tuples). + D.items() -> a set-like object providing a view on D's items. If ``prefix`` is not None, returns only the items associated with keys prefixed by ``prefix``. """ - cdef bint success - cdef list res = [] - cdef BaseState state = BaseState(self) - - if prefix is not None: - success = state.walk(prefix) - if not success: - return res - - cdef BaseIterator iter = BaseIterator(state) - - if prefix is None: - while iter.next(): - res.append((iter.key(), iter.data())) - else: - while iter.next(): - res.append((prefix+iter.key(), iter.data())) - - return res - - def __iter__(self): - cdef BaseIterator iter = BaseIterator(BaseState(self)) - while iter.next(): - yield iter.key() + return BaseTrieItemsView(self, prefix) cpdef keys(self, unicode prefix=None): """ - Returns a list of this trie's keys. + D.keys() -> a set-like object providing a view on D's keys. If ``prefix`` is not None, returns only the keys prefixed by ``prefix``. """ - cdef bint success - cdef list res = [] - cdef BaseState state = BaseState(self) - - if prefix is not None: - success = state.walk(prefix) - if not success: - return res - - cdef BaseIterator iter = BaseIterator(state) - - if prefix is None: - while iter.next(): - res.append(iter.key()) - else: - while iter.next(): - res.append(prefix+iter.key()) - - return res + return BaseTrieKeysView(self, prefix) cpdef values(self, unicode prefix=None): """ - Returns a list of this trie's values. + D.values() -> an object providing a view on D's values If ``prefix`` is not None, returns only the values associated with keys prefixed by ``prefix``. """ - cdef bint success - cdef list res = [] - cdef BaseState state = BaseState(self) - - if prefix is not None: - success = state.walk(prefix) - if not success: - return res - - cdef BaseIterator iter = BaseIterator(state) - while iter.next(): - res.append(iter.data()) - return res + return BaseTrieValuesView(self, prefix) cdef _index_to_value(self, cdatrie.TrieData index): return index @@ -728,67 +676,24 @@ cdef class Trie(BaseTrie): cpdef items(self, unicode prefix=None): """ - Returns a list of this trie's items (``(key,value)`` tuples). + D.items() -> a set-like object providing a view on D's items. If ``prefix`` is not None, returns only the items associated with keys prefixed by ``prefix``. """ - # the following code is - # - # [(k, self._values[v]) for (k,v) in BaseTrie.items(self, prefix)] - # - # but inlined for speed. - - cdef bint success - cdef list res = [] - cdef BaseState state = BaseState(self) - - if prefix is not None: - success = state.walk(prefix) - if not success: - return res - - cdef BaseIterator iter = BaseIterator(state) - - if prefix is None: - while iter.next(): - res.append((iter.key(), self._values[iter.data()])) - else: - while iter.next(): - res.append((prefix+iter.key(), self._values[iter.data()])) - - return res + return TrieItemsView(self, prefix) cpdef values(self, unicode prefix=None): """ - Returns a list of this trie's values. + D.values() -> an object providing a view on D's values If ``prefix`` is not None, returns only the values associated with keys prefixed by ``prefix``. """ - # the following code is - # - # [self._values[v] for v in BaseTrie.values(self, prefix)] - # - # but inlined for speed. - - cdef list res = [] - cdef BaseState state = BaseState(self) - cdef bint success - - if prefix is not None: - success = state.walk(prefix) - if not success: - return res - - cdef BaseIterator iter = BaseIterator(state) + return TrieValuesView(self, prefix) - while iter.next(): - res.append(self._values[iter.data()]) - - return res def longest_prefix_item(self, unicode key, default=RAISE_KEY_ERROR): """ @@ -865,6 +770,9 @@ cdef class _TrieState: if self._state is not NULL: cdatrie.trie_state_free(self._state) + cpdef get_tree(self): + return self._trie + cpdef walk(self, unicode to): cdef bint res for ch in to: @@ -980,6 +888,194 @@ cdef class Iterator(_TrieIterator): return self._root._trie._index_to_value(data) +class BaseTrieMappingView(Sized): + + __slots__ = ('_state', '_prefix') + + def __init__(self, base_trie, prefix=None): + cdef BaseState state = BaseState(base_trie) + self._state = state + self._prefix = prefix + + def _rewind_state(self, new_state): + """ + Reset state to root. Next try to walk to new state, if `new_state` + is not None. + """ + self._state.rewind() + if new_state is not None: + if (not isinstance(new_state, base_str) or + not self._state.walk(new_state)): + return False + return True + + def __len__(self): + """O(n) if prefix is defined""" + if self._prefix is None: + return len(self._state.get_tree()) + cdef int count = 0 + cdef _TrieIterator it + if self._rewind_state(self._prefix): + it = _TrieIterator(self._state) + while it.next(): + count += 1 + return count + + +class BaseTrieKeysView(BaseTrieMappingView, Set): + + __slots__ = () + + @classmethod + def _from_iterable(cls, it): + return set(it) + + def __contains__(self, item): + if self._prefix and not item.startswith(self._prefix): + return False + if self._rewind_state(item) and self._state.is_terminal(): + return True + return False + + def __iter__(self): + if not self._rewind_state(self._prefix): + raise StopIteration + cdef _TrieIterator it = _TrieIterator(self._state) + while it.next(): + if self._prefix is None: + yield it.key() + else: + yield self._prefix + it.key() + + def __eq__(self, other): + # Fail-fast version + if other is self: + return True + elif not isinstance(other, Set): + # No TypeError for equality + return False + count = 0 + for elem in self: + count += 1 + if elem not in other: + return False + return count == len(other) + + +class BaseTrieItemsView(BaseTrieMappingView, Set): + + __slots__ = () + + @classmethod + def _from_iterable(cls, it): + return set(it) + + def __contains__(self, item): + key, value = item + if self._prefix and not key.startswith(self._prefix): + return False + if self._rewind_state(key) and self._state.is_terminal(): + v = self._state.data() + return v is value or v == value + return False + + def __iter__(self): + if not self._rewind_state(self._prefix): + raise StopIteration + cdef BaseIterator it = BaseIterator(self._state) + while it.next(): + if self._prefix is None: + yield (it.key(), it.data()) + else: + yield (self._prefix + it.key(), it.data()) + + +class BaseTrieValuesView(BaseTrieMappingView): + + __slots__ = () + + def __contains__(self, value): + if self._prefix and not value.startswith(self._prefix): + return False + for v in self: + if v is value or v == value: + return True + return False + + def __iter__(self): + if not self._rewind_state(self._prefix): + raise StopIteration + cdef BaseIterator it = BaseIterator(self._state) + while it.next(): + yield it.data() + + +class TrieMappingView(BaseTrieMappingView): + + __slots__ = () + + def __init__(self, base_trie, prefix=None): + cdef State state = State(base_trie) + self._state = state + self._prefix = prefix + + +class TrieKeysView(BaseTrieKeysView): + pass + + +class TrieItemsView(TrieMappingView, Set): + + __slots__ = () + + @classmethod + def _from_iterable(cls, it): + return set(it) + + def __contains__(self, item): + key, value = item + if self._prefix and not key.startswith(self._prefix): + return False + if self._rewind_state(key) and self._state.is_terminal(): + v = self._state.data() + return v is value or v == value + return False + + def __iter__(self): + if not self._rewind_state(self._prefix): + raise StopIteration + cdef Iterator it = Iterator(self._state) + while it.next(): + if self._prefix is None: + yield (it.key(), it.data()) + else: + yield (self._prefix + it.key(), it.data()) + + +class TrieValuesView(TrieMappingView): + + __slots__ = () + + @classmethod + def _from_iterable(cls, it): + return set(it) + + def __contains__(self, value): + if self._prefix and not value.startswith(self._prefix): + return False + for v in self: + if v is value or v == value: + return True + return False + + def __iter__(self): + if not self._rewind_state(self._prefix): + raise StopIteration + cdef Iterator it = Iterator(self._state) + while it.next(): + yield it.data() + + cdef (cdatrie.Trie* ) _load_from_file(f) except NULL: cdef int fd = f.fileno() cdef stdio.FILE* f_ptr = stdio_ext.fdopen(fd, "r") diff --git a/tests/test_trie.py b/tests/test_trie.py index 0b3f039..f6ec663 100644 --- a/tests/test_trie.py +++ b/tests/test_trie.py @@ -141,9 +141,9 @@ def test_trie_items(): trie['foo'] = 10 trie['bar'] = 'foo' trie['foobar'] = 30 - assert trie.values() == ['foo', 10, 30] - assert trie.items() == [('bar', 'foo'), ('foo', 10), ('foobar', 30)] - assert trie.keys() == ['bar', 'foo', 'foobar'] + assert list(trie.values()) == ['foo', 10, 30] + assert list(trie.items()) == [('bar', 'foo'), ('foo', 10), ('foobar', 30)] + assert list(trie.keys()) == ['bar', 'foo', 'foobar'] def test_trie_iter(): @@ -241,37 +241,38 @@ def _trie(self): def test_trie_keys_prefix(self): trie = self._trie() - assert trie.keys('foobarz') == ['foobarzartic'] - assert trie.keys('foobarzart') == ['foobarzartic'] - assert trie.keys('foo') == ['foo', 'foobar', 'foobarzartic', 'foovar'] - assert trie.keys('foobar') == ['foobar', 'foobarzartic'] - assert trie.keys('') == [ + assert list(trie.keys('foobarz')) == ['foobarzartic'] + assert list(trie.keys('foobarzart')) == ['foobarzartic'] + assert list(trie.keys('foo')) == ['foo', 'foobar', 'foobarzartic', 'foovar'] + assert list(trie.keys('foobar')) == ['foobar', 'foobarzartic'] + assert list(trie.keys('')) == [ 'bar', 'foo', 'foobar', 'foobarzartic', 'foovar' ] - assert trie.keys('x') == [] + assert list(trie.keys('x')) == [] def test_trie_items_prefix(self): trie = self._trie() - assert trie.items('foobarz') == [('foobarzartic', None)] - assert trie.items('foobarzart') == [('foobarzartic', None)] - assert trie.items('foo') == [ + assert list(trie.items('foobarz')) == [('foobarzartic', None)] + assert list(trie.items('foobarzart')) == [('foobarzartic', None)] + assert list(trie.items('foo')) == [ ('foo', 10), ('foobar', 30), ('foobarzartic', None), ('foovar', 40) ] - assert trie.items('foobar') == [('foobar', 30), ('foobarzartic', None)] - assert trie.items('') == [ + assert list(trie.items('foobar')) == [ + ('foobar', 30), ('foobarzartic', None)] + assert list(trie.items('')) == [ ('bar', 20), ('foo', 10), ('foobar', 30), ('foobarzartic', None), ('foovar', 40) ] - assert trie.items('x') == [] + assert list(trie.items('x')) == [] def test_trie_values_prefix(self): trie = self._trie() - assert trie.values('foobarz') == [None] - assert trie.values('foobarzart') == [None] - assert trie.values('foo') == [10, 30, None, 40] - assert trie.values('foobar') == [30, None] - assert trie.values('') == [20, 10, 30, None, 40] - assert trie.values('x') == [] + assert list(trie.values('foobarz')) == [None] + assert list(trie.values('foobarzart')) == [None] + assert list(trie.values('foo')) == [10, 30, None, 40] + assert list(trie.values('foobar')) == [30, None] + assert list(trie.values('')) == [20, 10, 30, None, 40] + assert list(trie.values('x')) == [] class TestPrefixSearch(object): diff --git a/tests/test_trieview.py b/tests/test_trieview.py new file mode 100644 index 0000000..201479b --- /dev/null +++ b/tests/test_trieview.py @@ -0,0 +1,287 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import, unicode_literals + +import pytest +import string + +import datrie + + +def test_keys_empty(): + trie = datrie.BaseTrie(string.printable) + keys = trie.keys() + assert len(keys) == 0 + + +def test_keys_iter(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + keys_list = list(keys) + keys_list.sort() + assert keys_list == ["1", "2"] + del trie["2"] + assert list(keys) == ["1"] + + +def test_keys_iter_with_prefix(): + trie = datrie.BaseTrie(string.printable) + keys = trie.keys(prefix="prefix1") + keys_list = list(keys) + assert keys_list == [] + trie["prefix1_1"] = 11 + trie["prefix1_2"] = 12 + trie["prefix2_1"] = 21 + trie["prefix2_2"] = 22 + keys_list = list(keys) + keys_list.sort() + assert keys_list == ["prefix1_1", "prefix1_2"] + del trie["prefix1_1"] + del trie["prefix1_2"] + assert list(keys) == [] + + +def test_keys_contains(): + trie = datrie.BaseTrie(string.printable) + trie["prefix1_1"] = 11 + trie["prefix1_2"] = 12 + trie["prefix2_1"] = 21 + trie["prefix2_2"] = 22 + keys = trie.keys() + assert "prefix1_1" in keys + assert "prefix2_1" in keys + keys = trie.keys(prefix="prefix1") + assert "prefix1_1" in keys + assert "prefix2_1" not in keys + trie["1"] = 1 + keys = trie.keys() + assert "1" in keys + assert 1 not in keys + assert [1] not in keys + items = trie.items() + assert ("1", 1) in items + assert (1, 1) not in items + + +def test_keys_len(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + keys = trie.keys() + assert len(trie) == 1 + assert len(keys) == 1 + trie["1"] = 2 + trie["2"] = 2 + trie["prefix_3"] = 3 + assert len(keys) == 3 + keys = trie.keys(prefix="prefix") + assert len(keys) == 1 + del trie["1"] + del trie["2"] + assert len(keys) == 1 + del trie["prefix_3"] + assert len(keys) == 0 + + +def test_keys_prefix(): + trie = datrie.BaseTrie(string.printable) + trie["prefix1_1"] = 11 + trie["prefix1_2"] = 12 + trie["prefix2_3"] = 21 + keys = trie.keys(prefix="prefix") + assert len(keys) == 3 + keys = trie.keys(prefix="prefix1_") + assert len(keys) == 2 + keys_list = list(keys) + keys_list.sort() + assert keys_list == ["prefix1_1", "prefix1_2"] + del trie["prefix1_1"] + del trie["prefix2_3"] + assert list(keys) == ["prefix1_2"] + + +def test_keys_delete(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + del trie["1"] + assert len(trie) == 1 + assert len(keys) == 1 + + +def test_keys_eq(): + """Test trie.keys() == and != operations""" + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert keys == {"1", "2"} + assert keys == {"2", "1"} + trie["3"] = 3 + assert keys != {"2", "1"} + del trie["1"] + assert keys == {"2", "3"} + trie["prefix_4"] = 4 + keys = trie.keys(prefix="prefix") + assert keys == {"prefix_4"} + assert not keys != {"prefix_4"} + assert keys != {"1", "2", "3"} + + +def test_keys_issuperset(): + """Test trie.keys() >= and > operations""" + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert keys >= {"1"} + with pytest.raises(TypeError): + _ = keys >= 1 # not iterable + assert keys >= {"2"} + assert keys >= {"1", "2"} + assert not keys >= {"1", "2", "3"} + assert not keys >= {"3"} + # Proper superset + assert keys > {"2"} + assert not keys > {"1", "2"} + assert not keys > {"3"} + # Wrong type inside set + assert not keys >= {1, 2} + trie["prefix_3"] = 3 + keys = trie.keys(prefix="prefix") + assert keys >= {"prefix_3"} + assert not keys >= {"prefix_3", "1"} + del trie["prefix_3"] + assert keys >= set() + + +def test_keys_issubset(): + """Test trie.keys() <= and < operations""" + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert not keys <= {"1"} + with pytest.raises(TypeError): + assert not keys <= 1 # not iterable + assert keys <= ["1", "2"] # wrong type + assert keys <= {"1", "2"} + assert keys <= {"1", "2", "3"} + trie["prefix_3"] = 3 + # Proper subset + assert not keys < {"1", "2"} + assert not keys < {"1", "2", "prefix_3"} + assert keys < {"1", "2", "prefix_3", "3"} + keys = trie.keys(prefix="prefix") + assert keys <= {"prefix_3"} + assert keys <= {"prefix_3", "1"} + assert not keys <= {"1", "2"} + del trie["prefix_3"] + assert keys <= {"prefix_3"} + assert keys <= set() + assert keys < {"1", "2", "3"} + + +def test_keys_intersection(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert (keys & keys) == {"1", "2"} + assert (keys & keys) != set() + assert (keys & keys) != {"1"} + assert (keys & keys) != {"2"} + assert (keys & '1') == {"1"} + with pytest.raises(TypeError): + assert (keys & 1) == {"1"} # not iterable + assert (keys & 'ab') == set() + assert (keys & "12") == {"1", "2"} + assert (keys & "1") == {"1"} + trie["prefix_3"] = 3 + keys = trie.keys(prefix="prefix_") + assert (keys & keys) == {"prefix_3"} + assert (keys & keys) == keys + assert (keys & "12") == set() + assert (keys & "") == set() + + +def test_keys_union(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + trie["333"] = 2 + keys = trie.keys() + assert (keys | keys) == keys + assert (keys | set()) == set(keys) + del trie["333"] + assert (keys | {"1"}) == {"1", "2"} + del trie["1"] + assert (keys | {"1"}) == {"1", "2"} + assert (keys | {"2"}) == {"2"} + assert (keys | {"3"}) == {"2", "3"} + keys = trie.keys(prefix="") + assert (keys | {"3"}) == {"2", "3"} + keys = trie.keys(prefix="prefix") + assert (keys | {"3"}) == {"3"} + trie["prefix_3"] = 3 + assert (keys | {"3"}) == {"3", "prefix_3"} + + +def test_keys_difference(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + trie["3"] = 2 + keys = trie.keys() + assert (keys - set()) == set(keys) + assert (keys - {"3"}) == {"1", "2"} + assert (keys - {"2", "3"}) == {"1"} + assert (keys - {"1", "2", "3"}) == set() + assert (keys - {"1", "2", "3", "4"}) == set() + assert (keys - {"4"}) == {"1", "2", "3"} + keys = trie.keys(prefix="prefix") + assert (keys - set()) == set() + assert (keys - {"1"}) == set() + trie["prefix_1"] = 3 + assert (keys - set()) == {"prefix_1"} + assert (keys - {"prefix_1"}) == set() + assert (keys - {"prefix_2"}) == {"prefix_1"} + + +def test_keys_symmetric_difference(): + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert (keys ^ set()) == {"1", "2"} + assert (keys ^ {"1"}) == {"2"} + assert (keys ^ {"1", "2"}) == set() + assert (keys ^ {"1", "2", "3"}) == {"3"} + del trie["1"] + assert (keys ^ {"1"}) == {"1", "2"} + keys = trie.keys(prefix="prefix") + assert (keys ^ {"1"}) == {"1"} + trie["prefix_1"] = 3 + assert (keys ^ {"1"}) == {"prefix_1", "1"} + + +def test_keys_isdisjoint(): + # Return True if null intersection + trie = datrie.BaseTrie(string.printable) + trie["1"] = 1 + trie["2"] = 2 + keys = trie.keys() + assert keys.isdisjoint(set()) + assert not keys.isdisjoint({"1"}) + assert keys.isdisjoint({"3"}) + del trie["1"] + assert keys.isdisjoint({"1"}) + keys = trie.keys(prefix="prefix") + assert keys.isdisjoint({"1"}) + assert keys.isdisjoint({"2"}) + trie["prefix_1"] = 3 + assert keys.isdisjoint({"2"}) + assert not keys.isdisjoint({"prefix_1"}) + assert keys.isdisjoint({"prefix_2"})