diff options
Diffstat (limited to 'tests/api/test_db.py')
-rw-r--r-- | tests/api/test_db.py | 324 |
1 files changed, 324 insertions, 0 deletions
diff --git a/tests/api/test_db.py b/tests/api/test_db.py new file mode 100644 index 0000000..f53b90f --- /dev/null +++ b/tests/api/test_db.py @@ -0,0 +1,324 @@ +from datetime import datetime, timedelta, timezone + +import hircine.db as database +import hircine.db.models as models +import hircine.db.ops as ops +import pytest +from conftest import DB +from hircine.db.models import ( + Artist, + Base, + Comic, + ComicTag, + DateTimeUTC, + MixinID, + Namespace, + Tag, + TagNamespaces, +) +from sqlalchemy.exc import StatementError +from sqlalchemy.orm import ( + Mapped, + mapped_column, +) + + +class Date(MixinID, Base): + date: Mapped[datetime] = mapped_column(DateTimeUTC) + + +@pytest.mark.anyio +async def test_db_requires_tzinfo(): + with pytest.raises(StatementError, match="tzinfo is required"): + await DB.add(Date(date=datetime(2019, 4, 22))) + + +@pytest.mark.anyio +async def test_db_converts_date_input_to_utc(): + date = datetime(2019, 4, 22, tzinfo=timezone(timedelta(hours=-4))) + await DB.add(Date(date=date)) + + item = await DB.get(Date, 1) + + assert item.date.tzinfo == timezone.utc + assert item.date == date + + +@pytest.mark.parametrize( + "modelcls,assoccls", + [ + (models.Artist, models.ComicArtist), + (models.Circle, models.ComicCircle), + (models.Character, models.ComicCharacter), + (models.World, models.ComicWorld), + ], + ids=["artists", "circles", "characters", "worlds"], +) +@pytest.mark.anyio +async def test_models_retained_when_clearing_association( + empty_comic, modelcls, assoccls +): + model = modelcls(id=1, name="foo") + key = f"{modelcls.__name__.lower()}s" + + comic = empty_comic + setattr(comic, key, [model]) + comic = await DB.add(comic) + + async with database.session() as s: + object = await s.get(Comic, comic.id) + setattr(object, key, []) + await s.commit() + + assert await DB.get(assoccls, (comic.id, model.id)) is None + assert await DB.get(Comic, comic.id) is not None + assert await DB.get(modelcls, model.id) is not None + + +@pytest.mark.anyio +async def test_models_retained_when_clearing_comictag(empty_comic): + comic = await DB.add(empty_comic) + + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="bar") + ct = ComicTag(comic_id=comic.id, namespace=namespace, tag=tag) + + await DB.add(ct) + + async with database.session() as s: + object = await s.get(Comic, comic.id) + object.tags = [] + await s.commit() + + assert await DB.get(ComicTag, (comic.id, ct.namespace_id, ct.tag_id)) is None + assert await DB.get(Namespace, namespace.id) is not None + assert await DB.get(Tag, tag.id) is not None + assert await DB.get(Comic, comic.id) is not None + + +@pytest.mark.parametrize( + "modelcls,assoccls", + [ + (models.Artist, models.ComicArtist), + (models.Circle, models.ComicCircle), + (models.Character, models.ComicCharacter), + (models.World, models.ComicWorld), + ], + ids=["artists", "circles", "characters", "worlds"], +) +@pytest.mark.anyio +async def test_only_association_cleared_when_deleting(empty_comic, modelcls, assoccls): + model = modelcls(id=1, name="foo") + + comic = empty_comic + setattr(comic, f"{modelcls.__name__.lower()}s", [model]) + comic = await DB.add(comic) + + await DB.delete(modelcls, model.id) + assert await DB.get(assoccls, (comic.id, model.id)) is None + assert await DB.get(Comic, comic.id) is not None + + +@pytest.mark.parametrize( + "deleted", + [ + "namespace", + "tag", + ], +) +@pytest.mark.anyio +async def test_only_comictag_association_cleared_when_deleting(empty_comic, deleted): + comic = await DB.add(empty_comic) + + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="bar") + + await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)) + + if deleted == "namespace": + await DB.delete(Namespace, namespace.id) + elif deleted == "tag": + await DB.delete(Tag, tag.id) + + assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None + if deleted == "namespace": + assert await DB.get(Tag, tag.id) is not None + elif deleted == "tag": + assert await DB.get(Namespace, namespace.id) is not None + assert await DB.get(Comic, comic.id) is not None + + +@pytest.mark.parametrize( + "modelcls,assoccls", + [ + (models.Artist, models.ComicArtist), + (models.Circle, models.ComicCircle), + (models.Character, models.ComicCharacter), + (models.World, models.ComicWorld), + ], + ids=["artists", "circles", "characters", "worlds"], +) +@pytest.mark.anyio +async def test_deleting_comic_only_clears_association(empty_comic, modelcls, assoccls): + model = modelcls(id=1, name="foo") + + comic = empty_comic + setattr(comic, f"{modelcls.__name__.lower()}s", [model]) + comic = await DB.add(comic) + + await DB.delete(Comic, comic.id) + assert await DB.get(assoccls, (comic.id, model.id)) is None + assert await DB.get(modelcls, model.id) is not None + + +@pytest.mark.anyio +async def test_deleting_comic_only_clears_comictag(empty_comic): + comic = await DB.add(empty_comic) + + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="bar") + + await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)) + await DB.delete(Comic, comic.id) + + assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None + assert await DB.get(Tag, tag.id) is not None + assert await DB.get(Namespace, namespace.id) is not None + + +@pytest.mark.anyio +async def test_models_retained_when_clearing_tagnamespace(): + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="foo", namespaces=[namespace]) + + tag = await DB.add(tag) + + async with database.session() as s: + db_tag = await s.get(Tag, tag.id, options=Tag.load_full()) + db_tag.namespaces = [] + await s.commit() + + assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None + assert await DB.get(Namespace, namespace.id) is not None + assert await DB.get(Tag, tag.id) is not None + + +@pytest.mark.anyio +async def test_only_tagnamespace_cleared_when_deleting_tag(): + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="foo", namespaces=[namespace]) + + tag = await DB.add(tag) + + await DB.delete(Tag, tag.id) + + assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None + assert await DB.get(Namespace, namespace.id) is not None + assert await DB.get(Tag, tag.id) is None + + +@pytest.mark.anyio +async def test_only_tagnamespace_cleared_when_deleting_namespace(): + namespace = Namespace(id=1, name="foo") + tag = Tag(id=1, name="foo", namespaces=[namespace]) + + tag = await DB.add(tag) + + await DB.delete(Namespace, namespace.id) + + assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None + assert await DB.get(Namespace, namespace.id) is None + assert await DB.get(Tag, tag.id) is not None + + +@pytest.mark.parametrize( + "use_identity_map", + [False, True], + ids=["without identity lookup", "with identity lookup"], +) +@pytest.mark.anyio +async def test_ops_get_all(gen_artist, use_identity_map): + artist = await DB.add(next(gen_artist)) + have = list(await DB.add_all(*gen_artist)) + have.append(artist) + + missing_ids = [10, 20] + + async with database.session() as s: + if use_identity_map: + s.add(artist) + + artists, missing = await ops.get_all( + s, + Artist, + [a.id for a in have] + missing_ids, + use_identity_map=use_identity_map, + ) + + assert set([a.id for a in artists]) == set([a.id for a in have]) + assert missing == set(missing_ids) + + +@pytest.mark.anyio +async def test_ops_get_all_names(gen_artist): + have = await DB.add_all(*gen_artist) + missing_names = ["arty", "farty"] + + async with database.session() as s: + artists, missing = await ops.get_all_names( + s, Artist, [a.name for a in have] + missing_names + ) + + assert set([a.name for a in artists]) == set([a.name for a in have]) + assert missing == set(missing_names) + + +@pytest.mark.parametrize( + "missing", + [[("foo", "bar"), ("qux", "qaz")], []], + ids=["missing", "no missing"], +) +@pytest.mark.anyio +async def test_ops_get_ctag_names(gen_comic, gen_tag, gen_namespace, missing): + comic = await DB.add(next(gen_comic)) + have = [(ct.namespace.name, ct.tag.name) for ct in comic.tags] + + async with database.session() as s: + cts, missing = await ops.get_ctag_names(s, comic.id, have + missing) + + assert set(have) == set([(ct.namespace.name, ct.tag.name) for ct in cts]) + assert missing == set(missing) + + +@pytest.mark.anyio +async def test_ops_lookup_identity(gen_artist): + one = await DB.add(next(gen_artist)) + two = await DB.add(next(gen_artist)) + rest = await DB.add_all(*gen_artist) + + async with database.session() as s: + get_one = await s.get(Artist, one.id) + get_two = await s.get(Artist, two.id) + s.add(get_one, get_two) + + artists, satisfied = ops.lookup_identity( + s, Artist, [a.id for a in [one, two] + list(rest)] + ) + + assert set([a.name for a in artists]) == set([a.name for a in [one, two]]) + assert satisfied == set([one.id, two.id]) + + +@pytest.mark.anyio +async def test_ops_get_image_orphans(gen_archive, gen_image): + await DB.add(next(gen_archive)) + + orphan_one = await DB.add(next(gen_image)) + orphan_two = await DB.add(next(gen_image)) + + async with database.session() as s: + orphans = set(await ops.get_image_orphans(s)) + + assert orphans == set( + [(orphan_one.id, orphan_one.hash), (orphan_two.id, orphan_two.hash)] + ) |