diff options
-rw-r--r-- | frontend/src/gql/graphql.ts | 25 | ||||
-rw-r--r-- | frontend/src/lib/Enums.ts | 7 | ||||
-rw-r--r-- | frontend/src/lib/Filter.svelte.ts | 25 | ||||
-rw-r--r-- | src/hircine/api/filters.py | 53 | ||||
-rw-r--r-- | src/hircine/db/models.py | 21 | ||||
-rw-r--r-- | src/hircine/enums.py | 8 | ||||
-rw-r--r-- | tests/api/test_filter.py | 99 |
7 files changed, 194 insertions, 44 deletions
diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index bd001f3..63f2c55 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -122,6 +122,7 @@ export type Artist = { }; export type ArtistFilter = { + comics?: InputMaybe<BasicCountFilter>; name?: InputMaybe<StringFilter>; }; @@ -164,10 +165,14 @@ export type ArtistsUpsertInput = { export type AssociationFilter = { all?: InputMaybe<Array<Scalars['Int']['input']>>; any?: InputMaybe<Array<Scalars['Int']['input']>>; - empty?: InputMaybe<Scalars['Boolean']['input']>; + count?: InputMaybe<CountFilter>; exact?: InputMaybe<Array<Scalars['Int']['input']>>; }; +export type BasicCountFilter = { + count: CountFilter; +}; + export enum Category { Artbook = 'ARTBOOK', Comic = 'COMIC', @@ -203,6 +208,7 @@ export type Character = { }; export type CharacterFilter = { + comics?: InputMaybe<BasicCountFilter>; name?: InputMaybe<StringFilter>; }; @@ -249,6 +255,7 @@ export type Circle = { }; export type CircleFilter = { + comics?: InputMaybe<BasicCountFilter>; name?: InputMaybe<StringFilter>; }; @@ -396,6 +403,11 @@ export type ComicTotals = { worlds: Scalars['Int']['output']; }; +export type CountFilter = { + operator?: InputMaybe<Operator>; + value: Scalars['Int']['input']; +}; + export type CoverInput = { id: Scalars['Int']['input']; }; @@ -862,6 +874,7 @@ export type Namespace = { export type NamespaceFilter = { name?: InputMaybe<StringFilter>; + tags?: InputMaybe<BasicCountFilter>; }; export type NamespaceFilterInput = { @@ -905,6 +918,12 @@ export enum OnMissing { Ignore = 'IGNORE' } +export enum Operator { + Equal = 'EQUAL', + GreaterThan = 'GREATER_THAN', + LowerThan = 'LOWER_THAN' +} + export type Page = { __typename?: 'Page'; comicId?: Maybe<Scalars['Int']['output']>; @@ -1153,11 +1172,12 @@ export type Tag = { export type TagAssociationFilter = { all?: InputMaybe<Array<Scalars['String']['input']>>; any?: InputMaybe<Array<Scalars['String']['input']>>; - empty?: InputMaybe<Scalars['Boolean']['input']>; + count?: InputMaybe<CountFilter>; exact?: InputMaybe<Array<Scalars['String']['input']>>; }; export type TagFilter = { + comics?: InputMaybe<BasicCountFilter>; name?: InputMaybe<StringFilter>; namespaces?: InputMaybe<AssociationFilter>; }; @@ -1324,6 +1344,7 @@ export type World = { }; export type WorldFilter = { + comics?: InputMaybe<BasicCountFilter>; name?: InputMaybe<StringFilter>; }; diff --git a/frontend/src/lib/Enums.ts b/frontend/src/lib/Enums.ts index 3264de4..db9fb86 100644 --- a/frontend/src/lib/Enums.ts +++ b/frontend/src/lib/Enums.ts @@ -10,6 +10,7 @@ import { Language, Layout, NamespaceSort, + Operator, Rating, TagSort, UpdateMode, @@ -125,6 +126,12 @@ export const UpdateModeLabel: Record<UpdateMode, string> = { [UpdateMode.Replace]: 'Replace' }; +export const OperatorLabel: Record<Operator, string> = { + [Operator.Equal]: 'Equal', + [Operator.GreaterThan]: 'Greater than', + [Operator.LowerThan]: 'Lower than,' +}; + export const LanguageLabel: Record<Language, string> = { [Language.Ab]: 'Abkhazian', [Language.Aa]: 'Afar', diff --git a/frontend/src/lib/Filter.svelte.ts b/frontend/src/lib/Filter.svelte.ts index 6183f06..e73f497 100644 --- a/frontend/src/lib/Filter.svelte.ts +++ b/frontend/src/lib/Filter.svelte.ts @@ -1,4 +1,5 @@ import { + Operator, type ArchiveFilter, type ArchiveFilterInput, type ComicFilter, @@ -30,7 +31,7 @@ type AssocFilter<T, K extends Key> = Filter< any?: T[] | null; all?: T[] | null; exact?: T[] | null; - empty?: boolean | null; + count?: { value: number; operator?: Operator | null } | null; }, K >; @@ -62,10 +63,6 @@ class ComplexMember<K extends Key> { if (this.values.length > 0) { filter[this.key] = { [this.mode]: this.values }; } - - if (this.empty) { - filter[this.key] = { ...filter[this.key], empty: this.empty }; - } } } @@ -80,7 +77,9 @@ export class Association<K extends Key> extends ComplexMember<K> { } const prop = filter[key]; - this.empty = prop?.empty; + this.empty = + prop?.count?.value === 0 && + (prop.count.operator === undefined || prop.count.operator === Operator.Equal); if (prop?.all && prop.all.length > 0) { this.mode = 'all'; @@ -93,6 +92,13 @@ export class Association<K extends Key> extends ComplexMember<K> { this.values = prop.exact; } } + + integrate(filter: AssocFilter<unknown, K>) { + super.integrate(filter); + if (this.empty) { + filter[this.key] = { ...filter[this.key], count: { value: 0, operator: Operator.Equal } }; + } + } } export class Enum<K extends Key> extends ComplexMember<K> { @@ -112,6 +118,13 @@ export class Enum<K extends Key> extends ComplexMember<K> { this.values = prop.any; } } + + integrate(filter: EnumFilter<K>) { + super.integrate(filter); + if (this.empty) { + filter[this.key] = { ...filter[this.key], empty: this.empty }; + } + } } class Bool<K extends Key> { diff --git a/src/hircine/api/filters.py b/src/hircine/api/filters.py index 807178b..7ed5649 100644 --- a/src/hircine/api/filters.py +++ b/src/hircine/api/filters.py @@ -7,7 +7,7 @@ from strawberry import UNSET import hircine.db from hircine.db.models import ComicTag -from hircine.enums import Category, Censorship, Language, Rating +from hircine.enums import Category, Censorship, Language, Operator, Rating T = TypeVar("T") @@ -28,11 +28,23 @@ class Matchable(ABC): @strawberry.input +class CountFilter: + operator: Optional[Operator] = Operator.EQUAL + value: int + + def include(self, column, sql): + return sql.where(self.operator.value(column, self.value)) + + def exclude(self, column, sql): + return sql.where(~self.operator.value(column, self.value)) + + +@strawberry.input class AssociationFilter(Matchable): any: Optional[list[int]] = strawberry.field(default_factory=lambda: None) all: Optional[list[int]] = strawberry.field(default_factory=lambda: None) exact: Optional[list[int]] = strawberry.field(default_factory=lambda: None) - empty: Optional[bool] = None + count: Optional[CountFilter] = UNSET def _exists(self, condition): # The property.primaryjoin expression specifies the primary join path @@ -71,12 +83,6 @@ class AssociationFilter(Matchable): def _where_not_all_exist(self, sql): return sql.where(~self._all_exist(self.all)) - def _empty(self): - if self.empty: - return ~self._exists(True) - else: - return self._exists(True) - def _count_of(self, column): return ( select(func.count(column)) @@ -117,8 +123,8 @@ class AssociationFilter(Matchable): elif self.all == []: sql = sql.where(False) - if self.empty is not None: - sql = sql.where(self._empty()) + if self.count: + sql = self.count.include(self.count_column, sql) if self.exact is not None: sql = sql.where(self._exact()) @@ -134,8 +140,8 @@ class AssociationFilter(Matchable): if self.all: sql = self._where_not_all_exist(sql) - if self.empty is not None: - sql = sql.where(~self._empty()) + if self.count: + sql = self.count.exclude(self.count_column, sql) if self.exact is not None: sql = sql.where(~self._exact()) @@ -160,8 +166,14 @@ class Root: column = getattr(self._model, field, None) + # count columns are historically singular, so we need this hack + singular_field = field[:-1] + count_column = getattr(self._model, f"{singular_field}_count", None) + if issubclass(type(matcher), Matchable): matcher.column = column + matcher.count_column = count_column + if not negate: sql = matcher.include(sql) else: @@ -213,6 +225,17 @@ class StringFilter(Matchable): @strawberry.input +class BasicCountFilter(Matchable): + count: CountFilter + + def include(self, sql): + return self.count.include(self.count_column, sql) + + def exclude(self, sql): + return self.count.exclude(self.count_column, sql) + + +@strawberry.input class TagAssociationFilter(AssociationFilter): """ Tags need special handling since their IDs are strings instead of numbers. @@ -314,24 +337,28 @@ class ArchiveFilter(Root): @strawberry.input class ArtistFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Character") @strawberry.input class CharacterFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Circle") @strawberry.input class CircleFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("Namespace") @strawberry.input class NamespaceFilter(Root): name: Optional[StringFilter] = UNSET + tags: Optional[BasicCountFilter] = UNSET @hircine.db.model("Tag") @@ -339,9 +366,11 @@ class NamespaceFilter(Root): class TagFilter(Root): name: Optional[StringFilter] = UNSET namespaces: Optional[AssociationFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET @hircine.db.model("World") @strawberry.input class WorldFilter(Root): name: Optional[StringFilter] = UNSET + comics: Optional[BasicCountFilter] = UNSET diff --git a/src/hircine/db/models.py b/src/hircine/db/models.py index f204998..5d1a59a 100644 --- a/src/hircine/db/models.py +++ b/src/hircine/db/models.py @@ -356,7 +356,10 @@ class ComicWorld(Base): def defer_relationship_count(relationship, secondary=False): - left, right = relationship.property.synchronize_pairs[0] + if secondary: + left, right = relationship.property.secondary_synchronize_pairs[0] + else: + left, right = relationship.property.synchronize_pairs[0] return deferred( select(func.count(right)) @@ -366,7 +369,23 @@ def defer_relationship_count(relationship, secondary=False): ) +Comic.artist_count = defer_relationship_count(Comic.artists) +Comic.character_count = defer_relationship_count(Comic.characters) +Comic.circle_count = defer_relationship_count(Comic.circles) Comic.tag_count = defer_relationship_count(Comic.tags) +Comic.world_count = defer_relationship_count(Comic.worlds) + +Artist.comic_count = defer_relationship_count(Comic.artists, secondary=True) +Character.comic_count = defer_relationship_count(Comic.characters, secondary=True) +Circle.comic_count = defer_relationship_count(Comic.circles, secondary=True) +Namespace.tag_count = defer_relationship_count(Tag.namespaces, secondary=True) +Tag.comic_count = deferred( + select(func.count(ComicTag.tag_id)) + .where(Tag.id == ComicTag.tag_id) + .scalar_subquery() +) +Tag.namespace_count = defer_relationship_count(Tag.namespaces) +World.comic_count = defer_relationship_count(Comic.worlds, secondary=True) @event.listens_for(Comic.pages, "bulk_replace") diff --git a/src/hircine/enums.py b/src/hircine/enums.py index 7f95f02..f267270 100644 --- a/src/hircine/enums.py +++ b/src/hircine/enums.py @@ -1,4 +1,5 @@ import enum +import operator import strawberry @@ -57,6 +58,13 @@ class OnMissing(enum.Enum): @strawberry.enum +class Operator(enum.Enum): + GREATER_THAN = operator.gt + LOWER_THAN = operator.lt + EQUAL = operator.eq + + +@strawberry.enum class Language(enum.Enum): AA = "Afar" AB = "Abkhazian" diff --git a/tests/api/test_filter.py b/tests/api/test_filter.py index 1438785..6eb2934 100644 --- a/tests/api/test_filter.py +++ b/tests/api/test_filter.py @@ -421,51 +421,59 @@ async def test_field_presence(query_comic_filter, gen_comic, empty_comic, filter "filter,ids", [ ( - {"include": {"artists": {"empty": True}}}, + {"include": {"artists": {"count": {"value": 0}}}}, [100], ), ( - {"include": {"artists": {"empty": False}}}, - [1, 2], + {"include": {"artists": {"count": {"value": 0, "operator": "EQUAL"}}}}, + [100], ), ( - {"exclude": {"artists": {"empty": True}}}, - [1, 2], + { + "include": { + "artists": {"count": {"value": 1, "operator": "GREATER_THAN"}} + } + }, + [1], ), ( - {"exclude": {"artists": {"empty": False}}}, - [100], + {"include": {"artists": {"count": {"value": 3, "operator": "LOWER_THAN"}}}}, + [1, 2, 100], ), ( - {"include": {"tags": {"empty": True}}}, - [100], + {"exclude": {"artists": {"count": {"value": 0}}}}, + [1, 2], ), ( - {"include": {"tags": {"empty": False}}}, + {"exclude": {"artists": {"count": {"value": 0, "operator": "EQUAL"}}}}, [1, 2], ), ( - {"exclude": {"tags": {"empty": True}}}, - [1, 2], + { + "exclude": { + "artists": {"count": {"value": 1, "operator": "GREATER_THAN"}} + } + }, + [2, 100], ), ( - {"exclude": {"tags": {"empty": False}}}, - [100], + {"exclude": {"artists": {"count": {"value": 3, "operator": "LOWER_THAN"}}}}, + [], ), ], ids=[ - "includes artist empty", - "includes artist not empty", - "excludes artist empty", - "excludes artist not empty", - "includes tags empty", - "includes tags not empty", - "excludes tags empty", - "excludes tags not empty", + "include equal (default)", + "include equal (explicit)", + "include greater than", + "include lower than", + "exclude equal (default)", + "exclude equal (explicit)", + "exclude greater than", + "exclude lower than", ], ) @pytest.mark.anyio -async def test_assoc_presence(query_comic_filter, gen_comic, empty_comic, filter, ids): +async def test_assoc_counts(query_comic_filter, gen_comic, empty_comic, filter, ids): await DB.add(next(gen_comic)) await DB.add(next(gen_comic)) await DB.add(empty_comic) @@ -520,3 +528,48 @@ async def test_tag_assoc_filter(query_tag_filter, gen_namespace, gen_tag, filter response.assert_is("TagFilterResult") assert id_list(response.edges) == ids + + +@pytest.mark.parametrize( + "filter,expect", + [ + ({"include": {"comics": {"count": {"value": 1}}}}, [2, 3]), + ({"include": {"comics": {"count": {"value": 2, "operator": "EQUAL"}}}}, [1, 4]), + ( + { + "include": { + "comics": {"count": {"value": 3, "operator": "GREATER_THAN"}} + } + }, + [], + ), + ( + {"include": {"comics": {"count": {"value": 2, "operator": "LOWER_THAN"}}}}, + [2, 3], + ), + ( + {"exclude": {"comics": {"count": {"value": 1}}}}, + [1, 4], + ), + ( + {"exclude": {"comics": {"count": {"value": 1, "operator": "LOWER_THAN"}}}}, + [1, 2, 3, 4], + ), + ], + ids=[ + "include equal (default)", + "include equal (explicit)", + "include greater than", + "include lower than", + "exclude equal (default)", + "exclude lower than", + ], +) +@pytest.mark.anyio +async def test_count_filter(query_string_filter, gen_comic, filter, expect): + await DB.add_all(*gen_comic) + + response = Response(await query_string_filter(filter)) + response.assert_is("ArtistFilterResult") + + assert id_list(response.edges) == expect |