servo: clightning-sane: introduce parallelism

This commit is contained in:
Colin 2024-01-12 23:30:52 +00:00
parent 882cc5bfd0
commit 0e68533776

View File

@ -12,6 +12,7 @@ import math
import sys
import time
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum
@ -24,6 +25,9 @@ RPC_FILE = "/var/lib/clightning/bitcoin/lightning-rpc"
# set this too low and you might get inadvertent channel closures (?)
CLTV = 18
MAX_ROUTE_ATTEMPTS = 20
MAX_SEQUENTIAL_JOB_FAILURES = 100
class LoopError(Enum):
""" error when trying to loop sats, or when unable to calculate a route for the loop """
FAIL_TEMPORARY = "FAIL_TEMPORARY" # try again, we'll maybe find a different route
@ -36,8 +40,8 @@ class RouteError(Enum):
@dataclass
class TxBounds:
min_msat: int
max_msat: int
min_msat: int = 0
def __repr__(self) -> str:
return f"TxBounds({self.min_msat} <= msat <= {self.max_msat})"
@ -45,6 +49,12 @@ class TxBounds:
def is_satisfiable(self) -> bool:
return self.min_msat <= self.max_msat
def intersect(self, other: "TxBounds") -> "Self":
return TxBounds(
min_msat=max(self.min_msat, other.min_msat),
max_msat=min(self.max_msat, other.max_msat),
)
def restrict_to_htlc(self, ch: "LocalChannel", why: str = "") -> "Self":
"""
apply min/max HTLC size restrictions of the given channel.
@ -255,10 +265,10 @@ class LoopRouter:
assert len(channels) == 1, f"expected exactly 1 channel: {channels}"
return channels[0]
def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = 20) -> int:
def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = MAX_ROUTE_ATTEMPTS) -> int:
for i in range(retries):
if i != 0:
logger.info(f"retrying loop: {i} of {retries}\n")
logger.info(f"retrying loop: {i} of {retries}")
res = self.loop_once(out_scid, in_scid, tx)
if res == LoopError.FAIL_PERMANENT:
logger.info(f"loop {out_scid} -> {in_scid} is impossible (likely no route)")
@ -394,6 +404,111 @@ class LoopRouter:
def _add_route_delay(self, route: list[dict], delay: int) -> list[dict]:
return [ dict(hop, delay=hop["delay"] + delay) for hop in route ]
@dataclass
class LoopJob:
out: str
in_: str
amount: int
class AbstractLoopRunner:
def __init__(self, looper: LoopRouter, bounds: TxBounds, parallelism: int):
self.looper = looper
self.bounds = bounds
self.parallelism = parallelism
self.bounds_map = {} # map (out:str, in_:str) -> TxBounds. it's a cache so we don't have to try 10 routes every time.
def pop_job(self) -> LoopJob | None:
raise NotImplemented # abstract method
def finished_job(self, job: LoopJob, progress: int|LoopError) -> None:
raise NotImplemented # abstract method
def run_to_completion(self) -> None:
if self.parallelism == 1:
self._worker_thread()
else:
with ThreadPoolExecutor(max_workers=self.parallelism) as executor:
_ = list(executor.map(lambda _i: self._try_invoke(self._worker_thread), range(self.parallelism)))
def _try_invoke(self, f, *args) -> None:
"""
try to invoke `f` with the provided `args`, and log if it fails.
this overcomes the issue that background tasks which fail via Exception otherwise do so silently.
"""
try:
f(*args)
except Exception as e:
logger.error(f"task failed: {e}")
def _worker_thread(self) -> None:
while True:
job = self.pop_job()
logger.debug(f"popped job: {job}")
if job is None: return
result = self._execute_job(job)
logger.debug(f"finishing job {job} with {result}")
self.finished_job(job, result)
def _execute_job(self, job: LoopJob) -> LoopError|int:
bounds = self.bounds_map.get((job.out, job.in_), self.bounds)
bounds = bounds.intersect(TxBounds(max_msat=job.amount))
if not bounds.is_satisfiable():
logger.debug(f"TxBounds for job are unsatisfiable; skipping: {bounds} {job}")
return LoopError.FAIL_PERMANENT
amt_looped = self.looper.loop_once(job.out, job.in_, bounds)
if amt_looped in (0, LoopError.FAIL_PERMANENT, LoopError.FAIL_TEMPORARY):
return amt_looped
logger.info(f"looped {amt_looped} from {job.out} -> {job.in_}")
bounds = bounds.intersect(TxBounds(max_msat=amt_looped))
self.bounds_map[(job.out, job.in_)] = bounds
return amt_looped
class LoopBalancer(AbstractLoopRunner):
def __init__(self, out: str, in_: str, amount: int, looper: LoopRouter, bounds: TxBounds, parallelism: int=1):
super().__init__(looper, bounds, parallelism)
self.job = LoopJob(out=out, in_=in_, amount=amount)
self.fail_count = 0
self.out = out
self.in_ = in_
self.amount_target = amount
self.amount_looped = 0
self.amount_outstanding = 0
def pop_job(self) -> LoopJob | None:
if self.fail_count >= MAX_SEQUENTIAL_JOB_FAILURES: return None
amount_avail = self.amount_target - self.amount_looped - self.amount_outstanding
if amount_avail < self.bounds.min_msat: return None
amount_this_job = min(amount_avail, self.bounds.max_msat)
self.amount_outstanding += amount_this_job
return LoopJob(out=self.out, in_=self.in_, amount=amount_this_job)
def finished_job(self, job: LoopJob, progress: int) -> None:
# TODO: drop bad_channels cache and bounds_map cache after so many errors
if progress == LoopError.FAIL_PERMANENT: self.fail_count += MAX_SEQUENTIAL_JOB_FAILURES
elif progress == LoopError.FAIL_TEMPORARY: self.fail_count += 1
else:
self.fail_count = 0
self.amount_outstanding -= job.amount
self.amount_looped += progress
logger.info(f"loop progressed {progress}: {self.amount_looped} of {self.amount_target}")
def balance_loop(rpc: RpcHelper, out: str, in_: str, amount_msat: int, min_msat: int, max_msat: int, parallelism: int):
looper = LoopRouter(rpc)
bounds = TxBounds(min_msat=min_msat, max_msat=max_msat)
balancer = LoopBalancer(out, in_, amount_msat, looper, bounds, parallelism)
balancer.run_to_completion()
def show_status(rpc: RpcHelper, full: bool=False):
"""
show a table of channel balances between peers.
@ -406,24 +521,6 @@ def show_status(rpc: RpcHelper, full: bool=False):
else:
print(ch.to_str(with_scid=True, with_bal_ratio=True, with_cost=True, with_ppm_theirs=True, with_ppm_mine=full, with_peer_id=full))
def balance_loop(rpc: RpcHelper, out: str, in_: str, min_msat: int, max_msat: int, max_tx: int):
looper = LoopRouter(rpc)
bounds = TxBounds(min_msat=min_msat, max_msat=max_msat)
asked_to_route = bounds.max_msat
total_routed = 0
for i in range(max_tx):
bounds.max_msat = min(bounds.max_msat, asked_to_route - total_routed)
if not bounds.is_satisfiable(): break
amt_balanced = looper.loop_once_with_retries(out, in_, bounds)
total_routed += amt_balanced
if amt_balanced == 0: break
logger.info(f"rebalanced {amt_balanced} (total: {total_routed} of {asked_to_route})")
bounds.max_msat = min(bounds.max_msat, amt_balanced)
logger.info(f"rebalanced {total_routed} of {asked_to_route}")
def main():
logging.basicConfig()
logger.setLevel(logging.INFO)
@ -440,9 +537,10 @@ def main():
loop_parser.set_defaults(action="loop")
loop_parser.add_argument("out", help="peer id to send tx through")
loop_parser.add_argument("in_", help="peer id to receive tx through")
loop_parser.add_argument("--min-msat", default="999", help="min to rebalance")
loop_parser.add_argument("--max-msat", default="1000000", help="max to rebalance")
loop_parser.add_argument("--max-tx", default="1", help="maximum times to rebalance")
loop_parser.add_argument("amount", help="total amount of msat to loop")
loop_parser.add_argument("--min-msat", default="999", help="min transaction size")
loop_parser.add_argument("--max-msat", default="1000000", help="max transaction size")
loop_parser.add_argument("--jobs", default="1", help="how many HTLCs to keep in-flight at once")
args = parser.parse_args()
@ -455,7 +553,7 @@ def main():
show_status(rpc, full=args.full)
if args.action == "loop":
balance_loop(rpc, out=args.out, in_=args.in_, min_msat=int(args.min_msat), max_msat=int(args.max_msat), max_tx=int(args.max_tx))
balance_loop(rpc, out=args.out, in_=args.in_, amount_msat=int(args.amount), min_msat=int(args.min_msat), max_msat=int(args.max_msat), parallelism=int(args.jobs))
if __name__ == '__main__':
main()