summaryrefslogtreecommitdiffstatshomepage
path: root/tests/api/test_db.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/api/test_db.py')
-rw-r--r--tests/api/test_db.py324
1 files changed, 324 insertions, 0 deletions
diff --git a/tests/api/test_db.py b/tests/api/test_db.py
new file mode 100644
index 0000000..f53b90f
--- /dev/null
+++ b/tests/api/test_db.py
@@ -0,0 +1,324 @@
+from datetime import datetime, timedelta, timezone
+
+import hircine.db as database
+import hircine.db.models as models
+import hircine.db.ops as ops
+import pytest
+from conftest import DB
+from hircine.db.models import (
+ Artist,
+ Base,
+ Comic,
+ ComicTag,
+ DateTimeUTC,
+ MixinID,
+ Namespace,
+ Tag,
+ TagNamespaces,
+)
+from sqlalchemy.exc import StatementError
+from sqlalchemy.orm import (
+ Mapped,
+ mapped_column,
+)
+
+
+class Date(MixinID, Base):
+ date: Mapped[datetime] = mapped_column(DateTimeUTC)
+
+
+@pytest.mark.anyio
+async def test_db_requires_tzinfo():
+ with pytest.raises(StatementError, match="tzinfo is required"):
+ await DB.add(Date(date=datetime(2019, 4, 22)))
+
+
+@pytest.mark.anyio
+async def test_db_converts_date_input_to_utc():
+ date = datetime(2019, 4, 22, tzinfo=timezone(timedelta(hours=-4)))
+ await DB.add(Date(date=date))
+
+ item = await DB.get(Date, 1)
+
+ assert item.date.tzinfo == timezone.utc
+ assert item.date == date
+
+
+@pytest.mark.parametrize(
+ "modelcls,assoccls",
+ [
+ (models.Artist, models.ComicArtist),
+ (models.Circle, models.ComicCircle),
+ (models.Character, models.ComicCharacter),
+ (models.World, models.ComicWorld),
+ ],
+ ids=["artists", "circles", "characters", "worlds"],
+)
+@pytest.mark.anyio
+async def test_models_retained_when_clearing_association(
+ empty_comic, modelcls, assoccls
+):
+ model = modelcls(id=1, name="foo")
+ key = f"{modelcls.__name__.lower()}s"
+
+ comic = empty_comic
+ setattr(comic, key, [model])
+ comic = await DB.add(comic)
+
+ async with database.session() as s:
+ object = await s.get(Comic, comic.id)
+ setattr(object, key, [])
+ await s.commit()
+
+ assert await DB.get(assoccls, (comic.id, model.id)) is None
+ assert await DB.get(Comic, comic.id) is not None
+ assert await DB.get(modelcls, model.id) is not None
+
+
+@pytest.mark.anyio
+async def test_models_retained_when_clearing_comictag(empty_comic):
+ comic = await DB.add(empty_comic)
+
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="bar")
+ ct = ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)
+
+ await DB.add(ct)
+
+ async with database.session() as s:
+ object = await s.get(Comic, comic.id)
+ object.tags = []
+ await s.commit()
+
+ assert await DB.get(ComicTag, (comic.id, ct.namespace_id, ct.tag_id)) is None
+ assert await DB.get(Namespace, namespace.id) is not None
+ assert await DB.get(Tag, tag.id) is not None
+ assert await DB.get(Comic, comic.id) is not None
+
+
+@pytest.mark.parametrize(
+ "modelcls,assoccls",
+ [
+ (models.Artist, models.ComicArtist),
+ (models.Circle, models.ComicCircle),
+ (models.Character, models.ComicCharacter),
+ (models.World, models.ComicWorld),
+ ],
+ ids=["artists", "circles", "characters", "worlds"],
+)
+@pytest.mark.anyio
+async def test_only_association_cleared_when_deleting(empty_comic, modelcls, assoccls):
+ model = modelcls(id=1, name="foo")
+
+ comic = empty_comic
+ setattr(comic, f"{modelcls.__name__.lower()}s", [model])
+ comic = await DB.add(comic)
+
+ await DB.delete(modelcls, model.id)
+ assert await DB.get(assoccls, (comic.id, model.id)) is None
+ assert await DB.get(Comic, comic.id) is not None
+
+
+@pytest.mark.parametrize(
+ "deleted",
+ [
+ "namespace",
+ "tag",
+ ],
+)
+@pytest.mark.anyio
+async def test_only_comictag_association_cleared_when_deleting(empty_comic, deleted):
+ comic = await DB.add(empty_comic)
+
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="bar")
+
+ await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))
+
+ if deleted == "namespace":
+ await DB.delete(Namespace, namespace.id)
+ elif deleted == "tag":
+ await DB.delete(Tag, tag.id)
+
+ assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
+ if deleted == "namespace":
+ assert await DB.get(Tag, tag.id) is not None
+ elif deleted == "tag":
+ assert await DB.get(Namespace, namespace.id) is not None
+ assert await DB.get(Comic, comic.id) is not None
+
+
+@pytest.mark.parametrize(
+ "modelcls,assoccls",
+ [
+ (models.Artist, models.ComicArtist),
+ (models.Circle, models.ComicCircle),
+ (models.Character, models.ComicCharacter),
+ (models.World, models.ComicWorld),
+ ],
+ ids=["artists", "circles", "characters", "worlds"],
+)
+@pytest.mark.anyio
+async def test_deleting_comic_only_clears_association(empty_comic, modelcls, assoccls):
+ model = modelcls(id=1, name="foo")
+
+ comic = empty_comic
+ setattr(comic, f"{modelcls.__name__.lower()}s", [model])
+ comic = await DB.add(comic)
+
+ await DB.delete(Comic, comic.id)
+ assert await DB.get(assoccls, (comic.id, model.id)) is None
+ assert await DB.get(modelcls, model.id) is not None
+
+
+@pytest.mark.anyio
+async def test_deleting_comic_only_clears_comictag(empty_comic):
+ comic = await DB.add(empty_comic)
+
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="bar")
+
+ await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))
+ await DB.delete(Comic, comic.id)
+
+ assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
+ assert await DB.get(Tag, tag.id) is not None
+ assert await DB.get(Namespace, namespace.id) is not None
+
+
+@pytest.mark.anyio
+async def test_models_retained_when_clearing_tagnamespace():
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="foo", namespaces=[namespace])
+
+ tag = await DB.add(tag)
+
+ async with database.session() as s:
+ db_tag = await s.get(Tag, tag.id, options=Tag.load_full())
+ db_tag.namespaces = []
+ await s.commit()
+
+ assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
+ assert await DB.get(Namespace, namespace.id) is not None
+ assert await DB.get(Tag, tag.id) is not None
+
+
+@pytest.mark.anyio
+async def test_only_tagnamespace_cleared_when_deleting_tag():
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="foo", namespaces=[namespace])
+
+ tag = await DB.add(tag)
+
+ await DB.delete(Tag, tag.id)
+
+ assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
+ assert await DB.get(Namespace, namespace.id) is not None
+ assert await DB.get(Tag, tag.id) is None
+
+
+@pytest.mark.anyio
+async def test_only_tagnamespace_cleared_when_deleting_namespace():
+ namespace = Namespace(id=1, name="foo")
+ tag = Tag(id=1, name="foo", namespaces=[namespace])
+
+ tag = await DB.add(tag)
+
+ await DB.delete(Namespace, namespace.id)
+
+ assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
+ assert await DB.get(Namespace, namespace.id) is None
+ assert await DB.get(Tag, tag.id) is not None
+
+
+@pytest.mark.parametrize(
+ "use_identity_map",
+ [False, True],
+ ids=["without identity lookup", "with identity lookup"],
+)
+@pytest.mark.anyio
+async def test_ops_get_all(gen_artist, use_identity_map):
+ artist = await DB.add(next(gen_artist))
+ have = list(await DB.add_all(*gen_artist))
+ have.append(artist)
+
+ missing_ids = [10, 20]
+
+ async with database.session() as s:
+ if use_identity_map:
+ s.add(artist)
+
+ artists, missing = await ops.get_all(
+ s,
+ Artist,
+ [a.id for a in have] + missing_ids,
+ use_identity_map=use_identity_map,
+ )
+
+ assert set([a.id for a in artists]) == set([a.id for a in have])
+ assert missing == set(missing_ids)
+
+
+@pytest.mark.anyio
+async def test_ops_get_all_names(gen_artist):
+ have = await DB.add_all(*gen_artist)
+ missing_names = ["arty", "farty"]
+
+ async with database.session() as s:
+ artists, missing = await ops.get_all_names(
+ s, Artist, [a.name for a in have] + missing_names
+ )
+
+ assert set([a.name for a in artists]) == set([a.name for a in have])
+ assert missing == set(missing_names)
+
+
+@pytest.mark.parametrize(
+ "missing",
+ [[("foo", "bar"), ("qux", "qaz")], []],
+ ids=["missing", "no missing"],
+)
+@pytest.mark.anyio
+async def test_ops_get_ctag_names(gen_comic, gen_tag, gen_namespace, missing):
+ comic = await DB.add(next(gen_comic))
+ have = [(ct.namespace.name, ct.tag.name) for ct in comic.tags]
+
+ async with database.session() as s:
+ cts, missing = await ops.get_ctag_names(s, comic.id, have + missing)
+
+ assert set(have) == set([(ct.namespace.name, ct.tag.name) for ct in cts])
+ assert missing == set(missing)
+
+
+@pytest.mark.anyio
+async def test_ops_lookup_identity(gen_artist):
+ one = await DB.add(next(gen_artist))
+ two = await DB.add(next(gen_artist))
+ rest = await DB.add_all(*gen_artist)
+
+ async with database.session() as s:
+ get_one = await s.get(Artist, one.id)
+ get_two = await s.get(Artist, two.id)
+ s.add(get_one, get_two)
+
+ artists, satisfied = ops.lookup_identity(
+ s, Artist, [a.id for a in [one, two] + list(rest)]
+ )
+
+ assert set([a.name for a in artists]) == set([a.name for a in [one, two]])
+ assert satisfied == set([one.id, two.id])
+
+
+@pytest.mark.anyio
+async def test_ops_get_image_orphans(gen_archive, gen_image):
+ await DB.add(next(gen_archive))
+
+ orphan_one = await DB.add(next(gen_image))
+ orphan_two = await DB.add(next(gen_image))
+
+ async with database.session() as s:
+ orphans = set(await ops.get_image_orphans(s))
+
+ assert orphans == set(
+ [(orphan_one.id, orphan_one.hash), (orphan_two.id, orphan_two.hash)]
+ )