Skip to content

Commit

Permalink
Fix design flaw in category_index (when moving categories)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhekkel committed Jan 4, 2024
1 parent 47e59a5 commit 0f8a7c4
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 58 deletions.
1 change: 1 addition & 0 deletions changelog
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Version 6.0.1
- Change default order to write out categories in a file based on
parent/child relationship
- Added validate_pdbx and recover_pdbx
- Fixed a serious bug in category_index when moving categories

Version 6.0.0
- Drop the use of CCP4's monomer library for compound information
Expand Down
114 changes: 56 additions & 58 deletions src/category.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class row_comparator
{
public:
row_comparator(category &cat)
: m_category(cat)
{
auto cv = cat.get_cat_validator();

Expand All @@ -69,13 +68,13 @@ class row_comparator
}
}

int operator()(const row *a, const row *b) const
int operator()(const category &cat, const row *a, const row *b) const
{
assert(a);
assert(b);

row_handle rha(m_category, *a);
row_handle rhb(m_category, *b);
row_handle rha(cat, *a);
row_handle rhb(cat, *b);

int d = 0;
for (const auto &[k, f] : m_comparator)
Expand All @@ -92,11 +91,11 @@ class row_comparator
return d;
}

int operator()(const row_initializer &a, const row *b) const
int operator()(const category &cat, const row_initializer &a, const row *b) const
{
assert(b);

row_handle rhb(m_category, *b);
row_handle rhb(cat, *b);

int d = 0;
auto ai = a.begin();
Expand Down Expand Up @@ -124,7 +123,6 @@ class row_comparator
using key_comparator = std::tuple<uint16_t, compareFunc>;

std::vector<key_comparator> m_comparator;
category &m_category;
};

// --------------------------------------------------------------------
Expand All @@ -135,18 +133,18 @@ class row_comparator
class category_index
{
public:
category_index(category *cat);
category_index(category &cat);

~category_index()
{
delete m_root;
}

row *find(row *k) const;
row *find_by_value(row_initializer k) const;
row *find(const category &cat, row *k) const;
row *find_by_value(const category &cat, row_initializer k) const;

void insert(row *r);
void erase(row *r);
void insert(category &cat, row *r);
void erase(category &cat, row *r);

// reorder the row's and returns new head and tail
std::tuple<row *, row *> reorder()
Expand Down Expand Up @@ -192,8 +190,8 @@ class category_index
bool m_red;
};

entry *insert(entry *h, row *v);
entry *erase(entry *h, row *k);
entry *insert(category &cat, entry *h, row *v);
entry *erase(category &cat, entry *h, row *k);

// void validate(entry* h, bool isParentRed, uint32_t blackDepth, uint32_t& minBlack, uint32_t& maxBlack) const;

Expand Down Expand Up @@ -324,26 +322,24 @@ class category_index
return result;
}

category &m_category;
row_comparator m_row_comparator;
entry *m_root;
};

category_index::category_index(category *cat)
: m_category(*cat)
, m_row_comparator(m_category)
category_index::category_index(category &cat)
: m_row_comparator(cat)
, m_root(nullptr)
{
for (auto r : m_category)
insert(r.get_row());
for (auto r : cat)
insert(cat, r.get_row());
}

