From 0fe236829ccbf3b23c7bd0e87910dd2a33391453 Mon Sep 17 00:00:00 2001 From: Aalekh Patel Date: Wed, 20 Sep 2023 03:40:20 -0500 Subject: [PATCH] Inject a GetOrCreateMixin to Node and Relationship. (#244) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Align develop and main branch (#237) Co-authored-by: Bruno Sačarić * Inject a GetOrCreateMixin to Node and Relationship. * Use the imported GQLAlchemyError instead of a fully qualified path. * Use the imported Tuple instead of fully qualified typing.Tuple * Run black. * suppress flake8's f821 on get_or_create definition * Add tests for get_or_create and refactor the mixin into separate implementations for improved docs. Signed-off-by: Aalekh Patel * Apply black. Signed-off-by: Aalekh Patel * Apply black to tests as well. Signed-off-by: Aalekh Patel * typo fix in function signature. Signed-off-by: Aalekh Patel * Assert the counts of nodes and relationships in the test. Signed-off-by: Aalekh Patel * Apply black. Signed-off-by: Aalekh Patel * Add the .execute() to the query builder in failing tests and provide "name" to the node instantiation because it is a required field. * another attempt to fix the query builder usage in tests. * apply black fixes. * Fix tests by relying on the database identifier `_id` instead of the user-defined `id`. Signed-off-by: Aalekh Patel * Remove unused test functions. Signed-off-by: Aalekh Patel --------- Signed-off-by: Aalekh Patel Co-authored-by: Katarina Supe <61758502+katarinasupe@users.noreply.github.com> Co-authored-by: Bruno Sačarić Co-authored-by: katarinasupe Co-authored-by: Aalekh Patel --- gqlalchemy/models.py | 32 ++++++++++++ tests/ogm/test_get_or_create.py | 88 +++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 tests/ogm/test_get_or_create.py diff --git a/gqlalchemy/models.py b/gqlalchemy/models.py index fde0eaca..7daa92df 100644 --- a/gqlalchemy/models.py +++ b/gqlalchemy/models.py @@ -620,6 +620,22 @@ def load(self, db: "Database") -> "Node": # noqa F821 self._id = node._id return self + def get_or_create(self, db: "Database") -> Tuple["Node", bool]: # noqa F821 + """Return the node and a flag for whether it was created in the database. + + Args: + db: The database instance to operate on. + + Returns: + A tuple with the first component being the created graph node, + and the second being a boolean that is True if the node + was created in the database, and False if it was loaded instead. + """ + try: + return self.load(db=db), False + except GQLAlchemyError: + return self.save(db=db), True + class RelationshipMetaclass(BaseModel.__class__): def __new__(mcs, name, bases, namespace, **kwargs): # noqa C901 @@ -693,6 +709,22 @@ def load(self, db: "Database") -> "Relationship": # noqa F821 self._id = relationship._id return self + def get_or_create(self, db: "Database") -> Tuple["Relationship", bool]: # noqa F821 + """Return the relationship and a flag for whether it was created in the database. + + Args: + db: The database instance to operate on. + + Returns: + A tuple with the first component being the created graph relationship, + and the second being a boolean that is True if the relationship + was created in the database, and False if it was loaded instead. + """ + try: + return self.load(db=db), False + except GQLAlchemyError: + return self.save(db=db), True + class Path(GraphObject): _nodes: Iterable[Node] = PrivateAttr() diff --git a/tests/ogm/test_get_or_create.py b/tests/ogm/test_get_or_create.py new file mode 100644 index 00000000..ff80cdf1 --- /dev/null +++ b/tests/ogm/test_get_or_create.py @@ -0,0 +1,88 @@ +# Copyright (c) 2016-2022 Memgraph Ltd. [https://memgraph.com] +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from gqlalchemy import Node, Field, Relationship, GQLAlchemyError + + +@pytest.mark.parametrize("database", ["neo4j", "memgraph"], indirect=True) +def test_get_or_create_node(database): + class User(Node): + name: str = Field(unique=True, db=database) + + class Streamer(User): + name: str = Field(unique=True, db=database) + id: str = Field(index=True, db=database) + followers: int = Field() + totalViewCount: int = Field() + + # Assert that loading a node that doesn't yet exist raises GQLAlchemyError. + non_existent_streamer = Streamer(name="Mislav", id="7", followers=777, totalViewCount=7777) + with pytest.raises(GQLAlchemyError): + database.load_node(non_existent_streamer) + + streamer, created = non_existent_streamer.get_or_create(database) + assert created is True, "Node.get_or_create should create this node since it doesn't yet exist." + assert streamer.name == "Mislav" + assert streamer.id == "7" + assert streamer.followers == 777 + assert streamer.totalViewCount == 7777 + assert streamer._labels == {"Streamer", "User"} + + assert streamer._id is not None, "Since the streamer was created, it should not have a None _id." + + streamer_other, created = non_existent_streamer.get_or_create(database) + assert created is False, "Node.get_or_create should not create this node but load it instead." + assert streamer_other.name == "Mislav" + assert streamer_other.id == "7" + assert streamer_other.followers == 777 + assert streamer_other.totalViewCount == 7777 + assert streamer_other._labels == {"Streamer", "User"} + + assert ( + streamer_other._id == streamer._id + ), "Since the other streamer wasn't created, it should have the same underlying _id property." + + +@pytest.mark.parametrize("database", ["neo4j", "memgraph"], indirect=True) +def test_get_or_create_relationship(database): + class User(Node): + name: str = Field(unique=True, db=database) + + class Follows(Relationship): + _type = "FOLLOWS" + + node_from, created = User(name="foo").get_or_create(database) + assert created is True + assert node_from.name == "foo" + + node_to, created = User(name="bar").get_or_create(database) + assert created is True + assert node_to.name == "bar" + + assert node_from._id != node_to._id, "Since a new node was created, it should have a different id." + + # Assert that loading a relationship that doesn't yet exist raises GQLAlchemyError. + non_existent_relationship = Follows(_start_node_id=node_from._id, _end_node_id=node_to._id) + with pytest.raises(GQLAlchemyError): + database.load_relationship(non_existent_relationship) + + relationship, created = non_existent_relationship.get_or_create(database) + assert created is True, "Relationship.get_or_create should create this relationship since it doesn't yet exist." + assert relationship._id is not None + created_id = relationship._id + + relationship_loaded, created = non_existent_relationship.get_or_create(database) + assert created is False, "Relationship.get_or_create should not create this relationship but load it instead." + assert relationship_loaded._id == created_id