from datetime import datetime, timedelta, timezone
import hircine.db as database
import hircine.db.models as models
import hircine.db.ops as ops
import pytest
from conftest import DB
from hircine.db.models import (
Artist,
Base,
Comic,
ComicTag,
DateTimeUTC,
MixinID,
Namespace,
Tag,
TagNamespaces,
)
from sqlalchemy.exc import StatementError
from sqlalchemy.orm import (
Mapped,
mapped_column,
)
class Date(MixinID, Base):
date: Mapped[datetime] = mapped_column(DateTimeUTC)
@pytest.mark.anyio
async def test_db_requires_tzinfo():
with pytest.raises(StatementError, match="tzinfo is required"):
await DB.add(Date(date=datetime(2019, 4, 22)))
@pytest.mark.anyio
async def test_db_converts_date_input_to_utc():
date = datetime(2019, 4, 22, tzinfo=timezone(timedelta(hours=-4)))
await DB.add(Date(date=date))
item = await DB.get(Date, 1)
assert item.date.tzinfo == timezone.utc
assert item.date == date
@pytest.mark.parametrize(
"modelcls,assoccls",
[
(models.Artist, models.ComicArtist),
(models.Circle, models.ComicCircle),
(models.Character, models.ComicCharacter),
(models.World, models.ComicWorld),
],
ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_models_retained_when_clearing_association(
empty_comic, modelcls, assoccls
):
model = modelcls(id=1, name="foo")
key = f"{modelcls.__name__.lower()}s"
comic = empty_comic
setattr(comic, key, [model])
comic = await DB.add(comic)
async with database.session() as s:
object = await s.get(Comic, comic.id)
setattr(object, key, [])
await s.commit()
assert await DB.get(assoccls, (comic.id, model.id)) is None
assert await DB.get(Comic, comic.id) is not None
assert await DB.get(modelcls, model.id) is not None
@pytest.mark.anyio
async def test_models_retained_when_clearing_comictag(empty_comic):
comic = await DB.add(empty_comic)
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="bar")
ct = ComicTag(comic_id=comic.id, namespace=namespace, tag=tag)
await DB.add(ct)
async with database.session() as s:
object = await s.get(Comic, comic.id)
object.tags = []
await s.commit()
assert await DB.get(ComicTag, (comic.id, ct.namespace_id, ct.tag_id)) is None
assert await DB.get(Namespace, namespace.id) is not None
assert await DB.get(Tag, tag.id) is not None
assert await DB.get(Comic, comic.id) is not None
@pytest.mark.parametrize(
"modelcls,assoccls",
[
(models.Artist, models.ComicArtist),
(models.Circle, models.ComicCircle),
(models.Character, models.ComicCharacter),
(models.World, models.ComicWorld),
],
ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_only_association_cleared_when_deleting(empty_comic, modelcls, assoccls):
model = modelcls(id=1, name="foo")
comic = empty_comic
setattr(comic, f"{modelcls.__name__.lower()}s", [model])
comic = await DB.add(comic)
await DB.delete(modelcls, model.id)
assert await DB.get(assoccls, (comic.id, model.id)) is None
assert await DB.get(Comic, comic.id) is not None
@pytest.mark.parametrize(
"deleted",
[
"namespace",
"tag",
],
)
@pytest.mark.anyio
async def test_only_comictag_association_cleared_when_deleting(empty_comic, deleted):
comic = await DB.add(empty_comic)
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="bar")
await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))
if deleted == "namespace":
await DB.delete(Namespace, namespace.id)
elif deleted == "tag":
await DB.delete(Tag, tag.id)
assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
if deleted == "namespace":
assert await DB.get(Tag, tag.id) is not None
elif deleted == "tag":
assert await DB.get(Namespace, namespace.id) is not None
assert await DB.get(Comic, comic.id) is not None
@pytest.mark.parametrize(
"modelcls,assoccls",
[
(models.Artist, models.ComicArtist),
(models.Circle, models.ComicCircle),
(models.Character, models.ComicCharacter),
(models.World, models.ComicWorld),
],
ids=["artists", "circles", "characters", "worlds"],
)
@pytest.mark.anyio
async def test_deleting_comic_only_clears_association(empty_comic, modelcls, assoccls):
model = modelcls(id=1, name="foo")
comic = empty_comic
setattr(comic, f"{modelcls.__name__.lower()}s", [model])
comic = await DB.add(comic)
await DB.delete(Comic, comic.id)
assert await DB.get(assoccls, (comic.id, model.id)) is None
assert await DB.get(modelcls, model.id) is not None
@pytest.mark.anyio
async def test_deleting_comic_only_clears_comictag(empty_comic):
comic = await DB.add(empty_comic)
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="bar")
await DB.add(ComicTag(comic_id=comic.id, namespace=namespace, tag=tag))
await DB.delete(Comic, comic.id)
assert await DB.get(ComicTag, (comic.id, namespace.id, tag.id)) is None
assert await DB.get(Tag, tag.id) is not None
assert await DB.get(Namespace, namespace.id) is not None
@pytest.mark.anyio
async def test_models_retained_when_clearing_tagnamespace():
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="foo", namespaces=[namespace])
tag = await DB.add(tag)
async with database.session() as s:
db_tag = await s.get(Tag, tag.id, options=Tag.load_full())
db_tag.namespaces = []
await s.commit()
assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
assert await DB.get(Namespace, namespace.id) is not None
assert await DB.get(Tag, tag.id) is not None
@pytest.mark.anyio
async def test_only_tagnamespace_cleared_when_deleting_tag():
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="foo", namespaces=[namespace])
tag = await DB.add(tag)
await DB.delete(Tag, tag.id)
assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
assert await DB.get(Namespace, namespace.id) is not None
assert await DB.get(Tag, tag.id) is None
@pytest.mark.anyio
async def test_only_tagnamespace_cleared_when_deleting_namespace():
namespace = Namespace(id=1, name="foo")
tag = Tag(id=1, name="foo", namespaces=[namespace])
tag = await DB.add(tag)
await DB.delete(Namespace, namespace.id)
assert await DB.get(TagNamespaces, (namespace.id, tag.id)) is None
assert await DB.get(Namespace, namespace.id) is None
assert await DB.get(Tag, tag.id) is not None
@pytest.mark.parametrize(
"use_identity_map",
[False, True],
ids=["without identity lookup", "with identity lookup"],
)
@pytest.mark.anyio
async def test_ops_get_all(gen_artist, use_identity_map):
artist = await DB.add(next(gen_artist))
have = list(await DB.add_all(*gen_artist))
have.append(artist)
missing_ids = [10, 20]
async with database.session() as s:
if use_identity_map:
s.add(artist)
artists, missing = await ops.get_all(
s,
Artist,
[a.id for a in have] + missing_ids,
use_identity_map=use_identity_map,
)
assert set([a.id for a in artists]) == set([a.id for a in have])
assert missing == set(missing_ids)
@pytest.mark.anyio
async def test_ops_get_all_names(gen_artist):
have = await DB.add_all(*gen_artist)
missing_names = ["arty", "farty"]
async with database.session() as s:
artists, missing = await ops.get_all_names(
s, Artist, [a.name for a in have] + missing_names
)
assert set([a.name for a in artists]) == set([a.name for a in have])
assert missing == set(missing_names)
@pytest.mark.parametrize(
"missing",
[[("foo", "bar"), ("qux", "qaz")], []],
ids=["missing", "no missing"],
)
@pytest.mark.anyio
async def test_ops_get_ctag_names(gen_comic, gen_tag, gen_namespace, missing):
comic = await DB.add(next(gen_comic))
have = [(ct.namespace.name, ct.tag.name) for ct in comic.tags]
async with database.session() as s:
cts, missing = await ops.get_ctag_names(s, comic.id, have + missing)
assert set(have) == set([(ct.namespace.name, ct.tag.name) for ct in cts])
assert missing == set(missing)
@pytest.mark.anyio
async def test_ops_lookup_identity(gen_artist):
one = await DB.add(next(gen_artist))
two = await DB.add(next(gen_artist))
rest = await DB.add_all(*gen_artist)
async with database.session() as s:
get_one = await s.get(Artist, one.id)
get_two = await s.get(Artist, two.id)
s.add(get_one, get_two)
artists, satisfied = ops.lookup_identity(
s, Artist, [a.id for a in [one, two] + list(rest)]
)
assert set([a.name for a in artists]) == set([a.name for a in [one, two]])
assert satisfied == set([one.id, two.id])
@pytest.mark.anyio
async def test_ops_get_image_orphans(gen_archive, gen_image):
await DB.add(next(gen_archive))
orphan_one = await DB.add(next(gen_image))
orphan_two = await DB.add(next(gen_image))
async with database.session() as s:
orphans = set(await ops.get_image_orphans(s))
assert orphans == set(
[(orphan_one.id, orphan_one.hash), (orphan_two.id, orphan_two.hash)]
)