diff options
author | Wolfgang Müller | 2024-03-05 18:08:09 +0100 |
---|---|---|
committer | Wolfgang Müller | 2024-03-05 19:25:59 +0100 |
commit | d1d654ebac2d51e3841675faeb56480e440f622f (patch) | |
tree | 56ef123c1a15a10dfd90836e4038e27efde950c6 /src | |
download | hircine-0.1.0.tar.gz |
Initial commit0.1.0
Diffstat (limited to 'src')
36 files changed, 4821 insertions, 0 deletions
diff --git a/src/hircine/__init__.py b/src/hircine/__init__.py new file mode 100644 index 0000000..38b969d --- /dev/null +++ b/src/hircine/__init__.py @@ -0,0 +1 @@ +codename = "Satanic Satyr" diff --git a/src/hircine/api/__init__.py b/src/hircine/api/__init__.py new file mode 100644 index 0000000..951f375 --- /dev/null +++ b/src/hircine/api/__init__.py @@ -0,0 +1,28 @@ +import strawberry + +int = strawberry.scalar(int, name="int") + + +class APIException(Exception): + def __init__(self, graphql_error): + self.graphql_error = graphql_error + + +class MutationContext: + """ + Relevant information and data for mutations as they are being resolved. + + Attributes: + input: The strawberry input object for the mutation + root: The root object being modified by the mutation + session: The active SQLAlchemy session + model: The SQLAlchemy modelclass of the object being modified by the mutation + multiple: True if multiple objects in the database are being modified + """ + + def __init__(self, input, root, session, multiple=False): + self.session = session + self.input = input + self.root = root + self.model = type(root) + self.multiple = multiple diff --git a/src/hircine/api/filters.py b/src/hircine/api/filters.py new file mode 100644 index 0000000..ab44cf9 --- /dev/null +++ b/src/hircine/api/filters.py @@ -0,0 +1,347 @@ +from abc import ABC, abstractmethod +from typing import Generic, List, Optional, TypeVar + +import strawberry +from sqlalchemy import and_, func, or_, select +from strawberry import UNSET + +import hircine.db +from hircine.db.models import ComicTag +from hircine.enums import Category, Censorship, Language, Rating + +T = TypeVar("T") + + +class Matchable(ABC): + """ + The filter interface is comprised of two methods, include and exclude, that + can freely modify an SQL statement passed to them. + """ + + @abstractmethod + def include(self, sql): + return sql + + @abstractmethod + def exclude(self, sql): + return sql + + +@strawberry.input +class AssociationFilter(Matchable): + any: Optional[List[int]] = strawberry.field(default_factory=lambda: None) + all: Optional[List[int]] = strawberry.field(default_factory=lambda: None) + exact: Optional[List[int]] = strawberry.field(default_factory=lambda: None) + empty: Optional[bool] = None + + def _exists(self, condition): + # The property.primaryjoin expression specifies the primary join path + # between the parent object of the column that was handed to the + # Matchable instance and the associated object. + # + # For example, if this AssociationFilter is parametrized as + # AssociationFilter[World], and is present on an input class that is + # mapped to the Comic model, the primaryjoin expression is as follows: + # + # comic.id = comic_worlds.comic_id + # + # This expression is used to correlate the subquery with the main query + # for the parent object. + # + # condition specifies any additional conditions we should match on. + # Usually these will come from the where generator, which correlates + # the secondary objects with the user-supplied ids. + return select(1).where((self.column.property.primaryjoin) & condition).exists() + + def _any_exist(self, items): + return self._exists(or_(*self._collect(items))) + + def _where_any_exist(self, sql): + return sql.where(self._any_exist(self.any)) + + def _where_none_exist(self, sql): + return sql.where(~self._any_exist(self.any)) + + def _all_exist(self, items): + return and_(self._exists(c) for c in self._collect(items)) + + def _where_all_exist(self, sql): + return sql.where(self._all_exist(self.all)) + + def _where_not_all_exist(self, sql): + return sql.where(~self._all_exist(self.all)) + + def _empty(self): + if self.empty: + return ~self._exists(True) + else: + return self._exists(True) + + def _count_of(self, column): + return ( + select(func.count(column)) + .where(self.column.property.primaryjoin) + .scalar_subquery() + ) + + def _exact(self): + return and_( + self._all_exist(self.exact), + self._count_of(self.remote_column) == len(self.exact), + ) + + def _collect(self, ids): + for id in ids: + yield from self.where(id) + + @property + def remote_column(self): + _, remote = self.column.property.local_remote_pairs + _, remote_column = remote + + return remote_column + + def where(self, id): + yield self.remote_column == id + + def include(self, sql): + # ignore if any/all is None, but when the user specifically includes an + # empty list, make sure to return no items + if self.any: + sql = self._where_any_exist(sql) + elif self.any == []: + sql = sql.where(False) + + if self.all: + sql = self._where_all_exist(sql) + elif self.all == []: + sql = sql.where(False) + + if self.empty is not None: + sql = sql.where(self._empty()) + + if self.exact is not None: + sql = sql.where(self._exact()) + + return sql + + def exclude(self, sql): + # in contrast to include() we can fully ignore if any/all is None or + # the empty list and just return all items, since the user effectively + # asks to exclude "nothing" + if self.any: + sql = self._where_none_exist(sql) + if self.all: + sql = self._where_not_all_exist(sql) + + if self.empty is not None: + sql = sql.where(~self._empty()) + + if self.exact is not None: + sql = sql.where(~self._exact()) + + return sql + + +@strawberry.input +class Root: + def match(self, sql, negate): + """ + Collect all relevant matchers from the input and construct the final + SQL statement. + + If the matcher is a boolean value (like favourite, organized, etc), use + it directly. Otherwise consult a Matchable's include or exclude method. + """ + + for field, matcher in self.__dict__.items(): + if matcher is UNSET: + continue + + column = getattr(self._model, field, None) + + if issubclass(type(matcher), Matchable): + matcher.column = column + if not negate: + sql = matcher.include(sql) + else: + sql = matcher.exclude(sql) + + if isinstance(matcher, bool): + if not negate: + sql = sql.where(column == matcher) + else: + sql = sql.where(column != matcher) + + return sql + + +# When resolving names for types that extend Generic, strawberry prepends the +# name of the type variable to the name of the generic class. Since all classes +# that extend this class already end in "Filter", we have to make sure not to +# name it "FilterInput" lest we end up with "ComicFilterFilterInput". +# +# For now, use the very generic "Input" name so that we end up with sane +# GraphQL type names like "ComicFilterInput". +@strawberry.input +class Input(Generic[T]): + include: Optional["T"] = UNSET + exclude: Optional["T"] = UNSET + + +@strawberry.input +class StringFilter(Matchable): + contains: Optional[str] = UNSET + + def _conditions(self): + if self.contains is not UNSET: + yield self.column.contains(self.contains) + + def include(self, sql): + conditions = list(self._conditions()) + if not conditions: + return sql + + return sql.where(and_(*conditions)) + + def exclude(self, sql): + conditions = [~c for c in self._conditions()] + if not conditions: + return sql + + return sql.where(and_(*conditions)) + + +@strawberry.input +class TagAssociationFilter(AssociationFilter): + """ + Tags need special handling since their IDs are strings instead of numbers. + We can keep the full logic of AssociationFilter and only need to make sure + we unpack the database IDs from the input IDs. + """ + + any: Optional[List[str]] = strawberry.field(default_factory=lambda: None) + all: Optional[List[str]] = strawberry.field(default_factory=lambda: None) + exact: Optional[List[str]] = strawberry.field(default_factory=lambda: None) + + def where(self, id): + try: + nid, tid = id.split(":") + except ValueError: + # invalid specification, force False and stop generator + yield False + return + + predicates = [] + if nid: + predicates.append(ComicTag.namespace_id == nid) + if tid: + predicates.append(ComicTag.tag_id == tid) + + if not predicates: + # empty specification, force False and stop generator + yield False + return + + yield and_(*predicates) + + @property + def remote_column(self): + return ComicTag.comic_id + + +@strawberry.input +class Filter(Matchable, Generic[T]): + any: Optional[List["T"]] = strawberry.field(default_factory=lambda: None) + empty: Optional[bool] = None + + def _empty(self): + if self.empty: + return self.column.is_(None) + else: + return ~self.column.is_(None) + + def _any_exist(self): + return self.column.in_(self.any) + + def include(self, sql): + if self.any: + sql = sql.where(self._any_exist()) + + if self.empty is not None: + sql = sql.where(self._empty()) + + return sql + + def exclude(self, sql): + if self.any: + sql = sql.where(~self._any_exist()) + + if self.empty is not None: + sql = sql.where(~self._empty()) + + return sql + + +@hircine.db.model("Comic") +@strawberry.input +class ComicFilter(Root): + title: Optional[StringFilter] = UNSET + original_title: Optional[StringFilter] = UNSET + url: Optional[StringFilter] = UNSET + language: Optional[Filter[Language]] = UNSET + tags: Optional[TagAssociationFilter] = UNSET + artists: Optional[AssociationFilter] = UNSET + characters: Optional[AssociationFilter] = UNSET + circles: Optional[AssociationFilter] = UNSET + worlds: Optional[AssociationFilter] = UNSET + category: Optional[Filter[Category]] = UNSET + censorship: Optional[Filter[Censorship]] = UNSET + rating: Optional[Filter[Rating]] = UNSET + favourite: Optional[bool] = UNSET + organized: Optional[bool] = UNSET + bookmarked: Optional[bool] = UNSET + + +@hircine.db.model("Archive") +@strawberry.input +class ArchiveFilter(Root): + path: Optional[StringFilter] = UNSET + organized: Optional[bool] = UNSET + + +@hircine.db.model("Artist") +@strawberry.input +class ArtistFilter(Root): + name: Optional[StringFilter] = UNSET + + +@hircine.db.model("Character") +@strawberry.input +class CharacterFilter(Root): + name: Optional[StringFilter] = UNSET + + +@hircine.db.model("Circle") +@strawberry.input +class CircleFilter(Root): + name: Optional[StringFilter] = UNSET + + +@hircine.db.model("Namespace") +@strawberry.input +class NamespaceFilter(Root): + name: Optional[StringFilter] = UNSET + + +@hircine.db.model("Tag") +@strawberry.input +class TagFilter(Root): + name: Optional[StringFilter] = UNSET + namespaces: Optional[AssociationFilter] = UNSET + + +@hircine.db.model("World") +@strawberry.input +class WorldFilter(Root): + name: Optional[StringFilter] = UNSET diff --git a/src/hircine/api/inputs.py b/src/hircine/api/inputs.py new file mode 100644 index 0000000..c88bcce --- /dev/null +++ b/src/hircine/api/inputs.py @@ -0,0 +1,578 @@ +import datetime +from abc import ABC, abstractmethod +from typing import List, Optional, Type + +import strawberry +from sqlalchemy.orm.util import identity_key +from strawberry import UNSET + +import hircine.db +import hircine.db.ops as ops +from hircine.api import APIException, MutationContext +from hircine.api.responses import ( + IDNotFoundError, + InvalidParameterError, + PageClaimedError, + PageRemoteError, +) +from hircine.db.models import Archive, Base, Comic, ComicTag, Namespace, Tag +from hircine.enums import ( + Category, + Censorship, + Direction, + Language, + Layout, + OnMissing, + Rating, + UpdateMode, +) + + +def add_input_cls(modelcls): + return globals().get(f"Add{modelcls.__name__}Input") + + +def update_input_cls(modelcls): + return globals().get(f"Update{modelcls.__name__}Input") + + +def upsert_input_cls(modelcls): + return globals().get(f"Upsert{modelcls.__name__}Input") + + +class Fetchable(ABC): + """ + When mutating a model's associations, the API requires the user to pass + referential IDs. These may be referencing new items to add to a list of + associations, or items that should be removed. + + For example, the updateTags mutation requires as its input for the + "namespaces" field a list of numerical IDs: + + mutation updateTags { + updateTags(ids: 1, input: {namespaces: {ids: [1, 2]}}) { [...] } + } + + Mutations make heavy use of SQLAlchemy's ORM features to reconcile changes + between related objects (like Tags and Namespaces). In the example above, + to reconcile the changes made to a Tag's valid Namespaces, SQLAlchemy needs + to know about three objects: the Tag that is being modified and the two + Namespaces being added to it. + + This way SQLAlchemy can figure out whether it needs to add those Namespaces + to the Tag (or whether they're already there and can be skipped) and will, + upon commit, update the relevant tables automatically without us having to + emit custom SQL. + + SQLAlchemy cannot know about an object's relationships by ID alone, so it + needs to be fetched from the database first. The Fetchable class + facilitates this. It provides an abstract "fetch" method that, given a + MutationContext, will return any relevant objects from the database. + + Additionally, fetched items can be "constrained" to enforce API rules. + """ + + _model: Type[Base] + + @abstractmethod + async def fetch(self, ctx: MutationContext): + pass + + def update_mode(self): + try: + return self.options.mode + except AttributeError: + return UpdateMode.REPLACE + + @classmethod + async def constrain_item(cls, item, ctx: MutationContext): + pass + + +class FetchableID(Fetchable): + """ + A Fetchable for numerical IDs. Database queries are batched to avoid an + excess amount of SQL queries. + """ + + @classmethod + async def get_from_id(cls, id, ctx: MutationContext): + item, *_ = await cls.get_from_ids([id], ctx) + + return item + + @classmethod + async def get_from_ids(cls, ids, ctx: MutationContext): + items, missing = await ops.get_all( + ctx.session, cls._model, ids, use_identity_map=True + ) + + if missing: + raise APIException(IDNotFoundError(cls._model, missing.pop())) + + for item in items: + await cls.constrain_item(item, ctx) + + return items + + +class FetchableName(Fetchable): + """ + A Fetchable for textual IDs (used only for Tags). As with FetchableID, + queries are batched. + """ + + @classmethod + async def get_from_names(cls, names, ctx: MutationContext, on_missing: OnMissing): + for name in names: + if not name: + raise APIException( + InvalidParameterError( + parameter=f"{cls._model.__name__}.name", text="cannot be empty" + ) + ) + + items, missing = await ops.get_all_names(ctx.session, cls._model, names) + + if on_missing == OnMissing.CREATE: + for m in missing: + items.append(cls._model(name=m)) + + return items + + +@strawberry.input +class Input(FetchableID): + id: int + + async def fetch(self, ctx: MutationContext): + return await self.get_from_id(self.id, ctx) + + +@strawberry.input +class InputList(FetchableID): + ids: List[int] + + async def fetch(self, ctx: MutationContext): + if not self.ids: + return [] + + return await self.get_from_ids(self.ids, ctx) + + +@strawberry.input +class UpdateOptions: + mode: UpdateMode = UpdateMode.REPLACE + + +@strawberry.input +class UpdateInputList(InputList): + options: Optional[UpdateOptions] = UNSET + + +@strawberry.input +class Pagination: + page: int = 1 + items: int = 40 + + +@hircine.db.model("Archive") +@strawberry.input +class ArchiveInput(Input): + pass + + +@hircine.db.model("Page") +@strawberry.input +class UniquePagesInput(InputList): + @classmethod + async def constrain_item(cls, page, ctx): + if page.comic_id: + raise APIException(PageClaimedError(id=page.id, comic_id=page.comic_id)) + + if page.archive_id != ctx.input.archive.id: + raise APIException(PageRemoteError(id=page.id, archive_id=page.archive_id)) + + +@hircine.db.model("Page") +@strawberry.input +class UniquePagesUpdateInput(UpdateInputList): + @classmethod + async def constrain_item(cls, page, ctx): + if page.comic_id and page.comic_id != ctx.root.id: + raise APIException(PageClaimedError(id=page.id, comic_id=page.comic_id)) + + if page.archive_id != ctx.root.archive_id: + raise APIException(PageRemoteError(id=page.id, archive_id=page.archive_id)) + + +@hircine.db.model("Namespace") +@strawberry.input +class NamespacesInput(InputList): + pass + + +@hircine.db.model("Namespace") +@strawberry.input +class NamespacesUpdateInput(UpdateInputList): + pass + + +@hircine.db.model("Page") +@strawberry.input +class CoverInput(Input): + async def fetch(self, ctx: MutationContext): + page = await self.get_from_id(self.id, ctx) + return page.image + + @classmethod + async def constrain_item(cls, page, ctx): + if page.archive_id != ctx.input.archive.id: + raise APIException(PageRemoteError(id=page.id, archive_id=page.archive_id)) + + +@hircine.db.model("Page") +@strawberry.input +class CoverUpdateInput(CoverInput): + @classmethod + async def constrain_item(cls, page, ctx): + if ctx.model == Comic: + id = ctx.root.archive_id + elif ctx.model == Archive: + id = ctx.root.id + + if page.archive_id != id: + raise APIException(PageRemoteError(id=page.id, archive_id=page.archive_id)) + + +@hircine.db.model("Character") +@strawberry.input +class CharactersUpdateInput(UpdateInputList): + pass + + +@hircine.db.model("Artist") +@strawberry.input +class ArtistsUpdateInput(UpdateInputList): + pass + + +@hircine.db.model("Circle") +@strawberry.input +class CirclesUpdateInput(UpdateInputList): + pass + + +@hircine.db.model("World") +@strawberry.input +class WorldsUpdateInput(UpdateInputList): + pass + + +@strawberry.input +class ComicTagsUpdateInput(UpdateInputList): + ids: List[str] = strawberry.field(default_factory=lambda: []) + + @classmethod + def parse_input(cls, id): + try: + return [int(i) for i in id.split(":")] + except ValueError: + raise APIException( + InvalidParameterError( + parameter="id", + text="ComicTag ID must be specified as <namespace_id>:<tag_id>", + ) + ) + + @classmethod + async def get_from_ids(cls, ids, ctx: MutationContext): + comic = ctx.root + + ctags = [] + remaining = set() + + for id in ids: + nid, tid = cls.parse_input(id) + + key = identity_key(ComicTag, (comic.id, nid, tid)) + item = ctx.session.identity_map.get(key, None) + + if item is not None: + ctags.append(item) + else: + remaining.add((nid, tid)) + + if not remaining: + return ctags + + nids, tids = zip(*remaining) + + namespaces, missing = await ops.get_all( + ctx.session, Namespace, nids, use_identity_map=True + ) + if missing: + raise APIException(IDNotFoundError(Namespace, missing.pop())) + + tags, missing = await ops.get_all(ctx.session, Tag, tids, use_identity_map=True) + if missing: + raise APIException(IDNotFoundError(Tag, missing.pop())) + + for nid, tid in remaining: + namespace = ctx.session.identity_map.get(identity_key(Namespace, nid)) + tag = ctx.session.identity_map.get(identity_key(Tag, tid)) + + ctags.append(ComicTag(namespace=namespace, tag=tag)) + + return ctags + + +@strawberry.input +class UpsertOptions: + on_missing: OnMissing = OnMissing.IGNORE + + +@strawberry.input +class UpsertInputList(FetchableName): + names: List[str] = strawberry.field(default_factory=lambda: []) + options: Optional[UpsertOptions] = UNSET + + async def fetch(self, ctx: MutationContext): + if not self.names: + return [] + + options = self.options or UpsertOptions() + return await self.get_from_names(self.names, ctx, on_missing=options.on_missing) + + def update_mode(self): + return UpdateMode.ADD + + +@hircine.db.model("Character") +@strawberry.input +class CharactersUpsertInput(UpsertInputList): + pass + + +@hircine.db.model("Artist") +@strawberry.input +class ArtistsUpsertInput(UpsertInputList): + pass + + +@hircine.db.model("Circle") +@strawberry.input +class CirclesUpsertInput(UpsertInputList): + pass + + +@hircine.db.model("World") +@strawberry.input +class WorldsUpsertInput(UpsertInputList): + pass + + +@strawberry.input +class ComicTagsUpsertInput(UpsertInputList): + @classmethod + def parse_input(cls, name): + try: + namespace, tag = name.split(":") + + if not namespace or not tag: + raise ValueError() + + return namespace, tag + except ValueError: + raise APIException( + InvalidParameterError( + parameter="name", + text="ComicTag name must be specified as <namespace>:<tag>", + ) + ) + + @classmethod + async def get_from_names(cls, input, ctx: MutationContext, on_missing: OnMissing): + comic = ctx.root + + names = set() + for name in input: + names.add(cls.parse_input(name)) + + ctags, missing = await ops.get_ctag_names(ctx.session, comic.id, names) + + if not missing: + return ctags + + async def lookup(names, model): + have, missing = await ops.get_all_names( + ctx.session, model, names, options=model.load_full() + ) + dict = {} + + for item in have: + dict[item.name] = (item, True) + for item in missing: + dict[item] = (model(name=item), False) + + return dict + + remaining_ns, remaining_tags = zip(*missing) + + namespaces = await lookup(remaining_ns, Namespace) + tags = await lookup(remaining_tags, Tag) + + if on_missing == OnMissing.CREATE: + for ns, tag in missing: + namespace, _ = namespaces[ns] + tag, _ = tags[tag] + + tag.namespaces.append(namespace) + + ctags.append(ComicTag(namespace=namespace, tag=tag)) + + elif on_missing == OnMissing.IGNORE: + resident = [] + + for ns, tag in missing: + namespace, namespace_resident = namespaces[ns] + tag, tag_resident = tags[tag] + + if namespace_resident and tag_resident: + resident.append((namespace, tag)) + + restrictions = await ops.tag_restrictions( + ctx.session, [(ns.id, tag.id) for ns, tag in resident] + ) + + for namespace, tag in resident: + if namespace.id in restrictions[tag.id]: + ctags.append(ComicTag(namespace=namespace, tag=tag)) + + return ctags + + +@strawberry.input +class UpdateArchiveInput: + cover: Optional[CoverUpdateInput] = UNSET + organized: Optional[bool] = UNSET + + +@strawberry.input +class AddComicInput: + title: str + archive: ArchiveInput + pages: UniquePagesInput + cover: CoverInput + + +@strawberry.input +class UpdateComicInput: + title: Optional[str] = UNSET + original_title: Optional[str] = UNSET + cover: Optional[CoverUpdateInput] = UNSET + pages: Optional[UniquePagesUpdateInput] = UNSET + url: Optional[str] = UNSET + language: Optional[Language] = UNSET + date: Optional[datetime.date] = UNSET + direction: Optional[Direction] = UNSET + layout: Optional[Layout] = UNSET + rating: Optional[Rating] = UNSET + category: Optional[Category] = UNSET + censorship: Optional[Censorship] = UNSET + tags: Optional[ComicTagsUpdateInput] = UNSET + artists: Optional[ArtistsUpdateInput] = UNSET + characters: Optional[CharactersUpdateInput] = UNSET + circles: Optional[CirclesUpdateInput] = UNSET + worlds: Optional[WorldsUpdateInput] = UNSET + favourite: Optional[bool] = UNSET + organized: Optional[bool] = UNSET + bookmarked: Optional[bool] = UNSET + + +@strawberry.input +class UpsertComicInput: + title: Optional[str] = UNSET + original_title: Optional[str] = UNSET + url: Optional[str] = UNSET + language: Optional[Language] = UNSET + date: Optional[datetime.date] = UNSET + direction: Optional[Direction] = UNSET + layout: Optional[Layout] = UNSET + rating: Optional[Rating] = UNSET + category: Optional[Category] = UNSET + censorship: Optional[Censorship] = UNSET + tags: Optional[ComicTagsUpsertInput] = UNSET + artists: Optional[ArtistsUpsertInput] = UNSET + characters: Optional[CharactersUpsertInput] = UNSET + circles: Optional[CirclesUpsertInput] = UNSET + worlds: Optional[WorldsUpsertInput] = UNSET + favourite: Optional[bool] = UNSET + organized: Optional[bool] = UNSET + bookmarked: Optional[bool] = UNSET + + +@strawberry.input +class AddNamespaceInput: + name: str + sort_name: Optional[str] = UNSET + + +@strawberry.input +class UpdateNamespaceInput: + name: Optional[str] = UNSET + sort_name: Optional[str] = UNSET + + +@strawberry.input +class AddTagInput: + name: str + description: Optional[str] = None + namespaces: Optional[NamespacesInput] = UNSET + + +@strawberry.input +class UpdateTagInput: + name: Optional[str] = UNSET + description: Optional[str] = UNSET + namespaces: Optional[NamespacesUpdateInput] = UNSET + + +@strawberry.input +class AddArtistInput: + name: str + + +@strawberry.input +class UpdateArtistInput: + name: Optional[str] = UNSET + + +@strawberry.input +class AddCharacterInput: + name: str + + +@strawberry.input +class UpdateCharacterInput: + name: Optional[str] = UNSET + + +@strawberry.input +class AddCircleInput: + name: str + + +@strawberry.input +class UpdateCircleInput: + name: Optional[str] = UNSET + + +@strawberry.input +class AddWorldInput: + name: str + + +@strawberry.input +class UpdateWorldInput: + name: Optional[str] = UNSET diff --git a/src/hircine/api/mutation/__init__.py b/src/hircine/api/mutation/__init__.py new file mode 100644 index 0000000..93c2b4a --- /dev/null +++ b/src/hircine/api/mutation/__init__.py @@ -0,0 +1,69 @@ +import strawberry + +from hircine.api.responses import ( + AddComicResponse, + AddResponse, + DeleteResponse, + UpdateResponse, + UpsertResponse, +) +from hircine.db.models import ( + Archive, + Artist, + Character, + Circle, + Comic, + Namespace, + Tag, + World, +) + +from .resolvers import ( + add, + delete, + post_add_comic, + post_delete_archive, + update, + upsert, +) + + +def mutate(resolver): + return strawberry.mutation(resolver=resolver) + + +@strawberry.type +class Mutation: + update_archives: UpdateResponse = mutate(update(Archive)) + delete_archives: DeleteResponse = mutate( + delete(Archive, post_delete=post_delete_archive) + ) + + add_comic: AddComicResponse = mutate(add(Comic, post_add=post_add_comic)) + delete_comics: DeleteResponse = mutate(delete(Comic)) + update_comics: UpdateResponse = mutate(update(Comic)) + upsert_comics: UpsertResponse = mutate(upsert(Comic)) + + add_namespace: AddResponse = mutate(add(Namespace)) + delete_namespaces: DeleteResponse = mutate(delete(Namespace)) + update_namespaces: UpdateResponse = mutate(update(Namespace)) + + add_tag: AddResponse = mutate(add(Tag)) + delete_tags: DeleteResponse = mutate(delete(Tag)) + update_tags: UpdateResponse = mutate(update(Tag)) + + add_circle: AddResponse = mutate(add(Circle)) + delete_circles: DeleteResponse = mutate(delete(Circle)) + update_circles: UpdateResponse = mutate(update(Circle)) + + add_artist: AddResponse = mutate(add(Artist)) + delete_artists: DeleteResponse = mutate(delete(Artist)) + update_artists: UpdateResponse = mutate(update(Artist)) + + add_character: AddResponse = mutate(add(Character)) + delete_characters: DeleteResponse = mutate(delete(Character)) + update_characters: UpdateResponse = mutate(update(Character)) + + add_world: AddResponse = mutate(add(World)) + delete_worlds: DeleteResponse = mutate(delete(World)) + update_worlds: UpdateResponse = mutate(update(World)) diff --git a/src/hircine/api/mutation/resolvers.py b/src/hircine/api/mutation/resolvers.py new file mode 100644 index 0000000..069669e --- /dev/null +++ b/src/hircine/api/mutation/resolvers.py @@ -0,0 +1,217 @@ +from datetime import datetime, timezone +from pathlib import Path +from typing import List + +from strawberry import UNSET + +import hircine.db as db +import hircine.db.ops as ops +import hircine.thumbnailer as thumb +from hircine.api import APIException, MutationContext +from hircine.api.inputs import ( + Fetchable, + add_input_cls, + update_input_cls, + upsert_input_cls, +) +from hircine.api.responses import ( + AddComicSuccess, + AddSuccess, + DeleteSuccess, + IDNotFoundError, + InvalidParameterError, + NameExistsError, + UpdateSuccess, + UpsertSuccess, +) +from hircine.config import get_dir_structure +from hircine.db.models import Comic, Image, MixinModifyDates +from hircine.enums import UpdateMode + + +async def fetch_fields(input, ctx: MutationContext): + """ + Given a mutation input and a context, fetch and yield all relevant objects + from the database. + + If the item requested is a Fetchable input, await its resolution, otherwise + use the item "verbatim" after checking any API restrictions. + """ + + for field, value in input.__dict__.items(): + if field == "id" or value == UNSET: + continue + + if issubclass(type(value), Fetchable): + yield field, await value.fetch(ctx), value.update_mode() + else: + if isinstance(value, str) and not value: + value = None + + await check_constraints(ctx, field, value) + yield field, value, UpdateMode.REPLACE + + +async def check_constraints(ctx, field, value): + column = getattr(ctx.model.__table__.c, field) + + if value is None and not column.nullable: + raise APIException( + InvalidParameterError(parameter=field, text="cannot be empty") + ) + + if column.unique and ctx.multiple: + raise APIException( + InvalidParameterError( + parameter="name", text="Cannot bulk-update unique fields" + ) + ) + + if column.unique and field == "name": + if value != ctx.root.name: + if await ops.has_with_name(ctx.session, ctx.model, value): + raise APIException(NameExistsError(ctx.model)) + + +# Mutation resolvers use the factory pattern. Given a modelcls, the factory +# will return a strawberry resolver that is passed the corresponding Input +# type. + + +def add(modelcls, post_add=None): + async def inner(input: add_input_cls(modelcls)): + returnval = None + + async with db.session() as s: + try: + object = modelcls() + ctx = MutationContext(input, object, s) + + async for field, value, _ in fetch_fields(input, ctx): + setattr(object, field, value) + except APIException as e: + return e.graphql_error + + s.add(object) + await s.flush() + + if post_add: + returnval = await post_add(s, input, object) + + await s.commit() + + if returnval: + return returnval + else: + return AddSuccess(modelcls, object.id) + + return inner + + +async def post_add_comic(session, input, comic): + remaining_pages = await ops.get_remaining_pages_for(session, input.archive.id) + has_remaining = len(remaining_pages) > 0 + + if not has_remaining: + comic.archive.organized = True + + return AddComicSuccess(Comic, comic.id, has_remaining) + + +def update_attr(object, field, value, mode): + if mode != UpdateMode.REPLACE and isinstance(value, list): + attr = getattr(object, field) + match mode: + case UpdateMode.ADD: + value.extend(attr) + case UpdateMode.REMOVE: + value = list(set(attr) - set(value)) + + setattr(object, field, value) + + +async def _update(ids: List[int], modelcls, input, successcls): + multiple = len(ids) > 1 + + async with db.session() as s: + needed = [k for k, v in input.__dict__.items() if v is not UNSET] + + objects, missing = await ops.get_all( + s, modelcls, ids, modelcls.load_update(needed) + ) + + if missing: + return IDNotFoundError(modelcls, missing.pop()) + + for object in objects: + s.add(object) + + try: + ctx = MutationContext(input, object, s, multiple=multiple) + + async for field, value, mode in fetch_fields(input, ctx): + update_attr(object, field, value, mode) + except APIException as e: + return e.graphql_error + + if isinstance(object, MixinModifyDates) and s.is_modified(object): + object.updated_at = datetime.now(tz=timezone.utc) + + await s.commit() + + return successcls() + + +def update(modelcls): + async def inner(ids: List[int], input: update_input_cls(modelcls)): + return await _update(ids, modelcls, input, UpdateSuccess) + + return inner + + +def upsert(modelcls): + async def inner(ids: List[int], input: upsert_input_cls(modelcls)): + return await _update(ids, modelcls, input, UpsertSuccess) + + return inner + + +def delete(modelcls, post_delete=None): + async def inner(ids: List[int]): + async with db.session() as s: + objects, missing = await ops.get_all(s, modelcls, ids) + if missing: + return IDNotFoundError(modelcls, missing.pop()) + + for object in objects: + await s.delete(object) + + await s.flush() + + if post_delete: + await post_delete(s, objects) + + await s.commit() + + return DeleteSuccess() + + return inner + + +async def post_delete_archive(session, objects): + for archive in objects: + Path(archive.path).unlink(missing_ok=True) + + dirs = get_dir_structure() + orphans = await ops.get_image_orphans(session) + + ids = [] + for id, hash in orphans: + ids.append(id) + for suffix in ["full", "thumb"]: + Path(thumb.object_path(dirs.objects, hash, suffix)).unlink(missing_ok=True) + + if not ids: + return + + await ops.delete_all(session, Image, ids) diff --git a/src/hircine/api/query/__init__.py b/src/hircine/api/query/__init__.py new file mode 100644 index 0000000..9d81989 --- /dev/null +++ b/src/hircine/api/query/__init__.py @@ -0,0 +1,54 @@ +from typing import List + +import strawberry + +import hircine.api.responses as rp +import hircine.db.models as models +from hircine.api.types import ( + Archive, + Artist, + Character, + Circle, + Comic, + ComicScraper, + ComicTag, + FilterResult, + Namespace, + Tag, + World, +) + +from .resolvers import ( + all, + comic_scrapers, + comic_tags, + scrape_comic, + single, +) + + +def query(resolver): + return strawberry.field(resolver=resolver) + + +@strawberry.type +class Query: + archive: rp.ArchiveResponse = query(single(models.Archive, full=True)) + archives: FilterResult[Archive] = query(all(models.Archive)) + artist: rp.ArtistResponse = query(single(models.Artist)) + artists: FilterResult[Artist] = query(all(models.Artist)) + character: rp.CharacterResponse = query(single(models.Character)) + characters: FilterResult[Character] = query(all(models.Character)) + circle: rp.CircleResponse = query(single(models.Circle)) + circles: FilterResult[Circle] = query(all(models.Circle)) + comic: rp.ComicResponse = query(single(models.Comic, full=True)) + comic_scrapers: List[ComicScraper] = query(comic_scrapers) + comic_tags: FilterResult[ComicTag] = query(comic_tags) + comics: FilterResult[Comic] = query(all(models.Comic)) + namespace: rp.NamespaceResponse = query(single(models.Namespace)) + namespaces: FilterResult[Namespace] = query(all(models.Namespace)) + tag: rp.TagResponse = query(single(models.Tag, full=True)) + tags: FilterResult[Tag] = query(all(models.Tag)) + world: rp.WorldResponse = query(single(models.World)) + worlds: FilterResult[World] = query(all(models.World)) + scrape_comic: rp.ScrapeComicResponse = query(scrape_comic) diff --git a/src/hircine/api/query/resolvers.py b/src/hircine/api/query/resolvers.py new file mode 100644 index 0000000..a18e63e --- /dev/null +++ b/src/hircine/api/query/resolvers.py @@ -0,0 +1,146 @@ +from typing import Optional + +import hircine.api.filters as filters +import hircine.api.sort as sort +import hircine.api.types as types +import hircine.db as db +import hircine.db.models as models +import hircine.db.ops as ops +import hircine.plugins as plugins +from hircine.api.filters import Input as FilterInput +from hircine.api.inputs import Pagination +from hircine.api.responses import ( + IDNotFoundError, + ScraperError, + ScraperNotAvailableError, + ScraperNotFoundError, +) +from hircine.api.sort import Input as SortInput +from hircine.api.types import ( + ComicScraper, + ComicTag, + FilterResult, + FullComic, + ScrapeComicResult, + ScrapedComic, +) +from hircine.scraper import ScrapeError + +# Query resolvers use the factory pattern. Given a model, the factory will +# return a strawberry resolver that is passed the corresponding IDs + + +def single(model, full=False): + modelname = model.__name__ + if full: + modelname = f"Full{modelname}" + + typecls = getattr(types, modelname) + + async def inner(id: int): + async with db.session() as s: + options = model.load_full() if full else [] + obj = await s.get(model, id, options=options) + + if not obj: + return IDNotFoundError(model, id) + + return typecls(obj) + + return inner + + +def all(model): + typecls = getattr(types, model.__name__) + filtercls = getattr(filters, f"{model.__name__}Filter") + sortcls = getattr(sort, f"{model.__name__}Sort") + + async def inner( + pagination: Optional[Pagination] = None, + filter: Optional[FilterInput[filtercls]] = None, + sort: Optional[SortInput[sortcls]] = None, + ): + async with db.session() as s: + count, objs = await ops.query_all( + s, model, pagination=pagination, filter=filter, sort=sort + ) + + return FilterResult(count=count, edges=[typecls(obj) for obj in objs]) + + return inner + + +def namespace_tag_combinations_for(namespaces, tags, restrictions): + for namespace in namespaces: + for tag in tags: + valid_ids = restrictions.get(tag.id, []) + + if namespace.id in valid_ids: + yield ComicTag(namespace=namespace, tag=tag) + + +async def comic_tags(for_filter: bool = False): + async with db.session() as s: + _, tags = await ops.query_all(s, models.Tag) + _, namespaces = await ops.query_all(s, models.Namespace) + restrictions = await ops.tag_restrictions(s) + + combinations = list(namespace_tag_combinations_for(namespaces, tags, restrictions)) + + if not for_filter: + return FilterResult(count=len(combinations), edges=combinations) + + matchers = [] + + for namespace in namespaces: + matchers.append(ComicTag(namespace=namespace)) + for tag in tags: + matchers.append(ComicTag(tag=tag)) + + matchers.extend(combinations) + + return FilterResult(count=len(matchers), edges=matchers) + + +async def comic_scrapers(id: int): + async with db.session() as s: + comic = await s.get(models.Comic, id, options=models.Comic.load_full()) + + if not comic: + return [] + + scrapers = [] + for id, cls in sorted(plugins.get_scrapers(), key=lambda p: p[1].name): + scraper = cls(comic) + if scraper.is_available: + scrapers.append(ComicScraper(id, scraper)) + + return scrapers + + +async def scrape_comic(id: int, scraper: str): + scrapercls = plugins.get_scraper(scraper) + + if not scrapercls: + return ScraperNotFoundError(name=scraper) + + async with db.session() as s: + comic = await s.get(models.Comic, id, options=models.Comic.load_full()) + + if not comic: + return IDNotFoundError(models.Comic, id) + + instance = scrapercls(FullComic(comic)) + + if not instance.is_available: + return ScraperNotAvailableError(scraper=scraper, comic_id=id) + + gen = instance.collect(plugins.transformers) + + try: + return ScrapeComicResult( + data=ScrapedComic.from_generator(gen), + warnings=instance.get_warnings(), + ) + except ScrapeError as e: + return ScraperError(error=str(e)) diff --git a/src/hircine/api/responses.py b/src/hircine/api/responses.py new file mode 100644 index 0000000..99d5113 --- /dev/null +++ b/src/hircine/api/responses.py @@ -0,0 +1,219 @@ +from typing import Annotated, Union + +import strawberry + +from hircine.api.types import ( + Artist, + Character, + Circle, + FullArchive, + FullComic, + FullTag, + Namespace, + ScrapeComicResult, + World, +) + + +@strawberry.interface +class Success: + message: str + + +@strawberry.type +class AddSuccess(Success): + id: int + + def __init__(self, modelcls, id): + self.id = id + self.message = f"{modelcls.__name__} added" + + +@strawberry.type +class AddComicSuccess(AddSuccess): + archive_pages_remaining: bool + + def __init__(self, modelcls, id, archive_pages_remaining): + super().__init__(modelcls, id) + self.archive_pages_remaining = archive_pages_remaining + + +@strawberry.type +class UpdateSuccess(Success): + def __init__(self): + self.message = "Changes saved" + + +@strawberry.type +class UpsertSuccess(Success): + def __init__(self): + self.message = "Changes saved" + + +@strawberry.type +class DeleteSuccess(Success): + def __init__(self): + self.message = "Deletion successful" + + +@strawberry.interface +class Error: + @strawberry.field + def message(self) -> str: # pragma: no cover + return "An error occurred" + + +@strawberry.type +class InvalidParameterError(Error): + parameter: str + text: strawberry.Private[str] + + @strawberry.field + def message(self) -> str: + return f"Invalid parameter '{self.parameter}': {self.text}" + + +@strawberry.type +class IDNotFoundError(Error): + id: int + model: strawberry.Private[str] + + def __init__(self, modelcls, id): + self.id = id + self.model = modelcls.__name__ + + @strawberry.field + def message(self) -> str: + return f"{self.model} ID not found: '{self.id}'" + + +@strawberry.type +class ScraperNotFoundError(Error): + name: str + + @strawberry.field + def message(self) -> str: + return f"Scraper not found: '{self.name}'" + + +@strawberry.type +class NameExistsError(Error): + model: strawberry.Private[str] + + def __init__(self, modelcls): + self.model = modelcls.__name__ + + @strawberry.field + def message(self) -> str: + return f"Another {self.model} with this name exists" + + +@strawberry.type +class PageClaimedError(Error): + id: int + comic_id: int + + @strawberry.field + def message(self) -> str: + return f"Page ID {self.id} is already claimed by comic ID {self.comic_id}" + + +@strawberry.type +class PageRemoteError(Error): + id: int + archive_id: int + + @strawberry.field + def message(self) -> str: + return f"Page ID {self.id} comes from remote archive ID {self.archive_id}" + + +@strawberry.type +class ScraperError(Error): + error: str + + @strawberry.field + def message(self) -> str: + return f"Scraping failed: {self.error}" + + +@strawberry.type +class ScraperNotAvailableError(Error): + scraper: str + comic_id: int + + @strawberry.field + def message(self) -> str: + return f"Scraper {self.scraper} not available for comic ID {self.comic_id}" + + +AddComicResponse = Annotated[ + Union[ + AddComicSuccess, + IDNotFoundError, + PageClaimedError, + PageRemoteError, + InvalidParameterError, + ], + strawberry.union("AddComicResponse"), +] +AddResponse = Annotated[ + Union[AddSuccess, IDNotFoundError, NameExistsError, InvalidParameterError], + strawberry.union("AddResponse"), +] +ArchiveResponse = Annotated[ + Union[FullArchive, IDNotFoundError], strawberry.union("ArchiveResponse") +] +ArtistResponse = Annotated[ + Union[Artist, IDNotFoundError], strawberry.union("ArtistResponse") +] +CharacterResponse = Annotated[ + Union[Character, IDNotFoundError], strawberry.union("CharacterResponse") +] +CircleResponse = Annotated[ + Union[Circle, IDNotFoundError], strawberry.union("CircleResponse") +] +ComicResponse = Annotated[ + Union[FullComic, IDNotFoundError], strawberry.union("ComicResponse") +] +DeleteResponse = Annotated[ + Union[DeleteSuccess, IDNotFoundError], strawberry.union("DeleteResponse") +] +NamespaceResponse = Annotated[ + Union[Namespace, IDNotFoundError], strawberry.union("NamespaceResponse") +] +ScrapeComicResponse = Annotated[ + Union[ + ScrapeComicResult, + ScraperNotFoundError, + ScraperNotAvailableError, + IDNotFoundError, + ScraperError, + ], + strawberry.union("ScrapeComicResponse"), +] +TagResponse = Annotated[ + Union[FullTag, IDNotFoundError], strawberry.union("TagResponse") +] +UpdateResponse = Annotated[ + Union[ + UpdateSuccess, + NameExistsError, + IDNotFoundError, + InvalidParameterError, + PageRemoteError, + PageClaimedError, + ], + strawberry.union("UpdateResponse"), +] +UpsertResponse = Annotated[ + Union[ + UpsertSuccess, + NameExistsError, + InvalidParameterError, + ], + strawberry.union("UpsertResponse"), +] +WorldResponse = Annotated[ + Union[World, IDNotFoundError], strawberry.union("WorldResponse") +] diff --git a/src/hircine/api/sort.py b/src/hircine/api/sort.py new file mode 100644 index 0000000..17043a6 --- /dev/null +++ b/src/hircine/api/sort.py @@ -0,0 +1,94 @@ +import enum +from typing import Generic, Optional, TypeVar + +import sqlalchemy +import strawberry + +import hircine.db.models as models + +T = TypeVar("T") + + +@strawberry.enum +class SortDirection(enum.Enum): + ASCENDING = strawberry.enum_value(sqlalchemy.asc) + DESCENDING = strawberry.enum_value(sqlalchemy.desc) + + +@strawberry.enum +class ComicSort(enum.Enum): + TITLE = strawberry.enum_value(models.Comic.title) + ORIGINAL_TITLE = strawberry.enum_value(models.Comic.original_title) + DATE = strawberry.enum_value(models.Comic.date) + CREATED_AT = strawberry.enum_value(models.Comic.created_at) + UPDATED_AT = strawberry.enum_value(models.Comic.updated_at) + TAG_COUNT = strawberry.enum_value(models.Comic.tag_count) + PAGE_COUNT = strawberry.enum_value(models.Comic.page_count) + RANDOM = "Random" + + +@strawberry.enum +class ArchiveSort(enum.Enum): + PATH = strawberry.enum_value(models.Archive.path) + SIZE = strawberry.enum_value(models.Archive.size) + CREATED_AT = strawberry.enum_value(models.Archive.created_at) + PAGE_COUNT = strawberry.enum_value(models.Archive.page_count) + RANDOM = "Random" + + +@strawberry.enum +class ArtistSort(enum.Enum): + NAME = strawberry.enum_value(models.Artist.name) + CREATED_AT = strawberry.enum_value(models.Artist.created_at) + UPDATED_AT = strawberry.enum_value(models.Artist.updated_at) + RANDOM = "Random" + + +@strawberry.enum +class CharacterSort(enum.Enum): + NAME = strawberry.enum_value(models.Character.name) + CREATED_AT = strawberry.enum_value(models.Character.created_at) + UPDATED_AT = strawberry.enum_value(models.Character.updated_at) + RANDOM = "Random" + + +@strawberry.enum +class CircleSort(enum.Enum): + NAME = strawberry.enum_value(models.Circle.name) + CREATED_AT = strawberry.enum_value(models.Circle.created_at) + UPDATED_AT = strawberry.enum_value(models.Circle.updated_at) + RANDOM = "Random" + + +@strawberry.enum +class NamespaceSort(enum.Enum): + SORT_NAME = strawberry.enum_value(models.Namespace.sort_name) + NAME = strawberry.enum_value(models.Namespace.name) + CREATED_AT = strawberry.enum_value(models.Namespace.created_at) + UPDATED_AT = strawberry.enum_value(models.Namespace.updated_at) + RANDOM = "Random" + + +@strawberry.enum +class TagSort(enum.Enum): + NAME = strawberry.enum_value(models.Tag.name) + CREATED_AT = strawberry.enum_value(models.Tag.created_at) + UPDATED_AT = strawberry.enum_value(models.Tag.updated_at) + RANDOM = "Random" + + +@strawberry.enum +class WorldSort(enum.Enum): + NAME = strawberry.enum_value(models.World.name) + CREATED_AT = strawberry.enum_value(models.World.created_at) + UPDATED_AT = strawberry.enum_value(models.World.updated_at) + RANDOM = "Random" + + +# Use a generic "Input" name so that we end up with sane GraphQL type names +# See also: filter.py +@strawberry.input +class Input(Generic[T]): + on: T + direction: Optional[SortDirection] = SortDirection.ASCENDING + seed: Optional[int] = strawberry.UNSET diff --git a/src/hircine/api/types.py b/src/hircine/api/types.py new file mode 100644 index 0000000..b9fe0e7 --- /dev/null +++ b/src/hircine/api/types.py @@ -0,0 +1,337 @@ +import datetime +from typing import Generic, List, Optional, TypeVar + +import strawberry + +import hircine.scraper.types as scraped +from hircine.enums import Category, Censorship, Direction, Language, Layout, Rating + +T = TypeVar("T") + + +@strawberry.type +class Base: + id: int + + def __init__(self, model): + self.id = model.id + + +@strawberry.type +class MixinName: + name: str + + def __init__(self, model): + self.name = model.name + super().__init__(model) + + +@strawberry.type +class MixinFavourite: + favourite: bool + + def __init__(self, model): + self.favourite = model.favourite + super().__init__(model) + + +@strawberry.type +class MixinOrganized: + organized: bool + + def __init__(self, model): + self.organized = model.organized + super().__init__(model) + + +@strawberry.type +class MixinBookmarked: + bookmarked: bool + + def __init__(self, model): + self.bookmarked = model.bookmarked + super().__init__(model) + + +@strawberry.type +class MixinCreatedAt: + created_at: datetime.datetime + + def __init__(self, model): + self.created_at = model.created_at + super().__init__(model) + + +@strawberry.type +class MixinModifyDates(MixinCreatedAt): + updated_at: datetime.datetime + + def __init__(self, model): + self.updated_at = model.updated_at + super().__init__(model) + + +@strawberry.type +class FilterResult(Generic[T]): + count: int + edges: List["T"] + + +@strawberry.type +class Archive(MixinName, MixinOrganized, Base): + cover: "Image" + path: str + size: int + page_count: int + + def __init__(self, model): + super().__init__(model) + self.path = model.path + self.cover = Image(model.cover) + self.size = model.size + self.page_count = model.page_count + + +@strawberry.type +class FullArchive(MixinCreatedAt, Archive): + pages: List["Page"] + comics: List["Comic"] + mtime: datetime.datetime + + def __init__(self, model): + super().__init__(model) + self.mtime = model.mtime + self.pages = [Page(p) for p in model.pages] + self.comics = [Comic(c) for c in model.comics] + + +@strawberry.type +class Page(Base): + path: str + image: "Image" + comic_id: Optional[int] + + def __init__(self, model): + super().__init__(model) + self.path = model.path + self.image = Image(model.image) + self.comic_id = model.comic_id + + +@strawberry.type +class Image(Base): + hash: str + width: int + height: int + aspect_ratio: float + + def __init__(self, model): + super().__init__(model) + self.hash = model.hash + self.width = model.width + self.height = model.height + self.aspect_ratio = model.aspect_ratio + + +@strawberry.type +class Comic(MixinFavourite, MixinOrganized, MixinBookmarked, Base): + title: str + original_title: Optional[str] + language: Optional[Language] + date: Optional[datetime.date] + cover: "Image" + rating: Optional[Rating] + category: Optional[Category] + censorship: Optional[Censorship] + tags: List["ComicTag"] + artists: List["Artist"] + characters: List["Character"] + circles: List["Circle"] + worlds: List["World"] + page_count: int + + def __init__(self, model): + super().__init__(model) + self.title = model.title + self.original_title = model.original_title + self.language = model.language + self.date = model.date + self.cover = Image(model.cover) + self.rating = model.rating + self.category = model.category + self.censorship = model.censorship + self.tags = [ComicTag(t.namespace, t.tag) for t in model.tags] + self.artists = [Artist(a) for a in model.artists] + self.characters = [Character(c) for c in model.characters] + self.worlds = [World(w) for w in model.worlds] + self.circles = [Circle(g) for g in model.circles] + self.page_count = model.page_count + + +@strawberry.type +class FullComic(MixinModifyDates, Comic): + archive: "Archive" + url: Optional[str] + pages: List["Page"] + direction: Direction + layout: Layout + + def __init__(self, model): + super().__init__(model) + self.direction = model.direction + self.layout = model.layout + self.archive = Archive(model.archive) + self.pages = [Page(p) for p in model.pages] + self.url = model.url + + +@strawberry.type +class Tag(MixinName, Base): + description: Optional[str] + + def __init__(self, model): + super().__init__(model) + self.description = model.description + + +@strawberry.type +class FullTag(Tag): + namespaces: List["Namespace"] + + def __init__(self, model): + super().__init__(model) + self.namespaces = [Namespace(n) for n in model.namespaces] + + +@strawberry.type +class Namespace(MixinName, Base): + sort_name: Optional[str] + + def __init__(self, model): + super().__init__(model) + self.sort_name = model.sort_name + + +@strawberry.type +class ComicTag: + id: str + name: str + description: Optional[str] + + def __init__(self, namespace=None, tag=None): + tag_name, tag_id = ("", "") + ns_name, ns_id = ("", "") + + if tag: + tag_name, tag_id = (tag.name, tag.id) + if namespace: + ns_name, ns_id = (namespace.name, namespace.id) + + self.name = f"{ns_name}:{tag_name}" + self.id = f"{ns_id}:{tag_id}" + if tag: + self.description = tag.description + + +@strawberry.type +class Artist(MixinName, Base): + def __init__(self, model): + super().__init__(model) + + +@strawberry.type +class Character(MixinName, Base): + def __init__(self, model): + super().__init__(model) + + +@strawberry.type +class Circle(MixinName, Base): + def __init__(self, model): + super().__init__(model) + + +@strawberry.type +class World(MixinName, Base): + def __init__(self, model): + super().__init__(model) + + +@strawberry.type +class ComicScraper: + id: str + name: str + + def __init__(self, id, scraper): + self.id = id + self.name = scraper.name + + +@strawberry.type +class ScrapeComicResult: + data: "ScrapedComic" + warnings: List[str] = strawberry.field(default_factory=lambda: []) + + +@strawberry.type +class ScrapedComic: + title: Optional[str] = None + original_title: Optional[str] = None + url: Optional[str] = None + language: Optional[Language] = None + date: Optional[datetime.date] = None + rating: Optional[Rating] = None + category: Optional[Category] = None + censorship: Optional[Censorship] = None + direction: Optional[Direction] = None + layout: Optional[Layout] = None + tags: List[str] = strawberry.field(default_factory=lambda: []) + artists: List[str] = strawberry.field(default_factory=lambda: []) + characters: List[str] = strawberry.field(default_factory=lambda: []) + circles: List[str] = strawberry.field(default_factory=lambda: []) + worlds: List[str] = strawberry.field(default_factory=lambda: []) + + @classmethod + def from_generator(cls, generator): + data = cls() + + seen = set() + for item in generator: + if not item or item in seen: + continue + + seen.add(item) + + match item: + case scraped.Title(): + data.title = item.value + case scraped.OriginalTitle(): + data.original_title = item.value + case scraped.URL(): + data.url = item.value + case scraped.Language(): + data.language = item.value + case scraped.Date(): + data.date = item.value + case scraped.Rating(): + data.rating = item.value + case scraped.Category(): + data.category = item.value + case scraped.Censorship(): + data.censorship = item.value + case scraped.Direction(): + data.direction = item.value + case scraped.Layout(): + data.layout = item.value + case scraped.Tag(): + data.tags.append(item.to_string()) + case scraped.Artist(): + data.artists.append(item.name) + case scraped.Character(): + data.characters.append(item.name) + case scraped.Circle(): + data.circles.append(item.name) + case scraped.World(): + data.worlds.append(item.name) + + return data diff --git a/src/hircine/app.py b/src/hircine/app.py new file mode 100644 index 0000000..f22396b --- /dev/null +++ b/src/hircine/app.py @@ -0,0 +1,79 @@ +import asyncio +import os + +import strawberry +import uvicorn +from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.cors import CORSMiddleware +from starlette.routing import Mount, Route +from starlette.staticfiles import StaticFiles +from strawberry.asgi import GraphQL + +import hircine.db as db +from hircine.api.mutation import Mutation +from hircine.api.query import Query +from hircine.config import init_dir_structure + +schema = strawberry.Schema(query=Query, mutation=Mutation) +graphql: GraphQL = GraphQL(schema) + + +class SinglePageApplication(StaticFiles): # pragma: no cover + def __init__(self, index="index.html"): + self.index = index + super().__init__( + packages=[("hircine", "static/app")], html=True, check_dir=True + ) + + def lookup_path(self, path): + full_path, stat_result = super().lookup_path(path) + + if stat_result is None: + return super().lookup_path(self.index) + + return (full_path, stat_result) + + +class HelpFiles(StaticFiles): # pragma: no cover + def __init__(self, index="index.html"): + self.index = index + super().__init__( + packages=[("hircine", "static/help")], html=True, check_dir=True + ) + + +def app(): # pragma: no cover + dirs = init_dir_structure() + db.configure(dirs) + + routes = [ + Route("/graphql", endpoint=graphql), + Mount("/objects", app=StaticFiles(directory=dirs.objects), name="objects"), + Mount("/help", app=HelpFiles(), name="help"), + Mount("/", app=SinglePageApplication(), name="app"), + ] + + middleware = [] + + if "HIRCINE_DEV" in os.environ: + middleware = [ + Middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + ) + ] + + return Starlette(routes=routes, middleware=middleware) + + +if __name__ == "__main__": + dirs = init_dir_structure() + db.ensuredb(dirs) + + engine = db.configure(dirs) + asyncio.run(db.ensure_current_revision(engine)) + + uvicorn.run("hircine.app:app", host="::", reload=True, factory=True, lifespan="on") diff --git a/src/hircine/cli.py b/src/hircine/cli.py new file mode 100644 index 0000000..6941e2c --- /dev/null +++ b/src/hircine/cli.py @@ -0,0 +1,128 @@ +import argparse +import asyncio +import configparser +import importlib.metadata +import os +import sys +from datetime import datetime, timezone + +import alembic.config + +import hircine.db as db +from hircine import codename +from hircine.config import init_dir_structure +from hircine.scanner import Scanner + + +class SubcommandHelpFormatter(argparse.RawDescriptionHelpFormatter): + def _format_action(self, action): + parts = super(argparse.RawDescriptionHelpFormatter, self)._format_action(action) + if action.nargs == argparse.PARSER: + parts = "\n".join(parts.split("\n")[1:]) + return parts + + +def init(config, dirs, engine, args): + if os.path.exists(dirs.database): + sys.exit("Database already initialized.") + + dirs.mkdirs() + + print("Initializing database...") + asyncio.run(db.initialize(engine)) + print("Done.") + + +def scan(config, dirs, engine, args): + db.ensuredb(dirs) + + asyncio.run(db.ensure_current_revision(engine)) + + scanner = Scanner(config, dirs, reprocess=args.reprocess) + asyncio.run(scanner.scan()) + scanner.report() + + +def backup(config, dirs, engine, args, tag="manual"): + db.ensuredb(dirs) + + os.makedirs(dirs.backups, exist_ok=True) + + date = datetime.now(tz=timezone.utc).strftime("%Y-%m-%d_%H%M%S") + filename = f"{os.path.basename(dirs.database)}.{tag}.{date}" + path = os.path.join(dirs.backups, filename) + + asyncio.run(db.backup(engine, path)) + + +def migrate(config, dirs, engine, args): + db.ensuredb(dirs) + + backup(config, dirs, engine, args, tag="pre-migrate") + alembic.config.main(argv=["--config", db.alembic_ini, "upgrade", "head"]) + + +def vacuum(config, dirs, engine, args): + db.ensuredb(dirs) + + asyncio.run(db.vacuum(engine)) + + +def version(config, dirs, engine, args): + version = importlib.metadata.metadata("hircine")["Version"] + print(f'hircine {version} "{codename}"') + + +def main(): + parser = argparse.ArgumentParser( + prog="hircine", formatter_class=SubcommandHelpFormatter + ) + parser.add_argument("-C", dest="dir", help="run as if launched in DIR") + + subparsers = parser.add_subparsers(title="commands", required=True) + + parser_init = subparsers.add_parser("init", help="initialize a database") + parser_init.set_defaults(func=init) + + parser_import = subparsers.add_parser("import", help="import archives") + parser_import.set_defaults(func=scan) + parser_import.add_argument( + "-r", "--reprocess", action="store_true", help="reprocess all image files" + ) + + parser_migrate = subparsers.add_parser("migrate", help="run database migrations") + parser_migrate.set_defaults(func=migrate) + + parser_backup = subparsers.add_parser( + "backup", help="create a backup of the database" + ) + parser_backup.set_defaults(func=backup) + + parser_vacuum = subparsers.add_parser( + "vacuum", help="repack and optimize the database" + ) + parser_vacuum.set_defaults(func=vacuum) + + parser_version = subparsers.add_parser("version", help="show version and exit") + parser_version.set_defaults(func=version) + + args = parser.parse_args() + + if args.dir: + try: + os.chdir(args.dir) + except OSError as e: + sys.exit(e) + + dirs = init_dir_structure() + + config = configparser.ConfigParser() + config.read(dirs.config) + + engine = db.configure(dirs) + + args.func(config, dirs, engine, args) + + +if __name__ == "__main__": + main() diff --git a/src/hircine/config.py b/src/hircine/config.py new file mode 100644 index 0000000..fda783e --- /dev/null +++ b/src/hircine/config.py @@ -0,0 +1,38 @@ +import os + +dir_structure = None + + +class DirectoryStructure: + def __init__( + self, + database="hircine.db", + scan="content/", + objects="objects/", + backups="backups/", + config="hircine.ini", + ): + self.database = database + self.scan = scan + self.objects = objects + self.backups = backups + self.config = config + + def mkdirs(self): # pragma: no cover + os.makedirs(self.objects, exist_ok=True) + os.makedirs(self.scan, exist_ok=True) + os.makedirs(self.backups, exist_ok=True) + + +def init_dir_structure(): # pragma: no cover + global dir_structure + + dir_structure = DirectoryStructure() + + return dir_structure + + +def get_dir_structure(): + global dir_structure + + return dir_structure diff --git a/src/hircine/db/__init__.py b/src/hircine/db/__init__.py new file mode 100644 index 0000000..493bd91 --- /dev/null +++ b/src/hircine/db/__init__.py @@ -0,0 +1,99 @@ +import os +import sys +from pathlib import Path + +from alembic import command as alembic_command +from alembic import script as alembic_script +from alembic.config import Config as AlembicConfig +from alembic.runtime import migration +from sqlalchemy import event, text +from sqlalchemy.engine import Engine +from sqlalchemy.ext.asyncio import ( + async_sessionmaker, + create_async_engine, +) + +from . import models + +alembic_ini = f"{Path(__file__).parent.parent}/migrations/alembic.ini" +session = async_sessionmaker(expire_on_commit=False, autoflush=False) + + +def ensuredb(dirs): # pragma: no cover + if not os.path.exists(dirs.database): + sys.exit("No database found.") + + +def sqlite_url(path): + return f"sqlite+aiosqlite:///{path}" + + +def model(model): + def decorator(cls): + cls._model = getattr(models, model) + return cls + + return decorator + + +def stamp_alembic(connection): + cfg = AlembicConfig(alembic_ini) + cfg.attributes["connection"] = connection + cfg.attributes["silent"] = True + + alembic_command.stamp(cfg, "head") + + +def check_current_head(connection): # pragma: no cover + directory = alembic_script.ScriptDirectory.from_config(AlembicConfig(alembic_ini)) + + context = migration.MigrationContext.configure(connection) + return set(context.get_current_heads()) == set(directory.get_heads()) + + +async def ensure_current_revision(engine): # pragma: no cover + async with engine.begin() as conn: + if not await conn.run_sync(check_current_head): + sys.exit("Database is not up to date, please run 'hircine migrate'.") + + +async def initialize(engine): + async with engine.begin() as conn: + await conn.run_sync(models.Base.metadata.drop_all) + await conn.run_sync(models.Base.metadata.create_all) + await conn.run_sync(stamp_alembic) + + +async def backup(engine, path): # pragma: no cover + async with engine.connect() as conn: + await conn.execute(text("VACUUM INTO :path"), {"path": path}) + + +async def vacuum(engine): # pragma: no cover + async with engine.connect() as conn: + await conn.execute(text("VACUUM")) + + +@event.listens_for(Engine, "connect") +def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.execute("PRAGMA journal_mode=WAL") + cursor.close() + + +def create_engine(path, echo=False): + return create_async_engine( + sqlite_url(path), + connect_args={"check_same_thread": False}, + echo=echo, + ) + + +def configure(dirs): # pragma: no cover + echo = "HIRCINE_DEV" in os.environ + + engine = create_engine(dirs.database, echo=echo) + session.configure(bind=engine) + + return engine diff --git a/src/hircine/db/models.py b/src/hircine/db/models.py new file mode 100644 index 0000000..575771b --- /dev/null +++ b/src/hircine/db/models.py @@ -0,0 +1,379 @@ +import os +from datetime import date, datetime, timezone +from typing import List, Optional + +from sqlalchemy import ( + DateTime, + ForeignKey, + MetaData, + TypeDecorator, + event, + func, + select, +) +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + declared_attr, + deferred, + joinedload, + mapped_column, + relationship, + selectinload, +) + +from hircine.api import APIException +from hircine.api.responses import InvalidParameterError +from hircine.enums import Category, Censorship, Direction, Language, Layout, Rating + +naming_convention = { + "ix": "ix_%(column_0_label)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", +} + + +class DateTimeUTC(TypeDecorator): + impl = DateTime + cache_ok = True + + def process_bind_param(self, value, dialect): + if value is not None: + if not value.tzinfo: + raise TypeError("tzinfo is required") + value = value.astimezone(timezone.utc).replace(tzinfo=None) + return value + + def process_result_value(self, value, dialect): + if value is not None: + value = value.replace(tzinfo=timezone.utc) + return value + + +class Base(DeclarativeBase): + metadata = MetaData(naming_convention=naming_convention) + + @declared_attr.directive + def __tablename__(cls) -> str: + return cls.__name__.lower() + + __mapper_args__ = {"eager_defaults": True} + + @classmethod + def load_update(cls, fields): + return [] + + +class MixinID: + id: Mapped[int] = mapped_column(primary_key=True) + + +class MixinName: + name: Mapped[str] = mapped_column(unique=True) + + @classmethod + def default_order(cls): + return [cls.name] + + +class MixinFavourite: + favourite: Mapped[bool] = mapped_column(insert_default=False) + + +class MixinOrganized: + organized: Mapped[bool] = mapped_column(insert_default=False) + + +class MixinBookmarked: + bookmarked: Mapped[bool] = mapped_column(insert_default=False) + + +class MixinCreatedAt: + created_at: Mapped[datetime] = mapped_column(DateTimeUTC, server_default=func.now()) + + +class MixinModifyDates(MixinCreatedAt): + updated_at: Mapped[datetime] = mapped_column(DateTimeUTC, server_default=func.now()) + + +class Archive(MixinID, MixinCreatedAt, MixinOrganized, Base): + hash: Mapped[str] = mapped_column(unique=True) + path: Mapped[str] = mapped_column(unique=True) + size: Mapped[int] + mtime: Mapped[datetime] = mapped_column(DateTimeUTC) + + cover_id: Mapped[int] = mapped_column(ForeignKey("image.id")) + cover: Mapped["Image"] = relationship(lazy="joined", innerjoin=True) + + pages: Mapped[List["Page"]] = relationship( + back_populates="archive", + order_by="(Page.index)", + cascade="save-update, merge, expunge, delete, delete-orphan", + ) + comics: Mapped[List["Comic"]] = relationship( + back_populates="archive", + cascade="save-update, merge, expunge, delete, delete-orphan", + ) + + page_count: Mapped[int] + + @property + def name(self): + return os.path.basename(self.path) + + @classmethod + def default_order(cls): + return [cls.path] + + @classmethod + def load_full(cls): + return [ + joinedload(cls.pages, innerjoin=True), + selectinload(cls.comics), + ] + + +class Image(MixinID, Base): + hash: Mapped[str] = mapped_column(unique=True) + width: Mapped[int] + height: Mapped[int] + + @property + def aspect_ratio(self): + return self.width / self.height + + +class Page(MixinID, Base): + path: Mapped[str] + index: Mapped[int] + + archive_id: Mapped[int] = mapped_column(ForeignKey("archive.id")) + archive: Mapped["Archive"] = relationship(back_populates="pages") + + image_id: Mapped[int] = mapped_column(ForeignKey("image.id")) + image: Mapped["Image"] = relationship(lazy="joined", innerjoin=True) + + comic_id: Mapped[Optional[int]] = mapped_column(ForeignKey("comic.id")) + + +class Comic( + MixinID, MixinModifyDates, MixinFavourite, MixinOrganized, MixinBookmarked, Base +): + title: Mapped[str] + original_title: Mapped[Optional[str]] + url: Mapped[Optional[str]] + language: Mapped[Optional[Language]] + date: Mapped[Optional[date]] + + direction: Mapped[Direction] = mapped_column(insert_default=Direction.LEFT_TO_RIGHT) + layout: Mapped[Layout] = mapped_column(insert_default=Layout.SINGLE) + rating: Mapped[Optional[Rating]] + category: Mapped[Optional[Category]] + censorship: Mapped[Optional[Censorship]] + + cover_id: Mapped[int] = mapped_column(ForeignKey("image.id")) + cover: Mapped["Image"] = relationship(lazy="joined", innerjoin=True) + + archive_id: Mapped[int] = mapped_column(ForeignKey("archive.id")) + archive: Mapped["Archive"] = relationship(back_populates="comics") + + pages: Mapped[List["Page"]] = relationship(order_by="(Page.index)") + page_count: Mapped[int] + + tags: Mapped[List["ComicTag"]] = relationship( + lazy="selectin", + cascade="save-update, merge, expunge, delete, delete-orphan", + passive_deletes=True, + ) + + artists: Mapped[List["Artist"]] = relationship( + secondary="comicartist", + lazy="selectin", + order_by="(Artist.name, Artist.id)", + passive_deletes=True, + ) + + characters: Mapped[List["Character"]] = relationship( + secondary="comiccharacter", + lazy="selectin", + order_by="(Character.name, Character.id)", + passive_deletes=True, + ) + + circles: Mapped[List["Circle"]] = relationship( + secondary="comiccircle", + lazy="selectin", + order_by="(Circle.name, Circle.id)", + passive_deletes=True, + ) + + worlds: Mapped[List["World"]] = relationship( + secondary="comicworld", + lazy="selectin", + order_by="(World.name, World.id)", + passive_deletes=True, + ) + + @classmethod + def default_order(cls): + return [cls.title] + + @classmethod + def load_full(cls): + return [ + joinedload(cls.archive, innerjoin=True), + joinedload(cls.pages, innerjoin=True), + ] + + @classmethod + def load_update(cls, fields): + if "pages" in fields: + return [joinedload(cls.pages, innerjoin=True)] + return [] + + +class Tag(MixinID, MixinModifyDates, MixinName, Base): + description: Mapped[Optional[str]] + namespaces: Mapped[List["Namespace"]] = relationship( + secondary="tagnamespaces", + passive_deletes=True, + order_by="(Namespace.sort_name, Namespace.name, Namespace.id)", + ) + + @classmethod + def load_full(cls): + return [selectinload(cls.namespaces)] + + @classmethod + def load_update(cls, fields): + if "namespaces" in fields: + return cls.load_full() + return [] + + +class Namespace(MixinID, MixinModifyDates, MixinName, Base): + sort_name: Mapped[Optional[str]] + + @classmethod + def default_order(cls): + return [cls.sort_name, cls.name] + + @classmethod + def load_full(cls): + return [] + + +class TagNamespaces(Base): + namespace_id: Mapped[int] = mapped_column( + ForeignKey("namespace.id", ondelete="CASCADE"), primary_key=True + ) + tag_id: Mapped[int] = mapped_column( + ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True + ) + + +class ComicTag(Base): + comic_id: Mapped[int] = mapped_column( + ForeignKey("comic.id", ondelete="CASCADE"), primary_key=True + ) + namespace_id: Mapped[int] = mapped_column( + ForeignKey("namespace.id", ondelete="CASCADE"), primary_key=True + ) + tag_id: Mapped[int] = mapped_column( + ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True + ) + + namespace: Mapped["Namespace"] = relationship( + lazy="joined", + innerjoin=True, + order_by="(Namespace.sort_name, Namespace.name, Namespace.id)", + ) + + tag: Mapped["Tag"] = relationship( + lazy="joined", + innerjoin=True, + order_by="(Tag.name, Tag.id)", + ) + + @property + def name(self): + return f"{self.namespace.name}:{self.tag.name}" + + @property + def id(self): + return f"{self.namespace.id}:{self.tag.id}" + + +class Artist(MixinID, MixinModifyDates, MixinName, Base): + pass + + +class ComicArtist(Base): + comic_id: Mapped[int] = mapped_column( + ForeignKey("comic.id", ondelete="CASCADE"), primary_key=True + ) + artist_id: Mapped[int] = mapped_column( + ForeignKey("artist.id", ondelete="CASCADE"), primary_key=True + ) + + +class Character(MixinID, MixinModifyDates, MixinName, Base): + pass + + +class ComicCharacter(Base): + comic_id: Mapped[int] = mapped_column( + ForeignKey("comic.id", ondelete="CASCADE"), primary_key=True + ) + character_id: Mapped[int] = mapped_column( + ForeignKey("character.id", ondelete="CASCADE"), primary_key=True + ) + + +class Circle(MixinID, MixinModifyDates, MixinName, Base): + pass + + +class ComicCircle(Base): + comic_id: Mapped[int] = mapped_column( + ForeignKey("comic.id", ondelete="CASCADE"), primary_key=True + ) + circle_id: Mapped[int] = mapped_column( + ForeignKey("circle.id", ondelete="CASCADE"), primary_key=True + ) + + +class World(MixinID, MixinModifyDates, MixinName, Base): + pass + + +class ComicWorld(Base): + comic_id: Mapped[int] = mapped_column( + ForeignKey("comic.id", ondelete="CASCADE"), primary_key=True + ) + world_id: Mapped[int] = mapped_column( + ForeignKey("world.id", ondelete="CASCADE"), primary_key=True + ) + + +def defer_relationship_count(relationship, secondary=False): + left, right = relationship.property.synchronize_pairs[0] + + return deferred( + select(func.count(right)) + .select_from(right.table) + .where(left == right) + .scalar_subquery() + ) + + +Comic.tag_count = defer_relationship_count(Comic.tags) + + +@event.listens_for(Comic.pages, "bulk_replace") +def on_comic_pages_bulk_replace(target, values, initiator): + if not values: + raise APIException( + InvalidParameterError(parameter="pages", text="cannot be empty") + ) + + target.page_count = len(values) diff --git a/src/hircine/db/ops.py b/src/hircine/db/ops.py new file mode 100644 index 0000000..c164cd2 --- /dev/null +++ b/src/hircine/db/ops.py @@ -0,0 +1,200 @@ +import random +from collections import defaultdict + +from sqlalchemy import delete, func, null, select, text, tuple_ +from sqlalchemy.orm import contains_eager, undefer +from sqlalchemy.orm.util import identity_key +from strawberry import UNSET + +from hircine.db.models import ( + Archive, + ComicTag, + Image, + Namespace, + Page, + Tag, + TagNamespaces, +) + + +def paginate(sql, pagination): + if not pagination: + return sql + + if pagination.items < 1 or pagination.page < 1: + return sql.limit(0) + + sql = sql.limit(pagination.items) + + if pagination.page > 0: + sql = sql.offset((pagination.page - 1) * pagination.items) + + return sql + + +def apply_filter(sql, filter): + if not filter: + return sql + + if filter.include is not UNSET: + sql = filter.include.match(sql, False) + if filter.exclude is not UNSET: + sql = filter.exclude.match(sql, True) + + return sql + + +def sort_random(seed): + if seed: + seed = seed % 1000000000 + else: + seed = random.randrange(1000000000) + + # https://www.sqlite.org/forum/forumpost/e2216583a4 + return text("sin(iid + :seed)").bindparams(seed=seed) + + +def apply_sort(sql, sort, default, tiebreaker): + if not sort: + return sql.order_by(*default, tiebreaker) + + direction = sort.direction.value + + if sort.on.value == "Random": + return sql.order_by(direction(sort_random(sort.seed))) + + sql = sql.options(undefer(sort.on.value)) + + return sql.order_by(direction(sort.on.value), tiebreaker) + + +async def query_all(session, model, pagination=None, filter=None, sort=None): + sql = select( + model, func.count(model.id).over().label("count"), model.id.label("iid") + ) + sql = apply_filter(sql, filter) + sql = apply_sort(sql, sort, model.default_order(), model.id) + sql = paginate(sql, pagination) + + count = 0 + objs = [] + + for row in await session.execute(sql): + if count == 0: + count = row.count + + objs.append(row[0]) + + return count, objs + + +async def has_with_name(session, model, name): + sql = select(model.id).where(model.name == name) + return bool((await session.scalars(sql)).unique().first()) + + +async def tag_restrictions(session, tuples=None): + sql = select(TagNamespaces) + + if tuples: + sql = sql.where( + tuple_(TagNamespaces.namespace_id, TagNamespaces.tag_id).in_(tuples) + ) + + namespaces = (await session.scalars(sql)).unique().all() + + ns_map = defaultdict(set) + + for n in namespaces: + ns_map[n.tag_id].add(n.namespace_id) + + return ns_map + + +def lookup_identity(session, model, ids): + objects = [] + satisfied = set() + + for id in ids: + object = session.identity_map.get(identity_key(model, id), None) + if object is not None: + objects.append(object) + satisfied.add(id) + + return objects, satisfied + + +async def get_all(session, model, ids, options=[], use_identity_map=False): + objects = [] + ids = set(ids) + + if use_identity_map: + objects, satisfied = lookup_identity(session, model, ids) + + ids = ids - satisfied + + if not ids: + return objects, set() + + sql = select(model).where(model.id.in_(ids)).options(*options) + + objects += (await session.scalars(sql)).unique().all() + + fetched_ids = [object.id for object in objects] + missing = set(ids) - set(fetched_ids) + + return objects, missing + + +async def get_all_names(session, model, names, options=[]): + names = set(names) + + sql = select(model).where(model.name.in_(names)).options(*options) + + objects = (await session.scalars(sql)).unique().all() + + fetched_names = [object.name for object in objects] + missing = set(names) - set(fetched_names) + + return objects, missing + + +async def get_ctag_names(session, comic_id, tuples): + sql = ( + select(ComicTag) + .join(ComicTag.namespace) + .options(contains_eager(ComicTag.namespace)) + .join(ComicTag.tag) + .options(contains_eager(ComicTag.tag)) + .where(ComicTag.comic_id == comic_id) + .where(tuple_(Namespace.name, Tag.name).in_(tuples)) + ) + objects = (await session.scalars(sql)).unique().all() + + fetched_tags = [(o.namespace.name, o.tag.name) for o in objects] + missing = set(tuples) - set(fetched_tags) + + return objects, missing + + +async def get_image_orphans(session): + sql = select(Image.id, Image.hash).join(Page, isouter=True).where(Page.id == null()) + + return (await session.execute(sql)).t + + +async def get_remaining_pages_for(session, archive_id): + sql = ( + select(Page.id) + .join(Archive) + .where(Archive.id == archive_id) + .where(Page.comic_id == null()) + ) + + return (await session.execute(sql)).scalars().all() + + +async def delete_all(session, model, ids): + result = await session.execute(delete(model).where(model.id.in_(ids))) + + return result.rowcount diff --git a/src/hircine/enums.py b/src/hircine/enums.py new file mode 100644 index 0000000..7f95f02 --- /dev/null +++ b/src/hircine/enums.py @@ -0,0 +1,244 @@ +import enum + +import strawberry + + +@strawberry.enum +class Direction(enum.Enum): + LEFT_TO_RIGHT = "Left to Right" + RIGHT_TO_LEFT = "Right to Left" + + +@strawberry.enum +class Layout(enum.Enum): + SINGLE = "Single Page" + DOUBLE = "Double Page" + DOUBLE_OFFSET = "Double Page, offset" + + +@strawberry.enum +class Rating(enum.Enum): + SAFE = "Safe" + QUESTIONABLE = "Questionable" + EXPLICIT = "Explicit" + + +@strawberry.enum +class Category(enum.Enum): + MANGA = "Manga" + DOUJINSHI = "Doujinshi" + COMIC = "Comic" + GAME_CG = "Game CG" + IMAGE_SET = "Image Set" + ARTBOOK = "Artbook" + VARIANT_SET = "Variant Set" + WEBTOON = "Webtoon" + + +@strawberry.enum +class Censorship(enum.Enum): + NONE = "None" + BAR = "Bars" + MOSAIC = "Mosaic" + FULL = "Full" + + +@strawberry.enum +class UpdateMode(enum.Enum): + REPLACE = "Replace" + ADD = "Add" + REMOVE = "Remove" + + +@strawberry.enum +class OnMissing(enum.Enum): + IGNORE = "Ignore" + CREATE = "Create" + + +@strawberry.enum +class Language(enum.Enum): + AA = "Afar" + AB = "Abkhazian" + AE = "Avestan" + AF = "Afrikaans" + AK = "Akan" + AM = "Amharic" + AN = "Aragonese" + AR = "Arabic" + AS = "Assamese" + AV = "Avaric" + AY = "Aymara" + AZ = "Azerbaijani" + BA = "Bashkir" + BE = "Belarusian" + BG = "Bulgarian" + BH = "Bihari languages" + BI = "Bislama" + BM = "Bambara" + BN = "Bengali" + BO = "Tibetan" + BR = "Breton" + BS = "Bosnian" + CA = "Catalan" + CE = "Chechen" + CH = "Chamorro" + CO = "Corsican" + CR = "Cree" + CS = "Czech" + CU = "Church Slavic" + CV = "Chuvash" + CY = "Welsh" + DA = "Danish" + DE = "German" + DV = "Divehi" + DZ = "Dzongkha" + EE = "Ewe" + EL = "Modern Greek" + EN = "English" + EO = "Esperanto" + ES = "Spanish" + ET = "Estonian" + EU = "Basque" + FA = "Persian" + FF = "Fulah" + FI = "Finnish" + FJ = "Fijian" + FO = "Faroese" + FR = "French" + FY = "Western Frisian" + GA = "Irish" + GD = "Gaelic" + GL = "Galician" + GN = "Guarani" + GU = "Gujarati" + GV = "Manx" + HA = "Hausa" + HE = "Hebrew" + HI = "Hindi" + HO = "Hiri Motu" + HR = "Croatian" + HT = "Haitian" + HU = "Hungarian" + HY = "Armenian" + HZ = "Herero" + IA = "Interlingua" + ID = "Indonesian" + IE = "Interlingue" + IG = "Igbo" + II = "Sichuan Yi" + IK = "Inupiaq" + IO = "Ido" + IS = "Icelandic" + IT = "Italian" + IU = "Inuktitut" + JA = "Japanese" + JV = "Javanese" + KA = "Georgian" + KG = "Kongo" + KI = "Kikuyu" + KJ = "Kuanyama" + KK = "Kazakh" + KL = "Kalaallisut" + KM = "Central Khmer" + KN = "Kannada" + KO = "Korean" + KR = "Kanuri" + KS = "Kashmiri" + KU = "Kurdish" + KV = "Komi" + KW = "Cornish" + KY = "Kirghiz" + LA = "Latin" + LB = "Luxembourgish" + LG = "Ganda" + LI = "Limburgan" + LN = "Lingala" + LO = "Lao" + LT = "Lithuanian" + LU = "Luba-Katanga" + LV = "Latvian" + MG = "Malagasy" + MH = "Marshallese" + MI = "Maori" + MK = "Macedonian" + ML = "Malayalam" + MN = "Mongolian" + MR = "Marathi" + MS = "Malay" + MT = "Maltese" + MY = "Burmese" + NA = "Nauru" + NB = "Norwegian Bokmål" + ND = "North Ndebele" + NE = "Nepali" + NG = "Ndonga" + NL = "Dutch" + NN = "Norwegian Nynorsk" + NO = "Norwegian" + NR = "South Ndebele" + NV = "Navajo" + NY = "Chichewa" + OC = "Occitan" + OJ = "Ojibwa" + OM = "Oromo" + OR = "Oriya" + OS = "Ossetian" + PA = "Panjabi" + PI = "Pali" + PL = "Polish" + PS = "Pushto" + PT = "Portuguese" + QU = "Quechua" + RM = "Romansh" + RN = "Rundi" + RO = "Romanian" + RU = "Russian" + RW = "Kinyarwanda" + SA = "Sanskrit" + SC = "Sardinian" + SD = "Sindhi" + SE = "Northern Sami" + SG = "Sango" + SI = "Sinhala" + SK = "Slovak" + SL = "Slovenian" + SM = "Samoan" + SN = "Shona" + SO = "Somali" + SQ = "Albanian" + SR = "Serbian" + SS = "Swati" + ST = "Southern Sotho" + SU = "Sundanese" + SV = "Swedish" + SW = "Swahili" + TA = "Tamil" + TE = "Telugu" + TG = "Tajik" + TH = "Thai" + TI = "Tigrinya" + TK = "Turkmen" + TL = "Tagalog" + TN = "Tswana" + TO = "Tonga" + TR = "Turkish" + TS = "Tsonga" + TT = "Tatar" + TW = "Twi" + TY = "Tahitian" + UG = "Uighur" + UK = "Ukrainian" + UR = "Urdu" + UZ = "Uzbek" + VE = "Venda" + VI = "Vietnamese" + VO = "Volapük" + WA = "Walloon" + WO = "Wolof" + XH = "Xhosa" + YI = "Yiddish" + YO = "Yoruba" + ZA = "Zhuang" + ZH = "Chinese" + ZU = "Zulu" diff --git a/src/hircine/migrations/alembic.ini b/src/hircine/migrations/alembic.ini new file mode 100644 index 0000000..4e2bfca --- /dev/null +++ b/src/hircine/migrations/alembic.ini @@ -0,0 +1,37 @@ +[alembic] +script_location = %(here)s +sqlalchemy.url = sqlite+aiosqlite:///hircine.db + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/src/hircine/migrations/env.py b/src/hircine/migrations/env.py new file mode 100644 index 0000000..6df03ec --- /dev/null +++ b/src/hircine/migrations/env.py @@ -0,0 +1,96 @@ +import asyncio +from logging.config import fileConfig + +from alembic import context +from hircine.db.models import Base +from sqlalchemy import pool +from sqlalchemy.engine import Connection +from sqlalchemy.ext.asyncio import async_engine_from_config + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None and not config.attributes.get("silent", False): + fileConfig(config.config_file_name) + +# add your model's MetaData object here +# for 'autogenerate' support +# from myapp import mymodel +# target_metadata = mymodel.Base.metadata +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option("sqlalchemy.url") + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection: Connection) -> None: + context.configure( + connection=connection, target_metadata=target_metadata, render_as_batch=True + ) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations() -> None: + """In this scenario we need to create an Engine + and associate a connection with the context. + + """ + + connectable = async_engine_from_config( + config.get_section(config.config_ini_section, {}), + prefix="sqlalchemy.", + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online() -> None: + """Run migrations in 'online' mode.""" + + connectable = config.attributes.get("connection", None) + + if connectable is None: + asyncio.run(run_async_migrations()) + else: + do_run_migrations(connectable) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/src/hircine/migrations/script.py.mako b/src/hircine/migrations/script.py.mako new file mode 100644 index 0000000..fbc4b07 --- /dev/null +++ b/src/hircine/migrations/script.py.mako @@ -0,0 +1,26 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/src/hircine/plugins/__init__.py b/src/hircine/plugins/__init__.py new file mode 100644 index 0000000..27e55a7 --- /dev/null +++ b/src/hircine/plugins/__init__.py @@ -0,0 +1,49 @@ +from importlib.metadata import entry_points +from typing import Dict, Type + +from hircine.scraper import Scraper + +scraper_registry: Dict[str, Type[Scraper]] = {} +transformers = [] + + +def get_scraper(name): + return scraper_registry.get(name, None) + + +def get_scrapers(): + return scraper_registry.items() + + +def register_scraper(name, cls): + scraper_registry[name] = cls + + +def transformer(function): + """ + Marks the decorated function as a transformer. + + The decorated function must be a generator function that yields + :ref:`scraped-data`. The following parameters will be available to the + decorated function: + + :param generator: The scraper's generator function. + :param ScraperInfo info: Information on the scraper. + """ + + def _decorate(function): + transformers.append(function) + return function + + return _decorate(function) + + +def load(): # pragma: nocover + for entry in entry_points(group="hircine.scraper"): + register_scraper(entry.name, entry.load()) + + for entry in entry_points(group="hircine.transformer"): + entry.load() + + +load() diff --git a/src/hircine/plugins/scrapers/__init__.py b/src/hircine/plugins/scrapers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/hircine/plugins/scrapers/__init__.py diff --git a/src/hircine/plugins/scrapers/anchira.py b/src/hircine/plugins/scrapers/anchira.py new file mode 100644 index 0000000..aa224b9 --- /dev/null +++ b/src/hircine/plugins/scrapers/anchira.py @@ -0,0 +1,101 @@ +import re + +import yaml + +import hircine.enums as enums +from hircine.scraper import Scraper +from hircine.scraper.types import ( + URL, + Artist, + Censorship, + Circle, + Date, + Direction, + Language, + Rating, + Tag, + Title, + World, +) +from hircine.scraper.utils import open_archive_file + +URL_REGEX = re.compile(r"^https?://anchira\.to/g/") + + +class AnchiraYamlScraper(Scraper): + """ + A scraper for ``info.yaml`` files found in archives downloaded from + *anchira.to*. + + .. list-table:: + :align: left + + * - **Requires** + - ``info.yaml`` in the archive or as a sidecar. + * - **Source** + - ``anchira.to`` + """ + + name = "anchira.to info.yaml" + source = "anchira.to" + + def __init__(self, comic): + super().__init__(comic) + + self.data = self.load() + source = self.data.get("Source") + + if source and re.match(URL_REGEX, source): + self.is_available = True + + def load(self): + try: + with open_archive_file(self.comic.archive, "info.yaml") as yif: + return yaml.safe_load(yif) + except Exception: + return {} + + def scrape(self): + parsers = { + "Title": Title, + "Artist": Artist, + "URL": URL, + "Released": Date.from_timestamp, + "Circle": Circle, + "Parody": self.parse_world, + "Tags": self.parse_tag, + } + + for field, parser in parsers.items(): + if field not in self.data: + continue + + value = self.data[field] + + if isinstance(value, list): + yield from [lambda i=x: parser(i) for x in value] + else: + yield lambda: parser(value) + + yield Language(enums.Language.EN) + yield Direction(enums.Direction.RIGHT_TO_LEFT) + + def parse_world(self, input): + match input: + case "Original Work": + return + + return World(input) + + def parse_tag(self, input): + match input: + case "Unlimited": + return + case "Hentai": + return Rating(value=enums.Rating.EXPLICIT) + case "Non-H" | "Ecchi": + return Rating(value=enums.Rating.QUESTIONABLE) + case "Uncensored": + return Censorship(value=enums.Censorship.NONE) + case _: + return Tag.from_string(input) diff --git a/src/hircine/plugins/scrapers/ehentai_api.py b/src/hircine/plugins/scrapers/ehentai_api.py new file mode 100644 index 0000000..70fcf57 --- /dev/null +++ b/src/hircine/plugins/scrapers/ehentai_api.py @@ -0,0 +1,75 @@ +import html +import json +import re + +import requests + +from hircine.scraper import ScrapeError, Scraper + +from .handlers.exhentai import ExHentaiHandler + +API_URL = "https://api.e-hentai.org/api.php" +URL_REGEX = re.compile( + r"^https?://(?:exhentai|e-hentai).org/g/(?P<id>\d+)/(?P<token>[0-9a-fA-F]+).*" +) + + +class EHentaiAPIScraper(Scraper): + """ + A scraper for the `E-Hentai API <https://ehwiki.org/wiki/API>`_. + + .. list-table:: + :align: left + + * - **Requires** + - The comic :attr:`URL <hircine.api.types.FullComic.url>` pointing to + a gallery on *e-hentai.org* or *exhentai.org* + * - **Source** + - ``exhentai`` + + """ + + name = "e-hentai.org API" + source = "exhentai" + + def __init__(self, comic): + super().__init__(comic) + + if self.comic.url: + match = re.fullmatch(URL_REGEX, self.comic.url) + + if match: + self.is_available = True + self.id = int(match.group("id")) + self.token = match.group("token") + + def scrape(self): + data = json.dumps( + { + "method": "gdata", + "gidlist": [[self.id, self.token]], + "namespace": 1, + }, + separators=(",", ":"), + ) + + request = requests.post(API_URL, data=data) + + if request.status_code == requests.codes.ok: + try: + response = json.loads(request.text)["gmetadata"][0] + + title = response.get("title") + if title: + response["title"] = html.unescape(title) + + title_jpn = response.get("title_jpn") + if title_jpn: + response["title_jpn"] = html.unescape(title_jpn) + + handler = ExHentaiHandler() + yield from handler.scrape(response) + except json.JSONDecodeError: + raise ScrapeError("Could not parse JSON response") + else: + raise ScrapeError(f"Request failed with status code {request.status_code}'") diff --git a/src/hircine/plugins/scrapers/gallery_dl.py b/src/hircine/plugins/scrapers/gallery_dl.py new file mode 100644 index 0000000..a6cebc4 --- /dev/null +++ b/src/hircine/plugins/scrapers/gallery_dl.py @@ -0,0 +1,54 @@ +import json + +from hircine.scraper import Scraper +from hircine.scraper.utils import open_archive_file + +from .handlers.dynastyscans import DynastyScansHandler +from .handlers.e621 import E621Handler +from .handlers.exhentai import ExHentaiHandler +from .handlers.mangadex import MangadexHandler + +HANDLERS = { + "dynastyscans": DynastyScansHandler, + "e621": E621Handler, + "exhentai": ExHentaiHandler, + "mangadex": MangadexHandler, +} + + +class GalleryDLScraper(Scraper): + """ + A scraper for `gallery-dl's <https://github.com/mikf/gallery-dl>`_ + ``info.json`` files. For now supports only a select subset of extractors. + + .. list-table:: + :align: left + + * - **Requires** + - ``info.json`` in the archive or as a sidecar. + * - **Source** + - ``dynastyscans``, ``e621``, ``exhentai``, ``mangadex`` + """ + + def __init__(self, comic): + super().__init__(comic) + + self.data = self.load() + category = self.data.get("category") + + if category in HANDLERS.keys(): + self.is_available = True + + self.handler = HANDLERS.get(category)() + self.source = self.handler.source + self.name = f"gallery-dl info.json ({self.source})" + + def load(self): + try: + with open_archive_file(self.comic.archive, "info.json") as jif: + return json.load(jif) + except Exception: + return {} + + def scrape(self): + yield from self.handler.scrape(self.data) diff --git a/src/hircine/plugins/scrapers/handlers/__init__.py b/src/hircine/plugins/scrapers/handlers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/hircine/plugins/scrapers/handlers/__init__.py diff --git a/src/hircine/plugins/scrapers/handlers/dynastyscans.py b/src/hircine/plugins/scrapers/handlers/dynastyscans.py new file mode 100644 index 0000000..ded015b --- /dev/null +++ b/src/hircine/plugins/scrapers/handlers/dynastyscans.py @@ -0,0 +1,41 @@ +import hircine.enums as enums +from hircine.scraper import ScrapeWarning +from hircine.scraper.types import ( + Artist, + Circle, + Date, + Language, + Title, +) +from hircine.scraper.utils import parse_dict + + +class DynastyScansHandler: + source = "dynastyscans" + + def scrape(self, data): + parsers = { + "date": Date.from_iso, + "lang": self.parse_language, + "author": Artist, + "group": Circle, + } + + yield from parse_dict(parsers, data) + + if manga := data.get("manga"): + title = manga + + if chapter := data.get("chapter"): + title = title + f" Ch. {chapter}" + + if subtitle := data.get("title"): + title = title + f": {subtitle}" + + yield Title(title) + + def parse_language(self, input): + try: + return Language(value=enums.Language[input.upper()]) + except (KeyError, ValueError) as e: + raise ScrapeWarning(f"Could not parse language: '{input}'") from e diff --git a/src/hircine/plugins/scrapers/handlers/e621.py b/src/hircine/plugins/scrapers/handlers/e621.py new file mode 100644 index 0000000..6b798fd --- /dev/null +++ b/src/hircine/plugins/scrapers/handlers/e621.py @@ -0,0 +1,81 @@ +import hircine.enums as enums +from hircine.scraper import ScrapeWarning +from hircine.scraper.types import ( + URL, + Artist, + Category, + Censorship, + Character, + Date, + Language, + Rating, + Tag, + Title, + World, +) +from hircine.scraper.utils import parse_dict + + +def replace_underscore(fun): + return lambda input: fun(input.replace("_", " ")) + + +class E621Handler: + source = "e621" + + ratings = { + "e": Rating(enums.Rating.EXPLICIT), + "q": Rating(enums.Rating.QUESTIONABLE), + "s": Rating(enums.Rating.SAFE), + } + + def scrape(self, data): + match data.get("subcategory"): + case "pool": + yield from self.scrape_pool(data) + + def scrape_pool(self, data): + parsers = { + "date": Date.from_iso, + "rating": self.ratings.get, + "pool": { + "id": lambda pid: URL(f"https://e621.net/pools/{pid}"), + "name": Title, + }, + "tags": { + "general": replace_underscore(Tag.from_string), + "artist": replace_underscore(Artist), + "character": replace_underscore(Character), + "copyright": replace_underscore(World), + "species": replace_underscore(Tag.from_string), + "meta": self.parse_meta, + }, + } + + self.is_likely_uncensored = True + + yield from parse_dict(parsers, data) + + if self.is_likely_uncensored: + yield Censorship(enums.Censorship.NONE) + + def parse_meta(self, input): + match input: + case "comic": + return Category(enums.Category.COMIC) + case "censor_bar": + self.is_likely_uncensored = False + return Censorship(enums.Censorship.BAR) + case "mosaic_censorship": + self.is_likely_uncensored = False + return Censorship(enums.Censorship.MOSAIC) + case "uncensored": + return Censorship(enums.Censorship.NONE) + + if input.endswith("_text"): + lang, _ = input.split("_text", 1) + + try: + return Language(value=enums.Language(lang.capitalize())) + except ValueError as e: + raise ScrapeWarning(f"Could not parse language: '{input}'") from e diff --git a/src/hircine/plugins/scrapers/handlers/exhentai.py b/src/hircine/plugins/scrapers/handlers/exhentai.py new file mode 100644 index 0000000..12c22d7 --- /dev/null +++ b/src/hircine/plugins/scrapers/handlers/exhentai.py @@ -0,0 +1,139 @@ +import re + +import hircine.enums as enums +from hircine.scraper import ScrapeWarning +from hircine.scraper.types import ( + URL, + Artist, + Category, + Censorship, + Character, + Circle, + Date, + Direction, + Language, + OriginalTitle, + Rating, + Tag, + Title, + World, +) +from hircine.scraper.utils import parse_dict + + +def sanitize(title, split=False): + text = re.sub(r"\[[^\]]+\]|{[^}]+}|=[^=]+=|^\([^)]+\)", "", title) + if "|" in text and split: + orig, text = text.split("|", 1) + + return re.sub(r"\s{2,}", " ", text).strip() + + +class ExHentaiHandler: + source = "exhentai" + + def scrape(self, data): + category_field = "eh_category" if "eh_category" in data else "category" + + parsers = { + category_field: self.parse_category, + "posted": Date.from_timestamp, + "date": Date.from_iso, + "lang": self.parse_language, + "tags": self.parse_tag, + "title": lambda t: Title(sanitize(t, split=True)), + "title_jpn": lambda t: OriginalTitle(sanitize(t)), + } + + self.is_likely_pornographic = True + self.is_likely_rtl = False + self.has_censorship_tag = False + self.is_western = False + + yield from parse_dict(parsers, data) + + if self.is_likely_pornographic: + yield Rating(enums.Rating.EXPLICIT) + + if not self.has_censorship_tag: + if self.is_western: + yield Censorship(enums.Censorship.NONE) + else: + yield Censorship(enums.Censorship.BAR) + + if self.is_likely_rtl: + yield Direction(enums.Direction.RIGHT_TO_LEFT) + + if (gid := data["gid"]) and (token := data["token"]): + yield URL(f"https://exhentai.org/g/{gid}/{token}") + + def parse_category(self, input): + match input.lower(): + case "doujinshi": + self.is_likely_rtl = True + return Category(value=enums.Category.DOUJINSHI) + case "manga": + self.is_likely_rtl = True + return Category(value=enums.Category.MANGA) + case "western": + self.is_western = True + case "artist cg": + return Category(value=enums.Category.COMIC) + case "game cg": + return Category(value=enums.Category.GAME_CG) + case "image set": + return Category(value=enums.Category.IMAGE_SET) + case "non-h": + self.is_likely_pornographic = False + return Rating(value=enums.Rating.QUESTIONABLE) + + def parse_tag(self, input): + match input.split(":"): + case ["parody", value]: + return World(value) + case ["group", value]: + return Circle(value) + case ["artist", value]: + return Artist(value) + case ["character", value]: + return Character(value) + case ["language", value]: + return self.parse_language(value, from_value=True) + case ["other", "artbook"]: + return Category(enums.Category.ARTBOOK) + case ["other", "full censorship"]: + self.has_censorship_tag = True + return Censorship(enums.Censorship.FULL) + case ["other", "mosaic censorship"]: + self.has_censorship_tag = True + return Censorship(enums.Censorship.MOSAIC) + case ["other", "uncensored"]: + self.has_censorship_tag = True + return Censorship(enums.Censorship.NONE) + case ["other", "non-h imageset" | "western imageset"]: + return Category(value=enums.Category.IMAGE_SET) + case ["other", "western non-h"]: + self.is_likely_pornographic = False + return Rating(value=enums.Rating.QUESTIONABLE) + case ["other", "comic"]: + return Category(value=enums.Category.COMIC) + case ["other", "variant set"]: + return Category(value=enums.Category.VARIANT_SET) + case ["other", "webtoon"]: + return Category(value=enums.Category.WEBTOON) + case [namespace, tag]: + return Tag(namespace=namespace, tag=tag) + case [tag]: + return Tag(namespace=None, tag=tag) + + def parse_language(self, input, from_value=False): + if not input or input in ["translated", "speechless", "N/A"]: + return + + try: + if from_value: + return Language(value=enums.Language(input.capitalize())) + else: + return Language(value=enums.Language[input.upper()]) + except (KeyError, ValueError) as e: + raise ScrapeWarning(f"Could not parse language: '{input}'") from e diff --git a/src/hircine/plugins/scrapers/handlers/mangadex.py b/src/hircine/plugins/scrapers/handlers/mangadex.py new file mode 100644 index 0000000..7bc371d --- /dev/null +++ b/src/hircine/plugins/scrapers/handlers/mangadex.py @@ -0,0 +1,54 @@ +import hircine.enums as enums +from hircine.scraper import ScrapeWarning +from hircine.scraper.types import ( + URL, + Artist, + Circle, + Date, + Language, + Tag, + Title, +) +from hircine.scraper.utils import parse_dict + + +class MangadexHandler: + source = "mangadex" + + def scrape(self, data): + parsers = { + "date": Date.from_iso, + "lang": self.parse_language, + "tags": Tag.from_string, + "artist": Artist, + "author": Artist, + "group": Circle, + } + + yield from parse_dict(parsers, data) + + if chapter_id := data.get("chapter_id"): + yield URL(f"https://mangadex.org/chapter/{chapter_id}") + + if manga := data.get("manga"): + title = manga + + if volume := data.get("volume"): + title = title + f" Vol. {volume}" + + if chapter := data.get("chapter"): + if volume: + title = title + f", Ch. {chapter}" + else: + title = title + f"Ch. {chapter}" + + if subtitle := data.get("title"): + title = title + f": {subtitle}" + + yield Title(title) + + def parse_language(self, input): + try: + return Language(value=enums.Language[input.upper()]) + except (KeyError, ValueError) as e: + raise ScrapeWarning(f"Could not parse language: '{input}'") from e diff --git a/src/hircine/scanner.py b/src/hircine/scanner.py new file mode 100644 index 0000000..162e1f0 --- /dev/null +++ b/src/hircine/scanner.py @@ -0,0 +1,320 @@ +import asyncio +import multiprocessing +import os +import platform +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor +from datetime import datetime, timezone +from enum import Enum +from hashlib import file_digest +from typing import List, NamedTuple +from zipfile import ZipFile, is_zipfile + +from blake3 import blake3 +from natsort import natsorted, ns +from sqlalchemy import insert, select, update +from sqlalchemy.dialects.sqlite import insert as sqlite_upsert +from sqlalchemy.orm import raiseload + +import hircine.db as db +from hircine.db.models import Archive, Image, Page +from hircine.thumbnailer import Thumbnailer, params_from + + +class Status(Enum): + NEW = "+" + UNCHANGED = "=" + UPDATED = "*" + RENAMED = ">" + IGNORED = "I" + CONFLICT = "!" + MISSING = "?" + REIMAGE = "~" + + +def log(status, path, renamed_to=None): + if status == Status.UNCHANGED: + return + + print(f"[{status.value}]", end=" ") + print(f"{os.path.basename(path)}", end=" " if renamed_to else "\n") + + if renamed_to: + print(f"-> {os.path.basename(renamed_to)}", end="\n") + + +class Registry: + def __init__(self): + self.paths = set() + self.orphans = {} + self.conflicts = {} + self.marked = defaultdict(list) + + def mark(self, status, hash, path, renamed_to=None): + log(status, path, renamed_to) + self.marked[hash].append((path, status)) + + @property + def duplicates(self): + for hash, value in self.marked.items(): + if len(value) > 1: + yield value + + +class Member(NamedTuple): + path: str + hash: str + width: int + height: int + + +class UpdateArchive(NamedTuple): + id: int + path: str + mtime: datetime + + async def execute(self, session): + await session.execute( + update(Archive) + .values(path=self.path, mtime=self.mtime) + .where(Archive.id == self.id) + ) + + +class AddArchive(NamedTuple): + hash: str + path: str + size: int + mtime: datetime + members: List[Member] + + async def upsert_images(self, session): + input = [ + { + "hash": member.hash, + "width": member.width, + "height": member.height, + } + for member in self.members + ] + + images = { + image.hash: image.id + for image in await session.scalars( + sqlite_upsert(Image) + .returning(Image) + .on_conflict_do_nothing(index_elements=["hash"]), + input, + ) + } + + missing = [member.hash for member in self.members if member.hash not in images] + if missing: + for image in await session.scalars( + select(Image).where(Image.hash.in_(missing)) + ): + images[image.hash] = image.id + + return images + + async def execute(self, session): + images = await self.upsert_images(session) + + archive = ( + await session.scalars( + insert(Archive).returning(Archive), + { + "hash": self.hash, + "path": self.path, + "size": self.size, + "mtime": self.mtime, + "cover_id": images[self.members[0].hash], + "page_count": len(self.members), + }, + ) + ).one() + + await session.execute( + insert(Page), + [ + { + "index": index, + "path": member.path, + "image_id": images[member.hash], + "archive_id": archive.id, + } + for index, member in enumerate(self.members) + ], + ) + + +class Scanner: + def __init__(self, config, dirs, reprocess=False): + self.directory = dirs.scan + self.thumbnailer = Thumbnailer(dirs.objects, params_from(config)) + self.registry = Registry() + + self.reprocess = reprocess + + async def scan(self): + if platform.system() == "Windows": + ctx = multiprocessing.get_context("spawn") # pragma: no cover + else: + ctx = multiprocessing.get_context("forkserver") + + workers = multiprocessing.cpu_count() // 2 + + with ProcessPoolExecutor(max_workers=workers, mp_context=ctx) as pool: + async with db.session() as s: + sql = select(Archive).options(raiseload(Archive.cover)) + + for archive in await s.scalars(sql): + action = await self.scan_existing(archive, pool) + + if action: + await action.execute(s) + + async for action in self.scan_dir(self.directory, pool): + await action.execute(s) + + await s.commit() + + def report(self): # pragma: no cover + if self.registry.orphans: + print() + print( + "WARNING: The following paths are referenced in the DB, but do not exist in the file system:" # noqa: E501 + ) + for orphan in self.registry.orphans.values(): + _, path = orphan + log(Status.MISSING, path) + + for duplicate in self.registry.duplicates: + print() + print("WARNING: The following archives contain the same data:") + for path, status in duplicate: + log(status, path) + + for path, conflict in self.registry.conflicts.items(): + db_hash, fs_hash = conflict + print() + print("ERROR: The contents of the following archive have changed:") + log(Status.CONFLICT, path) + print(f" Database: {db_hash}") + print(f" File system: {fs_hash}") + + async def scan_existing(self, archive, pool): + try: + stat = os.stat(archive.path, follow_symlinks=False) + except FileNotFoundError: + self.registry.orphans[archive.hash] = (archive.id, archive.path) + return None + + self.registry.paths.add(archive.path) + + mtime = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc) + + if mtime == archive.mtime: + if self.reprocess: + await self.process_zip(archive.path, pool) + + self.registry.mark(Status.REIMAGE, archive.hash, archive.path) + return None + else: + self.registry.mark(Status.UNCHANGED, archive.hash, archive.path) + return None + + hash, _ = await self.process_zip(archive.path, pool) + + if archive.hash == hash: + self.registry.mark(Status.UPDATED, archive.hash, archive.path) + return UpdateArchive(id=archive.id, path=archive.path, mtime=mtime) + else: + log(Status.CONFLICT, archive.path) + self.registry.conflicts[archive.path] = (archive.hash, hash) + + return None + + async def scan_dir(self, path, pool): + path = os.path.realpath(path) + + for root, dirs, files in os.walk(path): + for file in files: + absolute = os.path.join(path, root, file) + + if os.path.islink(absolute): + continue + + if not is_zipfile(absolute): + continue + + if absolute in self.registry.paths: + continue + + async for result in self.scan_zip(absolute, pool): + yield result + + async def scan_zip(self, path, pool): + stat = os.stat(path, follow_symlinks=False) + mtime = datetime.fromtimestamp(stat.st_mtime, tz=timezone.utc) + + hash, members = await self.process_zip(path, pool) + + if hash in self.registry.marked: + self.registry.mark(Status.IGNORED, hash, path) + return + + if hash in self.registry.orphans: + id, old_path = self.registry.orphans[hash] + del self.registry.orphans[hash] + + self.registry.mark(Status.RENAMED, hash, old_path, renamed_to=path) + yield UpdateArchive(id=id, path=path, mtime=mtime) + return + elif members: + self.registry.mark(Status.NEW, hash, path) + yield AddArchive( + hash=hash, + path=path, + size=stat.st_size, + mtime=mtime, + members=natsorted(members, key=lambda m: m.path, alg=ns.P | ns.IC), + ) + + async def process_zip(self, path, pool): + members = [] + hash = blake3() + + with ZipFile(path, mode="r") as z: + input = [(path, info.filename) for info in z.infolist()] + + loop = asyncio.get_event_loop() + + tasks = [loop.run_in_executor(pool, self.process_member, i) for i in input] + results = await asyncio.gather(*tasks) + for digest, entry in results: + hash.update(digest) + if entry: + members.append(entry) + + return hash.hexdigest(), members + + def process_member(self, input): + path, name = input + + with ZipFile(path, mode="r") as zip: + with zip.open(name, mode="r") as member: + _, ext = os.path.splitext(name) + digest = file_digest(member, blake3).digest() + + if self.thumbnailer.can_process(ext): + hash = digest.hex() + + width, height = self.thumbnailer.process( + member, hash, reprocess=self.reprocess + ) + return digest, Member( + path=member.name, hash=hash, width=width, height=height + ) + + return digest, None diff --git a/src/hircine/scraper/__init__.py b/src/hircine/scraper/__init__.py new file mode 100644 index 0000000..c04265a --- /dev/null +++ b/src/hircine/scraper/__init__.py @@ -0,0 +1,108 @@ +from abc import ABC, abstractmethod + + +class ScraperInfo: + """ + A class containing informational data on a scraper. + + :param str name: The name of the scraper. + :param str source: The data source, usually a well-defined name. + For names used by built-in plugins, refer to the :ref:`Scrapers + reference <builtin-scrapers>`. + :param FullComic comic: The comic being scraped. + """ + + def __init__(self, name, source, comic): + self.name = name + self.source = source + self.comic = comic + + +class ScrapeWarning(Exception): + """ + An exception signalling a non-fatal error. Its message will be shown to the + user once the scraping process concludes. + + This is usually raised within a callable yielded by + :meth:`~hircine.scraper.Scraper.scrape` and should generally only be used + to notify the user that a piece of metadata was ignored because it was + malformed. + """ + + pass + + +class ScrapeError(Exception): + """ + An exception signalling a fatal error, stopping the scraping process + immediately. + + This should only be raised if it is impossible for the scraping process to + continue, for example if a file or URL is inaccessible. + """ + + pass + + +class Scraper(ABC): + """ + The abstract base class for scrapers. + + The following variables **must** be accessible after the instance is initialized: + + :var str name: The name of the scraper (displayed in the scraper dropdown). + :var str source: The data source. Usually a well-defined name. + :var bool is_available: Whether this scraper is available for the given comic. + """ + + name = "Abstract Scraper" + + source = None + is_available = False + + def __init__(self, comic): + """ + Initializes a scraper with the instance of the comic it is scraping. + + :param FullComic comic: The comic being scraped. + """ + self.comic = comic + self.warnings = [] + + @abstractmethod + def scrape(self): + """ + A generator function that yields :ref:`scraped-data` or a callable + returning such data. + + A callable may raise the :exc:`~hircine.scraper.ScrapeWarning` + exception. This exception will be caught automatically and its message + will be collected for display to the user after the scraping process concludes. + """ + pass + + def collect(self, transformers=[]): + def generator(): + for result in self.scrape(): + if callable(result): + try: + yield result() + except ScrapeWarning as e: + self.log_warning(e) + else: + yield result + + gen = generator() + + info = ScraperInfo(name=self.name, source=self.source, comic=self.comic) + + for fun in transformers: + gen = fun(gen, info) + + return gen + + def log_warning(self, warning): + self.warnings.append(warning) + + def get_warnings(self): + return list(map(str, self.warnings)) diff --git a/src/hircine/scraper/types.py b/src/hircine/scraper/types.py new file mode 100644 index 0000000..534792b --- /dev/null +++ b/src/hircine/scraper/types.py @@ -0,0 +1,246 @@ +from dataclasses import dataclass +from datetime import date, datetime + +import hircine.enums + +from . import ScrapeWarning + + +@dataclass(frozen=True) +class Tag: + """ + A :term:`qualified tag`, represented by strings. + + :param str namespace: The namespace. + :param str tag: The tag. + """ + + namespace: str + tag: str + + @classmethod + def from_string(cls, string, delimiter=":"): + """ + Returns a new instance of this class given a textual representation, + usually a qualified tag in the format ``<namespace>:<tag>``. If no + delimiter is found, the namespace is assumed to be ``none`` and the + given string is used as a tag instead. + + :param str string: The string of text representing a qualified tag. + :param str delimiter: The string with which the namespace is delimited + from the tag. + """ + match string.split(delimiter, 1): + case [namespace, tag]: + return cls(namespace=namespace, tag=tag) + return cls(namespace="none", tag=string) + + def to_string(self): + return f"{self.namespace}:{self.tag}" + + def __bool__(self): + return bool(self.namespace) and bool(self.tag) + + +@dataclass(frozen=True) +class Date: + """ + A scraped date. + + :param :class:`~datetime.date` value: The date. + """ + + value: date + + @classmethod + def from_iso(cls, datestring): + """ + Returns a new instance of this class given a textual representation of + a date in the format ``YYYY-MM-DD``. See :meth:`datetime.date.fromisoformat`. + + :param str datestring: The string of text representing a date. + :raise: :exc:`~hircine.scraper.ScrapeWarning` if the date string could + not be parsed. + """ + try: + return cls(value=datetime.fromisoformat(datestring).date()) + except ValueError as e: + raise ScrapeWarning( + f"Could not parse date: '{datestring}' as ISO 8601" + ) from e + + @classmethod + def from_timestamp(cls, timestamp): + """ + Returns a new instance of this class given a textual representation of + a POSIX timestamp. See :meth:`datetime.date.fromtimestamp`. + + :param str timestamp: The string of text representing a POSIX timestamp. + :raise: :exc:`~hircine.scraper.ScrapeWarning` if the timestamp could + not be parsed. + """ + try: + return cls(value=datetime.fromtimestamp(int(timestamp)).date()) + except (OverflowError, OSError, ValueError) as e: + raise ScrapeWarning( + f"Could not parse date: '{timestamp}' as POSIX timestamp" + ) from e + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Rating: + """ + A scraped rating, represented by an enum. + """ + + value: hircine.enums.Rating + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Category: + """ + A scraped category, represented by an enum. + """ + + value: hircine.enums.Category + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Censorship: + """ + A scraped censorship specifier, represented by an enum. + """ + + value: hircine.enums.Censorship + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Language: + """ + A scraped language, represented by an enum. + """ + + value: hircine.enums.Language + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Direction: + """ + A scraped direction, represented by an enum. + """ + + value: hircine.enums.Direction + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Layout: + """ + A scraped layout, represented by an enum. + """ + + value: hircine.enums.Layout + + def __bool__(self): + return self.value is not None + + +@dataclass(frozen=True) +class Title: + """ + A scraped comic title. + """ + + value: str + + def __bool__(self): + return bool(self.value) + + +@dataclass(frozen=True) +class OriginalTitle: + """ + A scraped original title. + """ + + value: str + + def __bool__(self): + return bool(self.value) + + +@dataclass(frozen=True) +class Artist: + """ + A scraped artist. + """ + + name: str + + def __bool__(self): + return bool(self.name) + + +@dataclass(frozen=True) +class Character: + """ + A scraped character. + """ + + name: str + + def __bool__(self): + return bool(self.name) + + +@dataclass(frozen=True) +class Circle: + """ + A scraped circle. + """ + + name: str + + def __bool__(self): + return bool(self.name) + + +@dataclass(frozen=True) +class World: + """ + A scraped world. + """ + + name: str + + def __bool__(self): + return bool(self.name) + + +@dataclass(frozen=True) +class URL: + """ + A scraped URL. + """ + + value: str + + def __bool__(self): + return bool(self.value) diff --git a/src/hircine/scraper/utils.py b/src/hircine/scraper/utils.py new file mode 100644 index 0000000..6afa2ed --- /dev/null +++ b/src/hircine/scraper/utils.py @@ -0,0 +1,62 @@ +import os +from contextlib import contextmanager +from zipfile import ZipFile + + +def parse_dict(parsers, data): + """ + Make a generator that yields callables applying parser functions to their + matching input data. *parsers* and *data* must both be dictionaries. Parser + functions are matched to input data using their dictionary keys. If a + parser's key is not present in *data*, it is ignored. + + A key in *parsers* may map to another dictionary of parsers. In this case, + this function will be applied recursively to the matching value in *data*, + which is assumed to be a dictionary as well. + + If a parser is matched to a list type, one callable for each list item is + yielded. + + :param dict parsers: A mapping of parsers. + :param dict data: A mapping of data to be parsed. + """ + for field, parser in parsers.items(): + if field not in data: + continue + + value = data[field] + + if isinstance(value, list): + yield from [lambda i=x: parser(i) for x in value] + elif isinstance(value, dict): + yield from parse_dict(parser, value) + else: + yield lambda: parser(value) + + +@contextmanager +def open_archive_file(archive, member, check_sidecar=True): # pragma: no cover + """ + Open an archive file for use with the :ref:`with <with>` statement. Yields + a :term:`file object` obtained from: + + 1. The archive's :ref:`sidecar file <sidecar-files>`, if it exists and + *check_sidecar* is ``True``. + 2. Otherwise, the archive itself. + + :param Archive archive: The archive. + :param str member: The name of the file within the archive (or its sidecar suffix). + :param bool check_sidecar: Whether to check for the sidecar file. + """ + if check_sidecar: + sidecar = f"{archive.path}.{member}" + + if os.path.exists(sidecar): + with open(sidecar, "r") as file: + yield file + + return + + with ZipFile(archive.path, "r") as zip: + with zip.open(member, "r") as file: + yield file diff --git a/src/hircine/thumbnailer.py b/src/hircine/thumbnailer.py new file mode 100644 index 0000000..ed565d5 --- /dev/null +++ b/src/hircine/thumbnailer.py @@ -0,0 +1,75 @@ +import os +from typing import NamedTuple + +from PIL import Image + +pillow_extensions = { + ext for ext, f in Image.registered_extensions().items() if f in Image.OPEN +} + + +class ThumbnailParameters(NamedTuple): + bounds: tuple[int, int] + options: dict + + +def params_from(config): + return { + "full": ThumbnailParameters( + bounds=( + config.getint("import.scale.full", "width", fallback=4200), + config.getint("import.scale.full", "height", fallback=2000), + ), + options={"quality": 82, "method": 5}, + ), + "thumb": ThumbnailParameters( + bounds=( + config.getint("import.scale.thumb", "width", fallback=1680), + config.getint("import.scale.thumb", "height", fallback=800), + ), + options={"quality": 75, "method": 5}, + ), + } + + +def object_path(directory, hash, suffix): + return os.path.join(directory, hash[:2], f"{hash[2:]}_{suffix}.webp") + + +class Thumbnailer: + def __init__(self, directory, params): + self.directory = directory + self.params = params + + @classmethod + def can_process(cls, extension): + return extension in pillow_extensions + + def object(self, hash, suffix): + return object_path(self.directory, hash, suffix) + + def process(self, handle, hash, reprocess=False): + size = None + + for suffix, parameters in self.params.items(): + source = Image.open(handle, mode="r") + + if not size: + size = source.size + + output = self.object(hash, suffix) + + if os.path.exists(output) and not reprocess: + continue + else: + os.makedirs(os.path.dirname(output), exist_ok=True) + + if source.mode != "RGB": + target = source.convert() + else: + target = source + + target.thumbnail(parameters.bounds, resample=Image.Resampling.LANCZOS) + target.save(output, **parameters.options) + + return size |