diff options
Diffstat (limited to '')
-rw-r--r-- | src/hircine/db/ops.py | 200 |
1 files changed, 200 insertions, 0 deletions
diff --git a/src/hircine/db/ops.py b/src/hircine/db/ops.py new file mode 100644 index 0000000..c164cd2 --- /dev/null +++ b/src/hircine/db/ops.py @@ -0,0 +1,200 @@ +import random +from collections import defaultdict + +from sqlalchemy import delete, func, null, select, text, tuple_ +from sqlalchemy.orm import contains_eager, undefer +from sqlalchemy.orm.util import identity_key +from strawberry import UNSET + +from hircine.db.models import ( + Archive, + ComicTag, + Image, + Namespace, + Page, + Tag, + TagNamespaces, +) + + +def paginate(sql, pagination): + if not pagination: + return sql + + if pagination.items < 1 or pagination.page < 1: + return sql.limit(0) + + sql = sql.limit(pagination.items) + + if pagination.page > 0: + sql = sql.offset((pagination.page - 1) * pagination.items) + + return sql + + +def apply_filter(sql, filter): + if not filter: + return sql + + if filter.include is not UNSET: + sql = filter.include.match(sql, False) + if filter.exclude is not UNSET: + sql = filter.exclude.match(sql, True) + + return sql + + +def sort_random(seed): + if seed: + seed = seed % 1000000000 + else: + seed = random.randrange(1000000000) + + # https://www.sqlite.org/forum/forumpost/e2216583a4 + return text("sin(iid + :seed)").bindparams(seed=seed) + + +def apply_sort(sql, sort, default, tiebreaker): + if not sort: + return sql.order_by(*default, tiebreaker) + + direction = sort.direction.value + + if sort.on.value == "Random": + return sql.order_by(direction(sort_random(sort.seed))) + + sql = sql.options(undefer(sort.on.value)) + + return sql.order_by(direction(sort.on.value), tiebreaker) + + +async def query_all(session, model, pagination=None, filter=None, sort=None): + sql = select( + model, func.count(model.id).over().label("count"), model.id.label("iid") + ) + sql = apply_filter(sql, filter) + sql = apply_sort(sql, sort, model.default_order(), model.id) + sql = paginate(sql, pagination) + + count = 0 + objs = [] + + for row in await session.execute(sql): + if count == 0: + count = row.count + + objs.append(row[0]) + + return count, objs + + +async def has_with_name(session, model, name): + sql = select(model.id).where(model.name == name) + return bool((await session.scalars(sql)).unique().first()) + + +async def tag_restrictions(session, tuples=None): + sql = select(TagNamespaces) + + if tuples: + sql = sql.where( + tuple_(TagNamespaces.namespace_id, TagNamespaces.tag_id).in_(tuples) + ) + + namespaces = (await session.scalars(sql)).unique().all() + + ns_map = defaultdict(set) + + for n in namespaces: + ns_map[n.tag_id].add(n.namespace_id) + + return ns_map + + +def lookup_identity(session, model, ids): + objects = [] + satisfied = set() + + for id in ids: + object = session.identity_map.get(identity_key(model, id), None) + if object is not None: + objects.append(object) + satisfied.add(id) + + return objects, satisfied + + +async def get_all(session, model, ids, options=[], use_identity_map=False): + objects = [] + ids = set(ids) + + if use_identity_map: + objects, satisfied = lookup_identity(session, model, ids) + + ids = ids - satisfied + + if not ids: + return objects, set() + + sql = select(model).where(model.id.in_(ids)).options(*options) + + objects += (await session.scalars(sql)).unique().all() + + fetched_ids = [object.id for object in objects] + missing = set(ids) - set(fetched_ids) + + return objects, missing + + +async def get_all_names(session, model, names, options=[]): + names = set(names) + + sql = select(model).where(model.name.in_(names)).options(*options) + + objects = (await session.scalars(sql)).unique().all() + + fetched_names = [object.name for object in objects] + missing = set(names) - set(fetched_names) + + return objects, missing + + +async def get_ctag_names(session, comic_id, tuples): + sql = ( + select(ComicTag) + .join(ComicTag.namespace) + .options(contains_eager(ComicTag.namespace)) + .join(ComicTag.tag) + .options(contains_eager(ComicTag.tag)) + .where(ComicTag.comic_id == comic_id) + .where(tuple_(Namespace.name, Tag.name).in_(tuples)) + ) + objects = (await session.scalars(sql)).unique().all() + + fetched_tags = [(o.namespace.name, o.tag.name) for o in objects] + missing = set(tuples) - set(fetched_tags) + + return objects, missing + + +async def get_image_orphans(session): + sql = select(Image.id, Image.hash).join(Page, isouter=True).where(Page.id == null()) + + return (await session.execute(sql)).t + + +async def get_remaining_pages_for(session, archive_id): + sql = ( + select(Page.id) + .join(Archive) + .where(Archive.id == archive_id) + .where(Page.comic_id == null()) + ) + + return (await session.execute(sql)).scalars().all() + + +async def delete_all(session, model, ids): + result = await session.execute(delete(model).where(model.id.in_(ids))) + + return result.rowcount |