summaryrefslogblamecommitdiffstatshomepage
path: root/tests/api/test_db.py
blob: 1405c23c8849ab94799ff4dc479737c8050ec103 (plain) (tree)
1
2
3
4
5
6
7
8
9
10

                                                  







                                         


                                  










                               












































































































































































































































































































                                                                                       
from datetime import datetime, timedelta, timezone

import pytest
from conftest import DB
from sqlalchemy.exc import StatementError
from sqlalchemy.orm import (
    Mapped,
    mapped_column,
)

import hircine.db as database
import hircine.db.models as models
import hircine.db.ops as ops
from hircine.db.models import (
    Artist,
    Base,
    Comic,
    ComicTag,
    DateTimeUTC,
    MixinID,
    Namespace,
    Tag,
    TagNamespaces,
)


class Date(MixinID, Base):
    date: Mapped[datetime] = mapped_column(DateTimeUTC)


@pytest.mark.anyio
async def test_db_requires_tzinfo():
    with pytest.raises(StatementError, match="tzinfo is required"):
        await DB.add(Date(date=datetime(2019, 4, 22)))


@pytest.mark.anyio
async def test_db_converts_date_input_to_utc():
    date = datetime(2019, 4, 22, tzinfo=timezone(timedelta(hours=-4)))
    await DB.add(Date(date=date))

    item = await DB.get(Date, 1)

    assert item.date.tzinfo == timezone.utc
    assert item.date == date


@pytest.mark.parametrize(
    "modelcls,assoccls",
    [
        (models.Artist, models.ComicArtist),
        (models.Circle, models.ComicCircle),
        (models.Character, models.ComicCharacter),
        (models.World, models.ComicWorld),
    ],
    ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_models_retained_when_clearing_association(
    empty_comic, modelcls, assoccls
):
    model = modelcls(id=1, name="foo")
    key = f"{modelcls.__name__.lower()}s"

    comic = empty_comic
    setattr(comic, key, [model])
    comic = await DB.add(comic)

    async with database.session() as s:
        object = await s.get(Comic, comic.id)
        setattr(object, key, [])
        await s.commit()

    assert await DB.get(assoccls, (comic.id, model.id)) is None
    assert await DB.get(Comic, comic.id) is not None
    assert await DB.get(modelcls, model.id) is not None


@pytest.mark.anyio
async def test_models_retained_when_clearing_comictag(empty_comic):
    comic = await DB.add(empty_comic)

    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="bar")
    ct = ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)

    await DB.add(ct)

    async with database.session() as s:
        object = await s.get(Comic, comic.id)
        object.tags = []
        await s.commit()

    assert await DB.get(ComicTag, (comic.id, ct.namespace_id, ct.tag_id)) is None
    assert await DB.get(Namespace, namespace.id) is not None
    assert await DB.get(Tag, tag.id) is not None
    assert await DB.get(Comic, comic.id) is not None


@pytest.mark.parametrize(
    "modelcls,assoccls",
    [
        (models.Artist, models.ComicArtist),
        (models.Circle, models.ComicCircle),
        (models.Character, models.ComicCharacter),
        (models.World, models.ComicWorld),
    ],
    ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_only_association_cleared_when_deleting(empty_comic, modelcls, assoccls):
    model = modelcls(id=1, name="foo")

    comic = empty_comic
    setattr(comic, f"{modelcls.__name__.lower()}s", [model])
    comic = await DB.add(comic)

    await DB.delete(modelcls, model.id)
    assert await DB.get(assoccls, (comic.id, model.id)) is None
    assert await DB.get(Comic, comic.id) is not None


@pytest.mark.parametrize(
    "deleted",
    [
        "namespace",
        "tag",
    ],
)
@pytest.mark.anyio
async def test_only_comictag_association_cleared_when_deleting(empty_comic, deleted):
    comic = await DB.add(empty_comic)

    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="bar")

    await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))

    if deleted == "namespace":
        await DB.delete(Namespace, namespace.id)
    elif deleted == "tag":
        await DB.delete(Tag, tag.id)

    assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
    if deleted == "namespace":
        assert await DB.get(Tag, tag.id) is not None
    elif deleted == "tag":
        assert await DB.get(Namespace, namespace.id) is not None
    assert await DB.get(Comic, comic.id) is not None


