diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/hircine/api/filters.py | 53 | ||||
-rw-r--r-- | src/hircine/db/models.py | 21 | ||||
-rw-r--r-- | src/hircine/enums.py | 8 |
3 files changed, 69 insertions, 13 deletions
diff --git a/src/hircine/api/filters.py b/src/hircine/api/filters.py index 807178b..7ed5649 100644 --- a/src/hircine/api/filters.py +++ b/src/hircine/api/filters.py @@ -7,7 +7,7 @@ from strawberry import UNSET import hircine.db from hircine.db.models import ComicTag -from hircine.enums import Category, Censorship, Language, Rating +from hircine.enums import Category, Censorship, Language, Operator, Rating T = TypeVar("T") @@ -28,11 +28,23 @@ class Matchable(ABC): @strawberry.input +class CountFilter: + operator: Optional[Operator] = Operator.EQUAL + value: int + + def include(self, column, sql): + return sql.where(self.operator.value(column, self.value)) + + def exclude(self, column, sql): + return sql.where(~self.operator.value(column, self.value)) + + +@strawberry.input class AssociationFilter(Matchable): any: Optional[list[int]] = strawberry.field(default_factory=lambda: None) all: Optional[list[int]] = strawberry.field(default_factory=lambda: None) exact: Optional[list[int]] = strawberry.field(default_factory=lambda: None) - empty: Optional[bool] = None + count: Optional[CountFilter] = UNSET def _exists(self, condition): # The property.primaryjoin expression specifies the primary join path @@ -71,12 +83,6 @@ class AssociationFilter(Matchable): def _where_not_all_exist(self, sql): return sql.where(~self._all_exist(self.all)) - def _empty(self): - if self.empty: - return ~self._exists(True) - else: - return self._exists(True) - def _count_of(self, column): return ( select(func.count(column)) @@ -117,8 +123,8 @@ class AssociationFilter(Matchable): elif self.all == []: sql = sql.where(False) - if self.empty is not None: - sql = sql.where(self._empty()) + if self.count: + sql = self.count.include(self.count_column, sql) if self.exact is not None: sql = sql.where(self._exact()) @@ -134,8 +140,8 @@ class AssociationFilter(Matchable): if self.all: sql = self._where_not_all_exist(sql) - if self.empty is not None: - sql = sql.where(~self._empty()) + if self.count: + sql = self.count.exclude(self.count_column, sql) if self.exact is not None: sql = sql.where(~self._exact()) @@ -160,8 +166,14 @@ class Root: column = getattr(self._model, field, None) + # count columns are historically singular, so we need this hack + singular_field = field[:-1] + count_column = getattr(self._model, f"{singular_field}_count", None) + if issubclass(type(matcher), Matchable): matcher.column = column + matcher.count_column = count_column + if not negate: sql = matcher.include(sql) else: @@ -213,6 +225,17 @@ class StringFilter(Matchable): @strawberry.input +class BasicCountFilter(Matchable): + count: CountFilter + + def include(self, sql): + return self.count.include(self.count_column, sql) + + def exclude(self, sql): + return self.count.exclude(self.count_column, sql) + + +@strawberry.input class TagAssociationFilter(AssociationFilter): """ Tags need special handling since their IDs are strings instead of numbers. @@ -314,24 +337,28 @@ class ArchiveFilter(Root): @strawberry.input class ArtistFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Character") @strawberry.input class CharacterFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Circle") @strawberry.input class CircleFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Namespace") @strawberry.input class NamespaceFilter(Root): name: Optional[StringFilter] = UNSET + tags: Optional[BasicCountFilter] = UNSET @hircine.db.model("Tag") @@ -339,9 +366,11 @@ class NamespaceFilter(Root): class TagFilter(Root): name: Optional[StringFilter] = UNSET namespaces: Optional[AssociationFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("World") @strawberry.input class WorldFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET diff --git a/src/hircine/db/models.py b/src/hircine/db/models.py index f204998..5d1a59a 100644 --- a/src/hircine/db/models.py +++ b/src/hircine/db/models.py @@ -356,7 +356,10 @@ class ComicWorld(Base): def defer_relationship_count(relationship, secondary=False): - left, right = relationship.property.synchronize_pairs[0] + if secondary: + left, right = relationship.property.secondary_synchronize_pairs[0] + else: + left, right = relationship.property.synchronize_pairs[0] return deferred( select(func.count(right)) @@ -366,7 +369,23 @@ def defer_relationship_count(relationship, secondary=False): ) +Comic.artist_count = defer_relationship_count(Comic.artists) +Comic.character_count = defer_relationship_count(Comic.characters) +Comic.circle_count = defer_relationship_count(Comic.circles) Comic.tag_count = defer_relationship_count(Comic.tags) +Comic.world_count = defer_relationship_count(Comic.worlds) + +Artist.comic_count = defer_relationship_count(Comic.artists, secondary=True) +Character.comic_count = defer_relationship_count(Comic.characters, secondary=True) +Circle.comic_count = defer_relationship_count(Comic.circles, secondary=True) +Namespace.tag_count = defer_relationship_count(Tag.namespaces, secondary=True) +Tag.comic_count = deferred( + select(func.count(ComicTag.tag_id)) + .where(Tag.id == ComicTag.tag_id) + .scalar_subquery() +) +Tag.namespace_count = defer_relationship_count(Tag.namespaces) +World.comic_count = defer_relationship_count(Comic.worlds, secondary=True) @event.listens_for(Comic.pages, "bulk_replace") diff --git a/src/hircine/enums.py b/src/hircine/enums.py index 7f95f02..f267270 100644 --- a/src/hircine/enums.py +++ b/src/hircine/enums.py @@ -1,4 +1,5 @@ import enum +import operator import strawberry @@ -57,6 +58,13 @@ class OnMissing(enum.Enum): @strawberry.enum +class Operator(enum.Enum): + GREATER_THAN = operator.gt + LOWER_THAN = operator.lt + EQUAL = operator.eq + + +@strawberry.enum class Language(enum.Enum): AA = "Afar" AB = "Abkhazian" |