servo: clightning-sane: tidy

This commit is contained in:
Colin 2024-01-12 01:25:56 +00:00
parent 432a66bf5f
commit 913403aac6

View File

@ -12,6 +12,7 @@ import math
import sys import sys
import time import time
from dataclasses import dataclass
from enum import Enum from enum import Enum
from pyln.client import LightningRpc, Millisatoshi, RpcError from pyln.client import LightningRpc, Millisatoshi, RpcError
@ -31,6 +32,50 @@ class RouteError(Enum):
HAS_BASE_FEE = "HAS_BASE_FEE" HAS_BASE_FEE = "HAS_BASE_FEE"
NO_ROUTE = "NO_ROUTE" NO_ROUTE = "NO_ROUTE"
@dataclass
class TxBounds:
min_msat: int
max_msat: int
def is_satisfiable(self) -> bool:
return self.min_msat <= self.max_msat
def restrict_to_htlc(self, ch: "LocalChannel") -> "Self":
"""
apply min/max HTLC size restrictions of the given channel.
"""
new_min, new_max = self.min_msat, self.max_msat
if ch.htlc_minimum_msat > self.min_msat:
new_min = ch.htlc_minimum_msat
logger.debug(f"raising min_msat due to HTLC requirements: {self.min_msat} -> {new_min}")
if ch.htlc_maximum_msat < self.max_msat:
new_max = ch.htlc_maximum_msat
logger.debug(f"lowering max_msat due to HTLC requirements: {self.max_msat} -> {new_max}")
return TxBounds(min_msat=new_min, max_msat=new_max)
def restrict_to_zero_fees(self, ch: "LocalChannel"=None, base: int=0, ppm: int=0) -> "Self":
"""
restrict tx size such that PPM fees are zero.
if the channel has a base fee, then `max_msat` is forced to 0.
"""
if ch:
self = self.restrict_to_zero_fees(base=ch.to_me["base_fee_millisatoshi"], ppm=ch.to_me["fee_per_millionth"])
new_max = self.max_msat
if ppm != 0:
new_max = math.ceil(1000000 / ppm) - 1
if new_max < self.max_msat:
logger.debug(f"decreasing max_msat due to fee ppm: {self.max_msat} -> {new_max}")
if base != 0:
logger.debug("free route impossible: channel has base fees")
new_max = 0
return TxBounds(
min_msat = self.min_msat,
max_msat = new_max,
)
class LocalChannel: class LocalChannel:
def __init__(self, channels: list, self_id: str): def __init__(self, channels: list, self_id: str):
assert len(channels) == 2, f"unexpected: more than 2 channels: {channels}" assert len(channels) == 2, f"unexpected: more than 2 channels: {channels}"
@ -109,11 +154,11 @@ class Balancer:
assert len(channels) == 1, f"expected exactly 1 channel: {channels}" assert len(channels) == 1, f"expected exactly 1 channel: {channels}"
return channels[0] return channels[0]
def balance_once_with_retries(self, out_scid: str, in_scid: str, min_tx_msat: int, max_tx_msat: int, retries: int = 20) -> None: def balance_once_with_retries(self, out_scid: str, in_scid: str, tx: TxBounds, retries: int = 20) -> None:
for i in range(retries): for i in range(retries):
if i != 0: if i != 0:
logger.info(f"retrying rebalance: {i} of {retries}\n") logger.info(f"retrying rebalance: {i} of {retries}\n")
res = self.balance_once(out_scid, in_scid, min_tx_msat, max_tx_msat) res = self.balance_once(out_scid, in_scid, tx)
if res == RebalanceResult.SUCCESS: if res == RebalanceResult.SUCCESS:
logger.info(f"rebalanced once with success {out_scid} -> {in_scid}") logger.info(f"rebalanced once with success {out_scid} -> {in_scid}")
break break
@ -123,7 +168,7 @@ class Balancer:
else: else:
logger.info(f"failed to rebalance {out_scid} -> {in_scid} within {retries} attempts") logger.info(f"failed to rebalance {out_scid} -> {in_scid} within {retries} attempts")
def balance_once(self, out_scid: str, in_scid: str, min_tx_msat: int, max_tx_msat: int) -> None: def balance_once(self, out_scid: str, in_scid: str, bounds: TxBounds) -> None:
out_ch = self._localchannel(out_scid) out_ch = self._localchannel(out_scid)
in_ch = self._localchannel(in_scid) in_ch = self._localchannel(in_scid)
@ -131,12 +176,13 @@ class Balancer:
logger.info(f"rebalance {out_scid} -> {in_scid} failed in our own channel") logger.info(f"rebalance {out_scid} -> {in_scid} failed in our own channel")
return RebalanceResult.FAIL_PERMANENT return RebalanceResult.FAIL_PERMANENT
tx_bounds = self._bound_tx_size(out_ch, in_ch, min_tx_msat, max_tx_msat) bounds = bounds.restrict_to_htlc(out_ch)
if tx_bounds is None: bounds = bounds.restrict_to_htlc(in_ch)
bounds = bounds.restrict_to_zero_fees(in_ch)
if not bounds.is_satisfiable():
return RebalanceResult.FAIL_PERMANENT # no valid bounds return RebalanceResult.FAIL_PERMANENT # no valid bounds
min_tx_msat, max_tx_msat = tx_bounds
route = self.route(out_ch, in_ch, min_tx_msat, max_tx_msat) route = self.route(out_ch, in_ch, bounds)
logger.debug(f"route: {route}") logger.debug(f"route: {route}")
if route == RouteError.NO_ROUTE: if route == RouteError.NO_ROUTE:
return RebalanceResult.FAIL_PERMANENT return RebalanceResult.FAIL_PERMANENT
@ -166,36 +212,7 @@ class Balancer:
else: else:
return RebalanceResult.SUCCESS return RebalanceResult.SUCCESS
def _bound_tx_size(self, out_ch: LocalChannel, in_ch: LocalChannel, min_tx_msat: int, max_tx_msat: int) -> tuple[int, int] | None: def route(self, out_ch: LocalChannel, in_ch: LocalChannel, bounds: TxBounds) -> list[dict] | RouteError:
# don't even try to route if the channels advertise to not support our request
min_, max_ = min_tx_msat, max_tx_msat
min_tx_msat = max(min_tx_msat, out_ch.htlc_minimum_msat)
min_tx_msat = max(min_tx_msat, in_ch.htlc_minimum_msat)
max_tx_msat = min(max_tx_msat, out_ch.htlc_maximum_msat)
max_tx_msat = min(max_tx_msat, in_ch.htlc_maximum_msat)
if min_ != min_tx_msat:
logger.debug(f"increased min_tx_msat due to route requirements: {min_} -> {min_tx_msat}")
if max_ != max_tx_msat:
logger.debug(f"decreased max_tx_msat due to route requirements: {max_} -> {max_tx_msat}")
if in_ch.to_me["base_fee_millisatoshi"] != 0:
logger.info(f"aborting route because inbound requires base fees")
return None
per_mili = in_ch.to_me["fee_per_millionth"]
if per_mili != 0:
new_max = math.ceil(1000000 / per_mili) - 1
if new_max < max_tx_msat:
logger.debug(f"decreased max_tx_msat due to inbound fee ppm: {max_tx_msat} -> {new_max}")
max_tx_msat = new_max
if min_tx_msat > max_tx_msat:
logger.info(f"aborting route because of conflicting HTLC min/max requirements ({min_tx_msat} > {max_tx_msat})")
return None
return min_tx_msat, max_tx_msat
def route(self, out_ch: LocalChannel, in_ch: LocalChannel, min_tx_msat: int, max_tx_msat: int) -> list[dict] | RouteError:
exclude = [ exclude = [
# ensure the payment doesn't cross either channel in reverse. # ensure the payment doesn't cross either channel in reverse.
# note that this doesn't preclude it from taking additional trips through self, with other peers. # note that this doesn't preclude it from taking additional trips through self, with other peers.
@ -208,63 +225,46 @@ class Balancer:
out_peer = out_ch.remote_peer out_peer = out_ch.remote_peer
in_peer = in_ch.remote_peer in_peer = in_ch.remote_peer
route_or_max_tx = self._find_partial_route(out_peer, in_peer, max_tx_msat, exclude=exclude)
while isinstance(route_or_max_tx, int): route_or_bounds = bounds
logger.debug(f"max feeless tx: {route_or_max_tx}") while isinstance(route_or_bounds, TxBounds):
try_again_msat = max(min_tx_msat, route_or_max_tx) old_bounds = route_or_bounds
route_or_bounds = self._find_partial_route(out_peer, in_peer, old_bounds, exclude=exclude)
if try_again_msat == route_or_max_tx: if route_or_bounds == old_bounds:
return RouteError.NO_ROUTE return RouteError.NO_ROUTE
# due to per-channel HTLC size requirements, we have to try again and we'll maybe get a different route if isinstance(route_or_bounds, RouteError):
route_or_max_tx = self._find_partial_route(out_peer, in_peer, try_again_msat, exclude=exclude) return route_or_bounds
route = route_or_max_tx
if isinstance(route, RouteError): route = self._add_route_endpoints(route_or_bounds, out_ch, in_ch)
return route
route = self._add_route_endpoints(route, out_ch, in_ch)
return route return route
def _find_partial_route(self, out_peer: str, in_peer: str, tx_msat: int, exclude: list[str]=[]) -> list[dict] | RouteError | int: def _find_partial_route(self, out_peer: str, in_peer: str, bounds: TxBounds, exclude: list[str]=[]) -> list[dict] | RouteError | TxBounds:
route = self.rpc.getroute(in_peer, amount_msat=tx_msat, riskfactor=0, fromid=out_peer, exclude=exclude, cltv=CLTV) route = self.rpc.getroute(in_peer, amount_msat=bounds.max_msat, riskfactor=0, fromid=out_peer, exclude=exclude, cltv=CLTV)
route = route["route"] route = route["route"]
if route == []: if route == []:
logger.debug(f"no route for {tx_msat}msat {out_peer} -> {in_peer}") logger.debug(f"no route for {bounds.max_msat}msat {out_peer} -> {in_peer}")
return RouteError.NO_ROUTE return RouteError.NO_ROUTE
send_msat = route[0]["amount_msat"] send_msat = route[0]["amount_msat"]
if send_msat != Millisatoshi(tx_msat): if send_msat != Millisatoshi(bounds.max_msat):
logger.debug(f"found route with non-zero fee: {send_msat} -> {tx_msat}. {route}") logger.debug(f"found route with non-zero fee: {send_msat} -> {bounds.max_msat}. {route}")
return self._max_feeless_tx_for_route(route)
for hop in route:
hop_scid = hop["channel"]
hop_dir = hop["direction"]
ch = self._get_directed_scid(hop_scid, hop_dir)
if ch["base_fee_millisatoshi"] != 0:
self.nonzero_base_channels.append(f"{hop_scid}/{hop_dir}")
bounds = bounds.restrict_to_zero_fees(ppm=ch["fee_per_millionth"])
if any(hop["base_fee_millisatoshi"] != 0 for hop in route):
return RouteError.HAS_BASE_FEE
return bounds
return route return route
def _max_feeless_tx_for_route(self, route: list[dict]) -> int|None:
has_base_fee = False
max_fee_per_mili = 0
for hop in route:
hop_scid = hop["channel"]
hop_dir = hop["direction"]
ch = self._get_directed_scid(hop_scid, hop_dir)
feebase = ch["base_fee_millisatoshi"]
if feebase:
has_base_fee = True
self.nonzero_base_channels.append(f"{hop_scid}/{hop_dir}")
per_mili = ch["fee_per_millionth"]
max_fee_per_mili = max(max_fee_per_mili, per_mili)
if has_base_fee:
return RouteError.HAS_BASE_FEE
if max_fee_per_mili == 0:
return int(route[0]["amount_msat"]) # no practical limit
return math.ceil(1000000 / max_fee_per_mili) - 1
def _add_route_endpoints(self, route, out_ch: LocalChannel, in_ch: LocalChannel): def _add_route_endpoints(self, route, out_ch: LocalChannel, in_ch: LocalChannel):
inbound_hop = dict( inbound_hop = dict(
id=self.self_id, id=self.self_id,
@ -310,7 +310,11 @@ def main():
rpc = LightningRpc(RPC_FILE) rpc = LightningRpc(RPC_FILE)
balancer = Balancer(rpc) balancer = Balancer(rpc)
balancer.balance_once_with_retries(args.out, args.in_, int(args.min_msat), int(args.max_msat)) bounds = TxBounds(
min_msat = int(args.min_msat),
max_msat = int(args.max_msat),
)
balancer.balance_once_with_retries(args.out, args.in_, bounds)
if __name__ == '__main__': if __name__ == '__main__':
main() main()