diff --git a/pkgs/additional/sane-scripts/src/sane-bt-search b/pkgs/additional/sane-scripts/src/sane-bt-search index 8f4209a3..6a13b3a3 100755 --- a/pkgs/additional/sane-scripts/src/sane-bt-search +++ b/pkgs/additional/sane-scripts/src/sane-bt-search @@ -84,6 +84,18 @@ def is_cat(cats: list[str], wanted_cats: list[str], default: bool = False) -> bo else: return any(c in wanted_cats for c in cats) +class Filter: + def __init__(self, manga: bool=False, h265: bool=False): + self.manga = manga + self.h265 = h265 + + def filter(self, t: 'Torrent', default: bool = False) -> bool: + valid = True + valid = valid and (not self.manga or t.is_manga(default)) + valid = valid and (not self.h265 or t.is_h265()) + return valid + + @dataclass(eq=True, order=True, unsafe_hash=True) class Torrent: seeders: int @@ -184,18 +196,26 @@ class Client: return sorted(torrents, reverse=True) -def filter_results(results: list[Torrent], full: bool, top: int, manga: bool, h265: bool) -> list[Torrent]: +def filter_results(results: list[Torrent], filter: Filter, top: int | None) -> list[Torrent]: """ take the complete query and filter further based on CLI options """ - if manga: - results = [t for t in results if t.is_manga(default=True)] - if h265: - results = [t for t in results if t.is_h265()] - if not full: + results = [t for t in results if filter.filter(t)] + if top is not None: results = results[:top] return results +def format_results(all_results: list[Torrent], filtered_results: list[Torrent], json: bool): + if json: + dumpable = [t.to_dict() for t in filtered_results] + print(json.dumps(dumpable)) + else: + num_total = len(all_results) + num_filtered = len(filtered_results) + print(f"found {num_total} result(s) filtered to {num_filtered}") + for r in filtered_results: + print(r) + def main(args: list[str]): logging.basicConfig() logging.getLogger().setLevel(logging.WARNING) @@ -215,18 +235,16 @@ def main(args: list[str]): logging.getLogger().setLevel(logging.DEBUG) client = Client() - results = client.query(args.query) - num_results = len(results) + all_results = client.query(args.query) - results = filter_results(results, args.full, int(args.top or "5"), args.manga, args.h265) + filter = Filter(manga=args.manga, h265=args.h265) + filtered_results = filter_results( + all_results, + filter, + None if args.full else int(args.top or "5"), + ) - if args.json: - dumpable = [t.to_dict() for t in results] - print(json.dumps(dumpable)) - else: - print(f"found {num_results} result(s)") - for r in results: - print(r) + format_results(all_results, filtered_results, args.json) if __name__ == "__main__": main(sys.argv[1:])