servo: clightning-sane: tidy

This commit is contained in:
Colin 2024-01-12 01:25:56 +00:00
parent 432a66bf5f
commit 913403aac6

View File

@ -12,6 +12,7 @@ import math
import sys
import time
from dataclasses import dataclass
from enum import Enum
from pyln.client import LightningRpc, Millisatoshi, RpcError
@ -31,6 +32,50 @@ class RouteError(Enum):
HAS_BASE_FEE = "HAS_BASE_FEE"
NO_ROUTE = "NO_ROUTE"
@dataclass
class TxBounds:
min_msat: int
max_msat: int
def is_satisfiable(self) -> bool:
return self.min_msat <= self.max_msat
def restrict_to_htlc(self, ch: "LocalChannel") -> "Self":
"""
apply min/max HTLC size restrictions of the given channel.
"""
new_min, new_max = self.min_msat, self.max_msat
if ch.htlc_minimum_msat > self.min_msat:
new_min = ch.htlc_minimum_msat
logger.debug(f"raising min_msat due to HTLC requirements: {self.min_msat} -> {new_min}")
if ch.htlc_maximum_msat < self.max_msat:
new_max = ch.htlc_maximum_msat
logger.debug(f"lowering max_msat due to HTLC requirements: {self.max_msat} -> {new_max}")
return TxBounds(min_msat=new_min, max_msat=new_max)
def restrict_to_zero_fees(self, ch: "LocalChannel"=None, base: int=0, ppm: int=0) -> "Self":
"""
restrict tx size such that PPM fees are zero.
if the channel has a base fee, then `max_msat` is forced to 0.
"""
if ch:
self = self.restrict_to_zero_fees(base=ch.to_me["base_fee_millisatoshi"], ppm=ch.to_me["fee_per_millionth"])
new_max = self.max_msat
if ppm != 0:
new_max = math.ceil(1000000 / ppm) - 1
if new_max < self.max_msat:
logger.debug(f"decreasing max_msat due to fee ppm: {self.max_msat} -> {new_max}")
if base != 0:
logger.debug("free route impossible: channel has base fees")
new_max = 0
return TxBounds(
min_msat = self.min_msat,
max_msat = new_max,
)
class LocalChannel:
def __init__(self, channels: list, self_id: str):
assert len(channels) == 2, f"unexpected: more than 2 channels: {channels}"
@ -109,11 +154,11 @@ class Balancer:
assert len(channels) == 1, f"expected exactly 1 channel: {channels}"
return channels[0]
def balance_once_with_retries(self, out_scid: str, in_scid: str, min_tx_msat: int, max_tx_msat: int, retries: int = 20) -> None:
def balance_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = 20) -> None:
for i in range(retries):
if i != 0:
logger.info(f"retrying rebalance: {i} of {retries}\n")
res = self.balance_once(out_scid, in_scid, min_tx_msat, max_tx_msat)
res = self.balance_once(out_scid, in_scid, tx)
if res == RebalanceResult.SUCCESS:
logger.info(f"rebalanced once with success {out_scid} -> {in_scid}")
break
@ -123,7 +168,7 @@ class Balancer:
else:
logger.info(f"failed to rebalance {out_scid} -> {in_scid} within {retries} attempts")
def balance_once(self, out_scid: str, in_scid: str, min_tx_msat: int, max_tx_msat: int) -> None:
def balance_once(self, out_scid: str, in_scid: str, bounds: TxBounds) -> None:
out_ch = self._localchannel(out_scid)
in_ch = self._localchannel(in_scid)
@ -131,12 +176,13 @@ class Balancer:
logger.info(f"rebalance {out_scid} -> {in_scid} failed in our own channel")
return RebalanceResult.FAIL_PERMANENT
tx_bounds = self._bound_tx_size(out_ch, in_ch, min_tx_msat, max_tx_msat)
if tx_bounds is None:
bounds = bounds.restrict_to_htlc(out_ch)
bounds = bounds.restrict_to_htlc(in_ch)
bounds = bounds.restrict_to_zero_fees(in_ch)
if not bounds.is_satisfiable():
return RebalanceResult.FAIL_PERMANENT # no valid bounds
min_tx_msat, max_tx_msat = tx_bounds
route = self.route(out_ch, in_ch, min_tx_msat, max_tx_msat)
route = self.route(out_ch, in_ch, bounds)
logger.debug(f"route: {route}")
if route == RouteError.NO_ROUTE:
return RebalanceResult.FAIL_PERMANENT
@ -166,36 +212,7 @@ class Balancer:
else:
return RebalanceResult.SUCCESS
def _bound_tx_size(self, out_ch: LocalChannel, in_ch: LocalChannel, min_tx_msat: int, max_tx_msat: int) -> tuple[int, int] | None:
# don't even try to route if the channels advertise to not support our request
min_, max_ = min_tx_msat, max_tx_msat
min_tx_msat = max(min_tx_msat, out_ch.htlc_minimum_msat)
min_tx_msat = max(min_tx_msat, in_ch.htlc_minimum_msat)
max_tx_msat = min(max_tx_msat, out_ch.htlc_maximum_msat)
max_tx_msat = min(max_tx_msat, in_ch.htlc_maximum_msat)
if min_ != min_tx_msat:
logger.debug(f"increased min_tx_msat due to route requirements: {min_} -> {min_tx_msat}")
if max_ != max_tx_msat:
logger.debug(f"decreased max_tx_msat due to route requirements: {max_} -> {max_tx_msat}")
if in_ch.to_me["base_fee_millisatoshi"] != 0:
logger.info(f"aborting route because inbound requires base fees")
return None
per_mili = in_ch.to_me["fee_per_millionth"]
if per_mili != 0:
new_max = math.ceil(1000000 / per_mili) - 1
if new_max < max_tx_msat:
logger.debug(f"decreased max_tx_msat due to inbound fee ppm: {max_tx_msat} -> {new_max}")
max_tx_msat = new_max
if min_tx_msat > max_tx_msat:
logger.info(f"aborting route because of conflicting HTLC min/max requirements ({min_tx_msat} > {max_tx_msat})")
return None
return min_tx_msat, max_tx_msat
def route(self, out_ch: LocalChannel, in_ch: LocalChannel, min_tx_msat: int, max_tx_msat: int) -> list[dict] | RouteError:
def route(self, out_ch: LocalChannel, in_ch: LocalChannel, bounds: TxBounds) -> list[dict] | RouteError:
exclude = [
# ensure the payment doesn't cross either channel in reverse.
# note that this doesn't preclude it from taking additional trips through self, with other peers.
@ -208,63 +225,46 @@ class Balancer:
out_peer = out_ch.remote_peer
in_peer = in_ch.remote_peer
route_or_max_tx = self._find_partial_route(out_peer, in_peer, max_tx_msat, exclude=exclude)
while isinstance(route_or_max_tx, int):
logger.debug(f"max feeless tx: {route_or_max_tx}")
try_again_msat = max(min_tx_msat, route_or_max_tx)
if try_again_msat == route_or_max_tx:
route_or_bounds = bounds
while isinstance(route_or_bounds, TxBounds):
old_bounds = route_or_bounds
route_or_bounds = self._find_partial_route(out_peer, in_peer, old_bounds, exclude=exclude)
if route_or_bounds == old_bounds:
return RouteError.NO_ROUTE
# due to per-channel HTLC size requirements, we have to try again and we'll maybe get a different route
route_or_max_tx = self._find_partial_route(out_peer, in_peer, try_again_msat, exclude=exclude)
route = route_or_max_tx
if isinstance(route_or_bounds, RouteError):
return route_or_bounds
if isinstance(route, RouteError):
return route
route = self._add_route_endpoints(route, out_ch, in_ch)
route = self._add_route_endpoints(route_or_bounds, out_ch, in_ch)
return route
def _find_partial_route(self, out_peer: str, in_peer: str, tx_msat: int, exclude: list[str]=[]) -> list[dict] | RouteError | int:
route = self.rpc.getroute(in_peer, amount_msat=tx_msat, riskfactor=0, fromid=out_peer, exclude=exclude, cltv=CLTV)
def _find_partial_route(self, out_peer: str, in_peer: str, bounds: TxBounds, exclude: list[str]=[]) -> list[dict] | RouteError | TxBounds:
route = self.rpc.getroute(in_peer, amount_msat=bounds.max_msat, riskfactor=0, fromid=out_peer, exclude=exclude, cltv=CLTV)
route = route["route"]
if route == []:
logger.debug(f"no route for {tx_msat}msat {out_peer} -> {in_peer}")
logger.debug(f"no route for {bounds.max_msat}msat {out_peer} -> {in_peer}")
return RouteError.NO_ROUTE
send_msat = route[0]["amount_msat"]
if send_msat != Millisatoshi(tx_msat):
logger.debug(f"found route with non-zero fee: {send_msat} -> {tx_msat}. {route}")
return self._max_feeless_tx_for_route(route)
if send_msat != Millisatoshi(bounds.max_msat):
logger.debug(f"found route with non-zero fee: {send_msat} -> {bounds.max_msat}. {route}")
for hop in route:
hop_scid = hop["channel"]
hop_dir = hop["direction"]
ch = self._get_directed_scid(hop_scid, hop_dir)
if ch["base_fee_millisatoshi"] != 0:
self.nonzero_base_channels.append(f"{hop_scid}/{hop_dir}")
bounds = bounds.restrict_to_zero_fees(ppm=ch["fee_per_millionth"])
if any(hop["base_fee_millisatoshi"] != 0 for hop in route):
return RouteError.HAS_BASE_FEE
return bounds
return route
def _max_feeless_tx_for_route(self, route: list[dict]) -> int|None:
has_base_fee = False
max_fee_per_mili = 0
for hop in route:
hop_scid = hop["channel"]
hop_dir = hop["direction"]
ch = self._get_directed_scid(hop_scid, hop_dir)
feebase = ch["base_fee_millisatoshi"]
if feebase:
has_base_fee = True
self.nonzero_base_channels.append(f"{hop_scid}/{hop_dir}")
per_mili = ch["fee_per_millionth"]
max_fee_per_mili = max(max_fee_per_mili, per_mili)
if has_base_fee:
return RouteError.HAS_BASE_FEE
if max_fee_per_mili == 0:
return int(route[0]["amount_msat"]) # no practical limit
return math.ceil(1000000 / max_fee_per_mili) - 1
def _add_route_endpoints(self, route, out_ch: LocalChannel, in_ch: LocalChannel):
inbound_hop = dict(
id=self.self_id,
@ -310,7 +310,11 @@ def main():
rpc = LightningRpc(RPC_FILE)
balancer = Balancer(rpc)
balancer.balance_once_with_retries(args.out, args.in_, int(args.min_msat), int(args.max_msat))
bounds = TxBounds(
min_msat = int(args.min_msat),
max_msat = int(args.max_msat),
)
balancer.balance_once_with_retries(args.out, args.in_, bounds)
if __name__ == '__main__':
main()