From 0e68533776259609d8996001133220cdaeca7e27 Mon Sep 17 00:00:00 2001 From: Colin Date: Fri, 12 Jan 2024 23:30:52 +0000 Subject: [PATCH] servo: clightning-sane: introduce parallelism --- .../clightning-sane/clightning-sane | 148 +++++++++++++++--- 1 file changed, 123 insertions(+), 25 deletions(-) diff --git a/hosts/by-name/servo/services/cryptocurrencies/clightning-sane/clightning-sane b/hosts/by-name/servo/services/cryptocurrencies/clightning-sane/clightning-sane index a5d499cb..6cf9c541 100755 --- a/hosts/by-name/servo/services/cryptocurrencies/clightning-sane/clightning-sane +++ b/hosts/by-name/servo/services/cryptocurrencies/clightning-sane/clightning-sane @@ -12,6 +12,7 @@ import math import sys import time +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from enum import Enum @@ -24,6 +25,9 @@ RPC_FILE = "/var/lib/clightning/bitcoin/lightning-rpc" # set this too low and you might get inadvertent channel closures (?) CLTV = 18 +MAX_ROUTE_ATTEMPTS = 20 +MAX_SEQUENTIAL_JOB_FAILURES = 100 + class LoopError(Enum): """ error when trying to loop sats, or when unable to calculate a route for the loop """ FAIL_TEMPORARY = "FAIL_TEMPORARY" # try again, we'll maybe find a different route @@ -36,8 +40,8 @@ class RouteError(Enum): @dataclass class TxBounds: - min_msat: int max_msat: int + min_msat: int = 0 def __repr__(self) -> str: return f"TxBounds({self.min_msat} <= msat <= {self.max_msat})" @@ -45,6 +49,12 @@ class TxBounds: def is_satisfiable(self) -> bool: return self.min_msat <= self.max_msat + def intersect(self, other: "TxBounds") -> "Self": + return TxBounds( + min_msat=max(self.min_msat, other.min_msat), + max_msat=min(self.max_msat, other.max_msat), + ) + def restrict_to_htlc(self, ch: "LocalChannel", why: str = "") -> "Self": """ apply min/max HTLC size restrictions of the given channel. @@ -255,10 +265,10 @@ class LoopRouter: assert len(channels) == 1, f"expected exactly 1 channel: {channels}" return channels[0] - def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = 20) -> int: + def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = MAX_ROUTE_ATTEMPTS) -> int: for i in range(retries): if i != 0: - logger.info(f"retrying loop: {i} of {retries}\n") + logger.info(f"retrying loop: {i} of {retries}") res = self.loop_once(out_scid, in_scid, tx) if res == LoopError.FAIL_PERMANENT: logger.info(f"loop {out_scid} -> {in_scid} is impossible (likely no route)") @@ -394,6 +404,111 @@ class LoopRouter: def _add_route_delay(self, route: list[dict], delay: int) -> list[dict]: return [ dict(hop, delay=hop["delay"] + delay) for hop in route ] +@dataclass +class LoopJob: + out: str + in_: str + amount: int + + +class AbstractLoopRunner: + def __init__(self, looper: LoopRouter, bounds: TxBounds, parallelism: int): + self.looper = looper + self.bounds = bounds + self.parallelism = parallelism + self.bounds_map = {} # map (out:str, in_:str) -> TxBounds. it's a cache so we don't have to try 10 routes every time. + + def pop_job(self) -> LoopJob | None: + raise NotImplemented # abstract method + + def finished_job(self, job: LoopJob, progress: int|LoopError) -> None: + raise NotImplemented # abstract method + + def run_to_completion(self) -> None: + if self.parallelism == 1: + self._worker_thread() + else: + with ThreadPoolExecutor(max_workers=self.parallelism) as executor: + _ = list(executor.map(lambda _i: self._try_invoke(self._worker_thread), range(self.parallelism))) + + + def _try_invoke(self, f, *args) -> None: + """ + try to invoke `f` with the provided `args`, and log if it fails. + this overcomes the issue that background tasks which fail via Exception otherwise do so silently. + """ + try: + f(*args) + except Exception as e: + logger.error(f"task failed: {e}") + + + def _worker_thread(self) -> None: + while True: + job = self.pop_job() + logger.debug(f"popped job: {job}") + if job is None: return + + result = self._execute_job(job) + logger.debug(f"finishing job {job} with {result}") + self.finished_job(job, result) + + def _execute_job(self, job: LoopJob) -> LoopError|int: + bounds = self.bounds_map.get((job.out, job.in_), self.bounds) + bounds = bounds.intersect(TxBounds(max_msat=job.amount)) + if not bounds.is_satisfiable(): + logger.debug(f"TxBounds for job are unsatisfiable; skipping: {bounds} {job}") + return LoopError.FAIL_PERMANENT + + amt_looped = self.looper.loop_once(job.out, job.in_, bounds) + if amt_looped in (0, LoopError.FAIL_PERMANENT, LoopError.FAIL_TEMPORARY): + return amt_looped + + logger.info(f"looped {amt_looped} from {job.out} -> {job.in_}") + bounds = bounds.intersect(TxBounds(max_msat=amt_looped)) + + self.bounds_map[(job.out, job.in_)] = bounds + return amt_looped + +class LoopBalancer(AbstractLoopRunner): + def __init__(self, out: str, in_: str, amount: int, looper: LoopRouter, bounds: TxBounds, parallelism: int=1): + super().__init__(looper, bounds, parallelism) + self.job = LoopJob(out=out, in_=in_, amount=amount) + self.fail_count = 0 + self.out = out + self.in_ = in_ + self.amount_target = amount + self.amount_looped = 0 + self.amount_outstanding = 0 + + def pop_job(self) -> LoopJob | None: + if self.fail_count >= MAX_SEQUENTIAL_JOB_FAILURES: return None + + amount_avail = self.amount_target - self.amount_looped - self.amount_outstanding + if amount_avail < self.bounds.min_msat: return None + amount_this_job = min(amount_avail, self.bounds.max_msat) + + self.amount_outstanding += amount_this_job + return LoopJob(out=self.out, in_=self.in_, amount=amount_this_job) + + def finished_job(self, job: LoopJob, progress: int) -> None: + # TODO: drop bad_channels cache and bounds_map cache after so many errors + if progress == LoopError.FAIL_PERMANENT: self.fail_count += MAX_SEQUENTIAL_JOB_FAILURES + elif progress == LoopError.FAIL_TEMPORARY: self.fail_count += 1 + else: + self.fail_count = 0 + self.amount_outstanding -= job.amount + self.amount_looped += progress + logger.info(f"loop progressed {progress}: {self.amount_looped} of {self.amount_target}") + + +def balance_loop(rpc: RpcHelper, out: str, in_: str, amount_msat: int, min_msat: int, max_msat: int, parallelism: int): + looper = LoopRouter(rpc) + bounds = TxBounds(min_msat=min_msat, max_msat=max_msat) + balancer = LoopBalancer(out, in_, amount_msat, looper, bounds, parallelism) + + balancer.run_to_completion() + def show_status(rpc: RpcHelper, full: bool=False): """ show a table of channel balances between peers. @@ -406,24 +521,6 @@ def show_status(rpc: RpcHelper, full: bool=False): else: print(ch.to_str(with_scid=True, with_bal_ratio=True, with_cost=True, with_ppm_theirs=True, with_ppm_mine=full, with_peer_id=full)) -def balance_loop(rpc: RpcHelper, out: str, in_: str, min_msat: int, max_msat: int, max_tx: int): - looper = LoopRouter(rpc) - bounds = TxBounds(min_msat=min_msat, max_msat=max_msat) - - asked_to_route = bounds.max_msat - total_routed = 0 - for i in range(max_tx): - bounds.max_msat = min(bounds.max_msat, asked_to_route - total_routed) - if not bounds.is_satisfiable(): break - - amt_balanced = looper.loop_once_with_retries(out, in_, bounds) - total_routed += amt_balanced - if amt_balanced == 0: break - logger.info(f"rebalanced {amt_balanced} (total: {total_routed} of {asked_to_route})") - bounds.max_msat = min(bounds.max_msat, amt_balanced) - - logger.info(f"rebalanced {total_routed} of {asked_to_route}") - def main(): logging.basicConfig() logger.setLevel(logging.INFO) @@ -440,9 +537,10 @@ def main(): loop_parser.set_defaults(action="loop") loop_parser.add_argument("out", help="peer id to send tx through") loop_parser.add_argument("in_", help="peer id to receive tx through") - loop_parser.add_argument("--min-msat", default="999", help="min to rebalance") - loop_parser.add_argument("--max-msat", default="1000000", help="max to rebalance") - loop_parser.add_argument("--max-tx", default="1", help="maximum times to rebalance") + loop_parser.add_argument("amount", help="total amount of msat to loop") + loop_parser.add_argument("--min-msat", default="999", help="min transaction size") + loop_parser.add_argument("--max-msat", default="1000000", help="max transaction size") + loop_parser.add_argument("--jobs", default="1", help="how many HTLCs to keep in-flight at once") args = parser.parse_args() @@ -455,7 +553,7 @@ def main(): show_status(rpc, full=args.full) if args.action == "loop": - balance_loop(rpc, out=args.out, in_=args.in_, min_msat=int(args.min_msat), max_msat=int(args.max_msat), max_tx=int(args.max_tx)) + balance_loop(rpc, out=args.out, in_=args.in_, amount_msat=int(args.amount), min_msat=int(args.min_msat), max_msat=int(args.max_msat), parallelism=int(args.jobs)) if __name__ == '__main__': main()