summaryrefslogtreecommitdiffstatshomepage
path: root/src/hircine/db/ops.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/hircine/db/ops.py200
1 files changed, 200 insertions, 0 deletions
diff --git a/src/hircine/db/ops.py b/src/hircine/db/ops.py
new file mode 100644
index 0000000..c164cd2
--- /dev/null
+++ b/src/hircine/db/ops.py
@@ -0,0 +1,200 @@
+import random
+from collections import defaultdict
+
+from sqlalchemy import delete, func, null, select, text, tuple_
+from sqlalchemy.orm import contains_eager, undefer
+from sqlalchemy.orm.util import identity_key
+from strawberry import UNSET
+
+from hircine.db.models import (
+ Archive,
+ ComicTag,
+ Image,
+ Namespace,
+ Page,
+ Tag,
+ TagNamespaces,
+)
+
+
+def paginate(sql, pagination):
+ if not pagination:
+ return sql
+
+ if pagination.items < 1 or pagination.page < 1:
+ return sql.limit(0)
+
+ sql = sql.limit(pagination.items)
+
+ if pagination.page > 0:
+ sql = sql.offset((pagination.page - 1) * pagination.items)
+
+ return sql
+
+
+def apply_filter(sql, filter):
+ if not filter:
+ return sql
+
+ if filter.include is not UNSET:
+ sql = filter.include.match(sql, False)
+ if filter.exclude is not UNSET:
+ sql = filter.exclude.match(sql, True)
+
+ return sql
+
+
+def sort_random(seed):
+ if seed:
+ seed = seed % 1000000000
+ else:
+ seed = random.randrange(1000000000)
+
+ # https://www.sqlite.org/forum/forumpost/e2216583a4
+ return text("sin(iid + :seed)").bindparams(seed=seed)
+
+
+def apply_sort(sql, sort, default, tiebreaker):
+ if not sort:
+ return sql.order_by(*default, tiebreaker)
+
+ direction = sort.direction.value
+
+ if sort.on.value == "Random":
+ return sql.order_by(direction(sort_random(sort.seed)))
+
+ sql = sql.options(undefer(sort.on.value))
+
+ return sql.order_by(direction(sort.on.value), tiebreaker)
+
+
+async def query_all(session, model, pagination=None, filter=None, sort=None):
+ sql = select(
+ model, func.count(model.id).over().label("count"), model.id.label("iid")
+ )
+ sql = apply_filter(sql, filter)
+ sql = apply_sort(sql, sort, model.default_order(), model.id)
+ sql = paginate(sql, pagination)
+
+ count = 0
+ objs = []
+
+ for row in await session.execute(sql):
+ if count == 0:
+ count = row.count
+
+ objs.append(row[0])
+
+ return count, objs
+
+
+async def has_with_name(session, model, name):
+ sql = select(model.id).where(model.name == name)
+ return bool((await session.scalars(sql)).unique().first())
+
+
+async def tag_restrictions(session, tuples=None):
+ sql = select(TagNamespaces)
+
+ if tuples:
+ sql = sql.where(
+ tuple_(TagNamespaces.namespace_id, TagNamespaces.tag_id).in_(tuples)
+ )
+
+ namespaces = (await session.scalars(sql)).unique().all()
+
+ ns_map = defaultdict(set)
+
+ for n in namespaces:
+ ns_map[n.tag_id].add(n.namespace_id)
+
+ return ns_map
+
+
+def lookup_identity(session, model, ids):
+ objects = []
+ satisfied = set()
+
+ for id in ids:
+ object = session.identity_map.get(identity_key(model, id), None)
+ if object is not None:
+ objects.append(object)
+ satisfied.add(id)
+
+ return objects, satisfied
+
+
+async def get_all(session, model, ids, options=[], use_identity_map=False):
+ objects = []
+ ids = set(ids)
+
+ if use_identity_map:
+ objects, satisfied = lookup_identity(session, model, ids)
+
+ ids = ids - satisfied
+
+ if not ids:
+ return objects, set()
+
+ sql = select(model).where(model.id.in_(ids)).options(*options)
+
+ objects += (await session.scalars(sql)).unique().all()
+
+ fetched_ids = [object.id for object in objects]
+ missing = set(ids) - set(fetched_ids)
+
+ return objects, missing
+
+
+async def get_all_names(session, model, names, options=[]):
+ names = set(names)
+
+ sql = select(model).where(model.name.in_(names)).options(*options)
+
+ objects = (await session.scalars(sql)).unique().all()
+
+ fetched_names = [object.name for object in objects]
+ missing = set(names) - set(fetched_names)
+
+ return objects, missing
+
+
+async def get_ctag_names(session, comic_id, tuples):
+ sql = (
+ select(ComicTag)
+ .join(ComicTag.namespace)
+ .options(contains_eager(ComicTag.namespace))
+ .join(ComicTag.tag)
+ .options(contains_eager(ComicTag.tag))
+ .where(ComicTag.comic_id == comic_id)
+ .where(tuple_(Namespace.name, Tag.name).in_(tuples))
+ )
+ objects = (await session.scalars(sql)).unique().all()
+
+ fetched_tags = [(o.namespace.name, o.tag.name) for o in objects]
+ missing = set(tuples) - set(fetched_tags)
+
+ return objects, missing
+
+
+async def get_image_orphans(session):
+ sql = select(Image.id, Image.hash).join(Page, isouter=True).where(Page.id == null())
+
+ return (await session.execute(sql)).t
+
+
+async def get_remaining_pages_for(session, archive_id):
+ sql = (
+ select(Page.id)
+ .join(Archive)
+ .where(Archive.id == archive_id)
+ .where(Page.comic_id == null())
+ )
+
+ return (await session.execute(sql)).scalars().all()
+
+
+async def delete_all(session, model, ids):
+ result = await session.execute(delete(model).where(model.id.in_(ids)))
+
+ return result.rowcount