servo: clightning-sane: introduce parallelism
This commit is contained in:
parent
882cc5bfd0
commit
0e68533776
|
@ -12,6 +12,7 @@ import math
|
|||
import sys
|
||||
import time
|
||||
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
@ -24,6 +25,9 @@ RPC_FILE = "/var/lib/clightning/bitcoin/lightning-rpc"
|
|||
# set this too low and you might get inadvertent channel closures (?)
|
||||
CLTV = 18
|
||||
|
||||
MAX_ROUTE_ATTEMPTS = 20
|
||||
MAX_SEQUENTIAL_JOB_FAILURES = 100
|
||||
|
||||
class LoopError(Enum):
|
||||
""" error when trying to loop sats, or when unable to calculate a route for the loop """
|
||||
FAIL_TEMPORARY = "FAIL_TEMPORARY" # try again, we'll maybe find a different route
|
||||
|
@ -36,8 +40,8 @@ class RouteError(Enum):
|
|||
|
||||
@dataclass
|
||||
class TxBounds:
|
||||
min_msat: int
|
||||
max_msat: int
|
||||
min_msat: int = 0
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"TxBounds({self.min_msat} <= msat <= {self.max_msat})"
|
||||
|
@ -45,6 +49,12 @@ class TxBounds:
|
|||
def is_satisfiable(self) -> bool:
|
||||
return self.min_msat <= self.max_msat
|
||||
|
||||
def intersect(self, other: "TxBounds") -> "Self":
|
||||
return TxBounds(
|
||||
min_msat=max(self.min_msat, other.min_msat),
|
||||
max_msat=min(self.max_msat, other.max_msat),
|
||||
)
|
||||
|
||||
def restrict_to_htlc(self, ch: "LocalChannel", why: str = "") -> "Self":
|
||||
"""
|
||||
apply min/max HTLC size restrictions of the given channel.
|
||||
|
@ -255,10 +265,10 @@ class LoopRouter:
|
|||
assert len(channels) == 1, f"expected exactly 1 channel: {channels}"
|
||||
return channels[0]
|
||||
|
||||
def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = 20) -> int:
|
||||
def loop_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = MAX_ROUTE_ATTEMPTS) -> int:
|
||||
for i in range(retries):
|
||||
if i != 0:
|
||||
logger.info(f"retrying loop: {i} of {retries}\n")
|
||||
logger.info(f"retrying loop: {i} of {retries}")
|
||||
res = self.loop_once(out_scid, in_scid, tx)
|
||||
if res == LoopError.FAIL_PERMANENT:
|
||||
logger.info(f"loop {out_scid} -> {in_scid} is impossible (likely no route)")
|
||||
|
@ -394,6 +404,111 @@ class LoopRouter:
|
|||
def _add_route_delay(self, route: list[dict], delay: int) -> list[dict]:
|
||||
return [ dict(hop, delay=hop["delay"] + delay) for hop in route ]
|
||||
|
||||
@dataclass
|
||||
class LoopJob:
|
||||
out: str
|
||||
in_: str
|
||||
amount: int
|
||||
|
||||
|
||||
class AbstractLoopRunner:
|
||||
def __init__(self, looper: LoopRouter, bounds: TxBounds, parallelism: int):
|
||||
self.looper = looper
|
||||
self.bounds = bounds
|
||||
self.parallelism = parallelism
|
||||
self.bounds_map = {} # map (out:str, in_:str) -> TxBounds. it's a cache so we don't have to try 10 routes every time.
|
||||
|
||||
def pop_job(self) -> LoopJob | None:
|
||||
raise NotImplemented # abstract method
|
||||
|
||||
def finished_job(self, job: LoopJob, progress: int|LoopError) -> None:
|
||||
raise NotImplemented # abstract method
|
||||
|
||||
def run_to_completion(self) -> None:
|
||||
if self.parallelism == 1:
|
||||
self._worker_thread()
|
||||
else:
|
||||
with ThreadPoolExecutor(max_workers=self.parallelism) as executor:
|
||||
_ = list(executor.map(lambda _i: self._try_invoke(self._worker_thread), range(self.parallelism)))
|
||||
|
||||
|
||||
def _try_invoke(self, f, *args) -> None:
|
||||
"""
|
||||
try to invoke `f` with the provided `args`, and log if it fails.
|
||||
this overcomes the issue that background tasks which fail via Exception otherwise do so silently.
|
||||
"""
|
||||
try:
|
||||
f(*args)
|
||||
except Exception as e:
|
||||
logger.error(f"task failed: {e}")
|
||||
|
||||
|
||||
def _worker_thread(self) -> None:
|
||||
while True:
|
||||
job = self.pop_job()
|
||||
logger.debug(f"popped job: {job}")
|
||||
if job is None: return
|
||||
|
||||
result = self._execute_job(job)
|
||||
logger.debug(f"finishing job {job} with {result}")
|
||||
self.finished_job(job, result)
|
||||
|
||||
def _execute_job(self, job: LoopJob) -> LoopError|int:
|
||||
bounds = self.bounds_map.get((job.out, job.in_), self.bounds)
|
||||
bounds = bounds.intersect(TxBounds(max_msat=job.amount))
|
||||
if not bounds.is_satisfiable():
|
||||
logger.debug(f"TxBounds for job are unsatisfiable; skipping: {bounds} {job}")
|
||||
return LoopError.FAIL_PERMANENT
|
||||
|
||||
amt_looped = self.looper.loop_once(job.out, job.in_, bounds)
|
||||
if amt_looped in (0, LoopError.FAIL_PERMANENT, LoopError.FAIL_TEMPORARY):
|
||||
return amt_looped
|
||||
|
||||
logger.info(f"looped {amt_looped} from {job.out} -> {job.in_}")
|
||||
bounds = bounds.intersect(TxBounds(max_msat=amt_looped))
|
||||
|
||||
self.bounds_map[(job.out, job.in_)] = bounds
|
||||
return amt_looped
|
||||
|
||||
class LoopBalancer(AbstractLoopRunner):
|
||||
def __init__(self, out: str, in_: str, amount: int, looper: LoopRouter, bounds: TxBounds, parallelism: int=1):
|
||||
super().__init__(looper, bounds, parallelism)
|
||||
self.job = LoopJob(out=out, in_=in_, amount=amount)
|
||||
self.fail_count = 0
|
||||
self.out = out
|
||||
self.in_ = in_
|
||||
self.amount_target = amount
|
||||
self.amount_looped = 0
|
||||
self.amount_outstanding = 0
|
||||
|
||||
def pop_job(self) -> LoopJob | None:
|
||||
if self.fail_count >= MAX_SEQUENTIAL_JOB_FAILURES: return None
|
||||
|
||||
amount_avail = self.amount_target - self.amount_looped - self.amount_outstanding
|
||||
if amount_avail < self.bounds.min_msat: return None
|
||||
amount_this_job = min(amount_avail, self.bounds.max_msat)
|
||||
|
||||
self.amount_outstanding += amount_this_job
|
||||
return LoopJob(out=self.out, in_=self.in_, amount=amount_this_job)
|
||||
|
||||
def finished_job(self, job: LoopJob, progress: int) -> None:
|
||||
# TODO: drop bad_channels cache and bounds_map cache after so many errors
|
||||
if progress == LoopError.FAIL_PERMANENT: self.fail_count += MAX_SEQUENTIAL_JOB_FAILURES
|
||||
elif progress == LoopError.FAIL_TEMPORARY: self.fail_count += 1
|
||||
else:
|
||||
self.fail_count = 0
|
||||
self.amount_outstanding -= job.amount
|
||||
self.amount_looped += progress
|
||||
logger.info(f"loop progressed {progress}: {self.amount_looped} of {self.amount_target}")
|
||||
|
||||
|
||||
def balance_loop(rpc: RpcHelper, out: str, in_: str, amount_msat: int, min_msat: int, max_msat: int, parallelism: int):
|
||||
looper = LoopRouter(rpc)
|
||||
bounds = TxBounds(min_msat=min_msat, max_msat=max_msat)
|
||||
balancer = LoopBalancer(out, in_, amount_msat, looper, bounds, parallelism)
|
||||
|
||||
balancer.run_to_completion()
|
||||
|
||||
def show_status(rpc: RpcHelper, full: bool=False):
|
||||
"""
|
||||
show a table of channel balances between peers.
|
||||
|
@ -406,24 +521,6 @@ def show_status(rpc: RpcHelper, full: bool=False):
|
|||
else:
|
||||
print(ch.to_str(with_scid=True, with_bal_ratio=True, with_cost=True, with_ppm_theirs=True, with_ppm_mine=full, with_peer_id=full))
|
||||
|
||||
def balance_loop(rpc: RpcHelper, out: str, in_: str, min_msat: int, max_msat: int, max_tx: int):
|
||||
looper = LoopRouter(rpc)
|
||||
bounds = TxBounds(min_msat=min_msat, max_msat=max_msat)
|
||||
|
||||
asked_to_route = bounds.max_msat
|
||||
total_routed = 0
|
||||
for i in range(max_tx):
|
||||
bounds.max_msat = min(bounds.max_msat, asked_to_route - total_routed)
|
||||
if not bounds.is_satisfiable(): break
|
||||
|
||||
amt_balanced = looper.loop_once_with_retries(out, in_, bounds)
|
||||
total_routed += amt_balanced
|
||||
if amt_balanced == 0: break
|
||||
logger.info(f"rebalanced {amt_balanced} (total: {total_routed} of {asked_to_route})")
|
||||
bounds.max_msat = min(bounds.max_msat, amt_balanced)
|
||||
|
||||
logger.info(f"rebalanced {total_routed} of {asked_to_route}")
|
||||
|
||||
def main():
|
||||
logging.basicConfig()
|
||||
logger.setLevel(logging.INFO)
|
||||
|
@ -440,9 +537,10 @@ def main():
|
|||
loop_parser.set_defaults(action="loop")
|
||||
loop_parser.add_argument("out", help="peer id to send tx through")
|
||||
loop_parser.add_argument("in_", help="peer id to receive tx through")
|
||||
loop_parser.add_argument("--min-msat", default="999", help="min to rebalance")
|
||||
loop_parser.add_argument("--max-msat", default="1000000", help="max to rebalance")
|
||||
loop_parser.add_argument("--max-tx", default="1", help="maximum times to rebalance")
|
||||
loop_parser.add_argument("amount", help="total amount of msat to loop")
|
||||
loop_parser.add_argument("--min-msat", default="999", help="min transaction size")
|
||||
loop_parser.add_argument("--max-msat", default="1000000", help="max transaction size")
|
||||
loop_parser.add_argument("--jobs", default="1", help="how many HTLCs to keep in-flight at once")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -455,7 +553,7 @@ def main():
|
|||
show_status(rpc, full=args.full)
|
||||
|
||||
if args.action == "loop":
|
||||
balance_loop(rpc, out=args.out, in_=args.in_, min_msat=int(args.min_msat), max_msat=int(args.max_msat), max_tx=int(args.max_tx))
|
||||
balance_loop(rpc, out=args.out, in_=args.in_, amount_msat=int(args.amount), min_msat=int(args.min_msat), max_msat=int(args.max_msat), parallelism=int(args.jobs))
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user