row *category_index::find(row *k) const
row *category_index::find(const category &cat, row *k) const
{
const entry *r = m_root;
while (r != nullptr)
{
int d = m_row_comparator(k, r->m_row);
int d = m_row_comparator(cat, k, r->m_row);
if (d < 0)
r = r->m_left;
else if (d > 0)
Expand All @@ -355,14 +351,14 @@ row *category_index::find(row *k) const
return r ? r->m_row : nullptr;
}

row *category_index::find_by_value(row_initializer k) const
row *category_index::find_by_value(const category &cat, row_initializer k) const
{
// sort the values in k first

row_initializer k2;
for (auto &f : m_category.key_field_indices())
for (auto &f : cat.key_field_indices())
{
auto fld = m_category.get_column_name(f);
auto fld = cat.get_column_name(f);

auto ki = find_if(k.begin(), k.end(), [&fld](auto &i) { return i.name() == fld; });
if (ki == k.end())
Expand All @@ -374,7 +370,7 @@ row *category_index::find_by_value(row_initializer k) const
const entry *r = m_root;
while (r != nullptr)
{
int d = m_row_comparator(k2, r->m_row);
int d = m_row_comparator(cat, k2, r->m_row);
if (d < 0)
r = r->m_left;
else if (d > 0)
Expand All @@ -386,34 +382,34 @@ row *category_index::find_by_value(row_initializer k) const
return r ? r->m_row : nullptr;
}

void category_index::insert(row *k)
void category_index::insert(category &cat, row *k)
{
m_root = insert(m_root, k);
m_root = insert(cat, m_root, k);
m_root->m_red = false;
}

category_index::entry *category_index::insert(entry *h, row *v)
category_index::entry *category_index::insert(category &cat, entry *h, row *v)
{
if (h == nullptr)
return new entry(v);

int d = m_row_comparator(v, h->m_row);
int d = m_row_comparator(cat, v, h->m_row);
if (d < 0)
h->m_left = insert(h->m_left, v);
h->m_left = insert(cat, h->m_left, v);
else if (d > 0)
h->m_right = insert(h->m_right, v);
h->m_right = insert(cat, h->m_right, v);
else
{
row_handle rh(m_category, *v);
row_handle rh(cat, *v);

std::ostringstream os;
for (auto col : m_category.key_fields())
for (auto col : cat.key_fields())
{
if (rh[col])
os << col << ": " << std::quoted(rh[col].text()) << "; ";
}

throw duplicate_key_error("Duplicate Key violation, cat: " + m_category.name() + " values: " + os.str());
throw duplicate_key_error("Duplicate Key violation, cat: " + cat.name() + " values: " + os.str());
}

if (is_red(h->m_right) and not is_red(h->m_left))
Expand All @@ -428,33 +424,33 @@ category_index::entry *category_index::insert(entry *h, row *v)
return h;
}

void category_index::erase(row *k)
void category_index::erase(category &cat, row *k)
{
assert(find(k) == k);
assert(find(cat, k) == k);

m_root = erase(m_root, k);
m_root = erase(cat, m_root, k);
if (m_root != nullptr)
m_root->m_red = false;
}

category_index::entry *category_index::erase(entry *h, row *k)
category_index::entry *category_index::erase(category &cat, entry *h, row *k)
{
if (m_row_comparator(k, h->m_row) < 0)
if (m_row_comparator(cat, k, h->m_row) < 0)
{
if (h->m_left != nullptr)
{
if (not is_red(h->m_left) and not is_red(h->m_left->m_left))
h = move_red_left(h);

h->m_left = erase(h->m_left, k);
h->m_left = erase(cat, h->m_left, k);
}
}
else
{
if (is_red(h->m_left))
h = rotateRight(h);

if (m_row_comparator(k, h->m_row) == 0 and h->m_right == nullptr)
if (m_row_comparator(cat, k, h->m_row) == 0 and h->m_right == nullptr)
{
delete h;
return nullptr;
Expand All @@ -465,13 +461,13 @@ category_index::entry *category_index::erase(entry *h, row *k)
if (not is_red(h->m_right) and not is_red(h->m_right->m_left))
h = move_red_right(h);

if (m_row_comparator(k, h->m_row) == 0)
if (m_row_comparator(cat, k, h->m_row) == 0)
{
h->m_row = find_min(h->m_right)->m_row;
h->m_right = erase_min(h->m_right);
}
else
h->m_right = erase(h->m_right, k);
h->m_right = erase(cat, h->m_right, k);
}
}

Expand Down Expand Up @@ -520,7 +516,7 @@ category::category(const category &rhs)
insert_impl(end(), clone_row(*r));

if (m_cat_validator != nullptr and m_index == nullptr)
m_index = new category_index(this);
m_index = new category_index(*this);
}

category::category(category &&rhs)
Expand Down Expand Up @@ -564,7 +560,7 @@ category &category::operator=(const category &rhs)
m_cat_validator = rhs.m_cat_validator;

if (m_cat_validator != nullptr and m_index == nullptr)
m_index = new category_index(this);
m_index = new category_index(*this);
}

return *this;
Expand Down Expand Up @@ -669,7 +665,7 @@ void category::set_validator(const validator *v, datablock &db)
}

if (missing.empty())
m_index = new category_index(this);
m_index = new category_index(*this);
else
{
std::ostringstream msg;
Expand Down Expand Up @@ -782,7 +778,7 @@ bool category::is_valid() const
for (auto r : *this)
{
auto p = r.get_row();
if (m_index->find(p) != p)
if (m_index->find(*this, p) != p)
m_validator->report_error("Key not found in index for category " + m_name, true);
}
}
Expand Down Expand Up @@ -904,7 +900,7 @@ row_handle category::operator[](const key_type &key)
if (m_index == nullptr)
throw std::logic_error("Category " + m_name + " does not have an index");

auto row = m_index->find_by_value(key);
auto row = m_index->find_by_value(*this, key);
if (row != nullptr)
result = { *this, *row };
}
Expand Down Expand Up @@ -1078,7 +1074,7 @@ category::iterator category::erase(iterator pos)
throw std::runtime_error("erase");