@pytest.mark.parametrize(
    "modelcls,assoccls",
    [
        (models.Artist, models.ComicArtist),
        (models.Circle, models.ComicCircle),
        (models.Character, models.ComicCharacter),
        (models.World, models.ComicWorld),
    ],
    ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_deleting_comic_only_clears_association(empty_comic, modelcls, assoccls):
    model = modelcls(id=1, name="foo")

    comic = empty_comic
    setattr(comic, f"{modelcls.__name__.lower()}s", [model])
    comic = await DB.add(comic)

    await DB.delete(Comic, comic.id)
    assert await DB.get(assoccls, (comic.id, model.id)) is None
    assert await DB.get(modelcls, model.id) is not None


@pytest.mark.anyio
async def test_deleting_comic_only_clears_comictag(empty_comic):
    comic = await DB.add(empty_comic)

    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="bar")

    await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))
    await DB.delete(Comic, comic.id)

    assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
    assert await DB.get(Tag, tag.id) is not None
    assert await DB.get(Namespace, namespace.id) is not None


@pytest.mark.anyio
async def test_models_retained_when_clearing_tagnamespace():
    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="foo", namespaces=[namespace])

    tag = await DB.add(tag)

    async with database.session() as s:
        db_tag = await s.get(Tag, tag.id, options=Tag.load_full())
        db_tag.namespaces = []
        await s.commit()

    assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
    assert await DB.get(Namespace, namespace.id) is not None
    assert await DB.get(Tag, tag.id) is not None


@pytest.mark.anyio
async def test_only_tagnamespace_cleared_when_deleting_tag():
    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="foo", namespaces=[namespace])

    tag = await DB.add(tag)

    await DB.delete(Tag, tag.id)

    assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
    assert await DB.get(Namespace, namespace.id) is not None
    assert await DB.get(Tag, tag.id) is None


@pytest.mark.anyio
async def test_only_tagnamespace_cleared_when_deleting_namespace():
    namespace = Namespace(id=1, name="foo")
    tag = Tag(id=1, name="foo", namespaces=[namespace])

    tag = await DB.add(tag)

    await DB.delete(Namespace, namespace.id)

    assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
    assert await DB.get(Namespace, namespace.id) is None
    assert await DB.get(Tag, tag.id) is not None


@pytest.mark.parametrize(
    "use_identity_map",
    [False, True],
    ids=["without identity lookup", "with identity lookup"],
)
@pytest.mark.anyio
async def test_ops_get_all(gen_artist, use_identity_map):
    artist = await DB.add(next(gen_artist))
    have = list(await DB.add_all(*gen_artist))
    have.append(artist)

    missing_ids = [10, 20]

    async with database.session() as s:
        if use_identity_map:
            s.add(artist)

        artists, missing = await ops.get_all(
            s,
            Artist,
            [a.id for a in have] + missing_ids,
            use_identity_map=use_identity_map,
        )

    assert set([a.id for a in artists]) == set([a.id for a in have])
    assert missing == set(missing_ids)


@pytest.mark.anyio
async def test_ops_get_all_names(gen_artist):
    have = await DB.add_all(*gen_artist)
    missing_names = ["arty", "farty"]

    async with database.session() as s:
        artists, missing = await ops.get_all_names(
            s, Artist, [a.name for a in have] + missing_names
        )

    assert set([a.name for a in artists]) == set([a.name for a in have])
    assert missing == set(missing_names)


@pytest.mark.parametrize(
    "missing",
    [[("foo", "bar"), ("qux", "qaz")], []],
    ids=["missing", "no missing"],
)
@pytest.mark.anyio
async def test_ops_get_ctag_names(gen_comic, gen_tag, gen_namespace, missing):
    comic = await DB.add(next(gen_comic))
    have = [(ct.namespace.name, ct.tag.name) for ct in comic.tags]

    async with database.session() as s:
        cts, missing = await ops.get_ctag_names(s, comic.id, have + missing)

    assert set(have) == set([(ct.namespace.name, ct.tag.name) for ct in cts])
    assert missing == set(missing)


@pytest.mark.anyio
async def test_ops_lookup_identity(gen_artist):
    one = await DB.add(next(gen_artist))
    two = await DB.add(next(gen_artist))
    rest = await DB.add_all(*gen_artist)

    async with database.session() as s:
        get_one = await s.get(Artist, one.id)
        get_two = await s.get(Artist, two.id)
        s.add(get_one, get_two)

        artists, satisfied = ops.lookup_identity(
            s, Artist, [a.id for a in [one, two] + list(rest)]
        )

    assert set([a.name for a in artists]) == set([a.name for a in [one, two]])
    assert satisfied == set([one.id, two.id])


@pytest.mark.anyio
async def test_ops_get_image_orphans(gen_archive, gen_image):
    await DB.add(next(gen_archive))

    orphan_one = await DB.add(next(gen_image))
    orphan_two = await DB.add(next(gen_image))

    async with database.session() as s:
        orphans = set(await ops.get_image_orphans(s))

    assert orphans == set(
        [(orphan_one.id, orphan_one.hash), (orphan_two.id, orphan_two.hash)]
    )