from datetime import datetime, timedelta, timezone import pytest from conftest import DB from sqlalchemy.exc import StatementError from sqlalchemy.orm import ( Mapped, mapped_column, ) import hircine.db as database import hircine.db.models as models import hircine.db.ops as ops from hircine.db.models import ( Artist, Base, Comic, ComicTag, DateTimeUTC, MixinID, Namespace, Tag, TagNamespaces, ) class Date(MixinID, Base): date: Mapped[datetime] = mapped_column(DateTimeUTC) @pytest.mark.anyio async def test_db_requires_tzinfo(): with pytest.raises(StatementError, match="tzinfo is required"): await DB.add(Date(date=datetime(2019, 4, 22))) @pytest.mark.anyio async def test_db_converts_date_input_to_utc(): date = datetime(2019, 4, 22, tzinfo=timezone(timedelta(hours=-4))) await DB.add(Date(date=date)) item = await DB.get(Date, 1) assert item.date.tzinfo == timezone.utc assert item.date == date @pytest.mark.parametrize( "modelcls,assoccls", [ (models.Artist, models.ComicArtist), (models.Circle, models.ComicCircle), (models.Character, models.ComicCharacter), (models.World, models.ComicWorld), ], ids=["artists", "circles", "characters", "worlds"], ) @pytest.mark.anyio async def test_models_retained_when_clearing_association( empty_comic, modelcls, assoccls ): model = modelcls(id=1, name="foo") key = f"{modelcls.__name__.lower()}s" comic = empty_comic setattr(comic, key, [model]) comic = await DB.add(comic) async with database.session() as s: object = await s.get(Comic, comic.id) setattr(object, key, []) await s.commit() assert await DB.get(assoccls, (comic.id, model.id)) is None assert await DB.get(Comic, comic.id) is not None assert await DB.get(modelcls, model.id) is not None @pytest.mark.anyio async def test_models_retained_when_clearing_comictag(empty_comic): comic = await DB.add(empty_comic) namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="bar") ct = ComicTag(comic_id=comic.id, namespace=namespace, tag=tag) await DB.add(ct) async with database.session() as s: object = await s.get(Comic, comic.id) object.tags = [] await s.commit() assert await DB.get(ComicTag, (comic.id, ct.namespace_id, ct.tag_id)) is None assert await DB.get(Namespace, namespace.id) is not None assert await DB.get(Tag, tag.id) is not None assert await DB.get(Comic, comic.id) is not None @pytest.mark.parametrize( "modelcls,assoccls", [ (models.Artist, models.ComicArtist), (models.Circle, models.ComicCircle), (models.Character, models.ComicCharacter), (models.World, models.ComicWorld), ], ids=["artists", "circles", "characters", "worlds"], ) @pytest.mark.anyio async def test_only_association_cleared_when_deleting(empty_comic, modelcls, assoccls): model = modelcls(id=1, name="foo") comic = empty_comic setattr(comic, f"{modelcls.__name__.lower()}s", [model]) comic = await DB.add(comic) await DB.delete(modelcls, model.id) assert await DB.get(assoccls, (comic.id, model.id)) is None assert await DB.get(Comic, comic.id) is not None @pytest.mark.parametrize( "deleted", [ "namespace", "tag", ], ) @pytest.mark.anyio async def test_only_comictag_association_cleared_when_deleting(empty_comic, deleted): comic = await DB.add(empty_comic) namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="bar") await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)) if deleted == "namespace": await DB.delete(Namespace, namespace.id) elif deleted == "tag": await DB.delete(Tag, tag.id) assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None if deleted == "namespace": assert await DB.get(Tag, tag.id) is not None elif deleted == "tag": assert await DB.get(Namespace, namespace.id) is not None assert await DB.get(Comic, comic.id) is not None @pytest.mark.parametrize( "modelcls,assoccls", [ (models.Artist, models.ComicArtist), (models.Circle, models.ComicCircle), (models.Character, models.ComicCharacter), (models.World, models.ComicWorld), ], ids=["artists", "circles", "characters", "worlds"], ) @pytest.mark.anyio async def test_deleting_comic_only_clears_association(empty_comic, modelcls, assoccls): model = modelcls(id=1, name="foo") comic = empty_comic setattr(comic, f"{modelcls.__name__.lower()}s", [model]) comic = await DB.add(comic) await DB.delete(Comic, comic.id) assert await DB.get(assoccls, (comic.id, model.id)) is None assert await DB.get(modelcls, model.id) is not None @pytest.mark.anyio async def test_deleting_comic_only_clears_comictag(empty_comic): comic = await DB.add(empty_comic) namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="bar") await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)) await DB.delete(Comic, comic.id) assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None assert await DB.get(Tag, tag.id) is not None assert await DB.get(Namespace, namespace.id) is not None @pytest.mark.anyio async def test_models_retained_when_clearing_tagnamespace(): namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="foo", namespaces=[namespace]) tag = await DB.add(tag) async with database.session() as s: db_tag = await s.get(Tag, tag.id, options=Tag.load_full()) db_tag.namespaces = [] await s.commit() assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None assert await DB.get(Namespace, namespace.id) is not None assert await DB.get(Tag, tag.id) is not None @pytest.mark.anyio async def test_only_tagnamespace_cleared_when_deleting_tag(): namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="foo", namespaces=[namespace]) tag = await DB.add(tag) await DB.delete(Tag, tag.id) assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None assert await DB.get(Namespace, namespace.id) is not None assert await DB.get(Tag, tag.id) is None @pytest.mark.anyio async def test_only_tagnamespace_cleared_when_deleting_namespace(): namespace = Namespace(id=1, name="foo") tag = Tag(id=1, name="foo", namespaces=[namespace]) tag = await DB.add(tag) await DB.delete(Namespace, namespace.id) assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None assert await DB.get(Namespace, namespace.id) is None assert await DB.get(Tag, tag.id) is not None @pytest.mark.parametrize( "use_identity_map", [False, True], ids=["without identity lookup", "with identity lookup"], ) @pytest.mark.anyio async def test_ops_get_all(gen_artist, use_identity_map): artist = await DB.add(next(gen_artist)) have = list(await DB.add_all(*gen_artist)) have.append(artist) missing_ids = [10, 20] async with database.session() as s: if use_identity_map: s.add(artist) artists, missing = await ops.get_all( s, Artist, [a.id for a in have] + missing_ids, use_identity_map=use_identity_map, ) assert set([a.id for a in artists]) == set([a.id for a in have]) assert missing == set(missing_ids) @pytest.mark.anyio async def test_ops_get_all_names(gen_artist): have = await DB.add_all(*gen_artist) missing_names = ["arty", "farty"] async with database.session() as s: artists, missing = await ops.get_all_names( s, Artist, [a.name for a in have] + missing_names ) assert set([a.name for a in artists]) == set([a.name for a in have]) assert missing == set(missing_names) @pytest.mark.parametrize( "missing", [[("foo", "bar"), ("qux", "qaz")], []], ids=["missing", "no missing"], ) @pytest.mark.anyio async def test_ops_get_ctag_names(gen_comic, gen_tag, gen_namespace, missing): comic = await DB.add(next(gen_comic)) have = [(ct.namespace.name, ct.tag.name) for ct in comic.tags] async with database.session() as s: cts, missing = await ops.get_ctag_names(s, comic.id, have + missing) assert set(have) == set([(ct.namespace.name, ct.tag.name) for ct in cts]) assert missing == set(missing) @pytest.mark.anyio async def test_ops_lookup_identity(gen_artist): one = await DB.add(next(gen_artist)) two = await DB.add(next(gen_artist)) rest = await DB.add_all(*gen_artist) async with database.session() as s: get_one = await s.get(Artist, one.id) get_two = await s.get(Artist, two.id) s.add(get_one, get_two) artists, satisfied = ops.lookup_identity( s, Artist, [a.id for a in [one, two] + list(rest)] ) assert set([a.name for a in artists]) == set([a.name for a in [one, two]]) assert satisfied == set([one.id, two.id]) @pytest.mark.anyio async def test_ops_get_image_orphans(gen_archive, gen_image): await DB.add(next(gen_archive)) orphan_one = await DB.add(next(gen_image)) orphan_two = await DB.add(next(gen_image)) async with database.session() as s: orphans = set(await ops.get_image_orphans(s)) assert orphans == set( [(orphan_one.id, orphan_one.hash), (orphan_two.id, orphan_two.hash)] )