if (m_index != nullptr)
m_index->erase(r);
m_index->erase(*this, r);

if (r == m_head)
{
Expand Down Expand Up @@ -1250,12 +1246,14 @@ std::string category::get_unique_id(std::function<std::string(int)> generator)
std::string id_tag = "id";
if (m_cat_validator != nullptr and m_cat_validator->m_keys.size() == 1)
{
id_tag = m_cat_validator->m_keys.front();

if (m_index == nullptr and m_cat_validator != nullptr)
m_index = new category_index(this);
m_index = new category_index(*this);

for (;;)
{
if (m_index->find_by_value({{ id_tag, result }}) == nullptr)
if (m_index->find_by_value(*this, {{ id_tag, result }}) == nullptr)
break;
result = generator(static_cast<int>(m_last_unique_num++));
}
Expand Down Expand Up @@ -1407,7 +1405,7 @@ void category::update_value(row *row, uint16_t column, std::string_view value, b
{
// make sure we have an index, if possible
if (m_index == nullptr and m_cat_validator != nullptr)
m_index = new category_index(this);
m_index = new category_index(*this);

auto &col = m_columns[column];

Expand All @@ -1433,9 +1431,9 @@ void category::update_value(row *row, uint16_t column, std::string_view value, b
if (updateLinked and // an update of an Item's value
m_index != nullptr and key_field_indices().count(column))
{
reinsert = m_index->find(row);
reinsert = m_index->find(*this, row);
if (reinsert)
m_index->erase(row);
m_index->erase(*this, row);
}

// first remove old value with cix
Expand All @@ -1446,7 +1444,7 @@ void category::update_value(row *row, uint16_t column, std::string_view value, b
row->append(column, { value });

if (reinsert)
m_index->insert(row);
m_index->insert(*this, row);

// see if we need to update any child categories that depend on this value
auto iv = col.m_validator;
Expand Down Expand Up @@ -1602,7 +1600,7 @@ row_handle category::create_copy(row_handle r)
category::iterator category::insert_impl(const_iterator pos, row *n)
{
if (m_index == nullptr and m_cat_validator != nullptr)
m_index = new category_index(this);
m_index = new category_index(*this);

assert(n != nullptr);
assert(n->m_next == nullptr);
Expand Down Expand Up @@ -1642,7 +1640,7 @@ category::iterator category::insert_impl(const_iterator pos, row *n)
}

if (m_index != nullptr)
m_index->insert(n);
m_index->insert(*this, n);

// insert at end, most often this is the case
if (pos.m_current == nullptr)
Expand Down

0 comments on commit 0f8a7c4

Please sign in to comment.