cross: list: allow visit to be mutable

This commit is contained in:
2022-08-18 02:45:15 -07:00
parent f2b23ace17
commit 2a9c065cb0
2 changed files with 34 additions and 19 deletions

View File

@@ -98,7 +98,7 @@ pub trait ListConsumer<L> {
/// process the subsequent value (and by extension all values)
pub trait FoldOp<State, V> {
type Output;
fn feed(&self, prev: State, next: V) -> Self::Output;
fn feed(&mut self, prev: State, next: V) -> Self::Output;
}
pub struct FoldImpl<Op, State>(Op, State);
@@ -118,7 +118,7 @@ where
{
type Output = <FoldImpl<Op, Op::Output> as ListConsumer<T>>::Output;
fn consume(self, l: Node<H, T>) -> Self::Output {
let FoldImpl(op, state) = self;
let FoldImpl(mut op, state) = self;
let (head, tail) = l.split();
let next_state = op.feed(state, head);
FoldImpl(op, next_state).consume(tail)
@@ -140,7 +140,7 @@ where
{
type Output = <FoldImpl<Op, Op::Output> as ListConsumer<&'a T>>::Output;
fn consume(self, l: &'a Node<H, T>) -> Self::Output {
let FoldImpl(op, state) = self;
let FoldImpl(mut op, state) = self;
let (head, tail) = l.split_ref();
let next_state = op.feed(state, head);
FoldImpl(op, next_state).consume(tail)
@@ -163,7 +163,7 @@ where
}
pub trait Visitor<E> {
fn visit(&self, v: E);
fn visit(&mut self, v: E);
}
pub struct VisitOp<V>(V);
@@ -172,7 +172,7 @@ where
V: Visitor<Next>
{
type Output = ();
fn feed(&self, _prev: (), next: Next) {
fn feed(&mut self, _prev: (), next: Next) {
self.0.visit(next)
}
}
@@ -194,7 +194,7 @@ where
pub struct ReverseOp;
impl<Prev, Next> FoldOp<Prev, Next> for ReverseOp {
type Output = Node<Next, Prev>;
fn feed(&self, prev: Prev, next: Next) -> Self::Output {
fn feed(&mut self, prev: Prev, next: Next) -> Self::Output {
Node::new(next, prev)
}
}
@@ -220,7 +220,7 @@ where
Prev: core::ops::Add<Next>,
{
type Output = Prev::Output;
fn feed(&self, prev: Prev, next: Next) -> Self::Output {
fn feed(&mut self, prev: Prev, next: Next) -> Self::Output {
prev + next
}
}
@@ -252,7 +252,7 @@ where
Prev: Appendable<F::Output>,
{
type Output = Appended<Prev, F::Output>;
fn feed(&self, prev: Prev, next: Next) -> Self::Output {
fn feed(&mut self, prev: Prev, next: Next) -> Self::Output {
prev.append(self.0.map(next))
}
}
@@ -300,7 +300,7 @@ where
Prev: Extend<Next>,
{
type Output = Prev::Output;
fn feed(&self, prev: Prev, next: Next) -> Self::Output {
fn feed(&mut self, prev: Prev, next: Next) -> Self::Output {
prev.extend(next)
}
}
@@ -353,7 +353,7 @@ where
L: Appendable<Tagged<L::Length, Next>>,
{
type Output = Appended<L, Tagged<L::Length, Next>>;
fn feed(&self, prev: L, next: Next) -> Self::Output {
fn feed(&mut self, prev: L, next: Next) -> Self::Output {
prev.append(Tagged::new(next))
}
}
@@ -406,7 +406,7 @@ mod test {
impl FoldOp<i32, i32> for SumVal {
type Output = i32;
fn feed(&self, prev: i32, next: i32) -> Self::Output {
fn feed(&mut self, prev: i32, next: i32) -> Self::Output {
prev + next
}
}
@@ -419,19 +419,19 @@ mod test {
impl FoldOp<i32, f32> for SumVal {
type Output = f32;
fn feed(&self, prev: i32, next: f32) -> Self::Output {
fn feed(&mut self, prev: i32, next: f32) -> Self::Output {
prev as f32 + next
}
}
impl FoldOp<f32, f32> for SumVal {
type Output = f32;
fn feed(&self, prev: f32, next: f32) -> Self::Output {
fn feed(&mut self, prev: f32, next: f32) -> Self::Output {
prev + next
}
}
impl FoldOp<f32, i32> for SumVal {
type Output = f32;
fn feed(&self, prev: f32, next: i32) -> Self::Output {
fn feed(&mut self, prev: f32, next: i32) -> Self::Output {
prev + next as f32
}
}
@@ -455,13 +455,13 @@ mod test {
struct SumRef;
impl FoldOp<i32, &NotCopy> for SumRef {
type Output = i32;
fn feed(&self, prev: i32, next: &NotCopy) -> Self::Output {
fn feed(&mut self, prev: i32, next: &NotCopy) -> Self::Output {
prev + next.0
}
}
impl FoldOp<i32, &i32> for SumRef {
type Output = i32;
fn feed(&self, prev: i32, next: &i32) -> Self::Output {
fn feed(&mut self, prev: i32, next: &i32) -> Self::Output {
prev + *next
}
}
@@ -475,7 +475,7 @@ mod test {
struct NoopVisitor;
impl<V> Visitor<V> for NoopVisitor {
fn visit(&self, _e: V) {}
fn visit(&mut self, _e: V) {}
}
#[test]
@@ -487,6 +487,21 @@ mod test {
list.visit(NoopVisitor);
}
struct AccumVisitor(i32);
impl Visitor<i32> for &mut AccumVisitor {
fn visit(&mut self, e: i32) {
self.0 += e * 2;
}
}
#[test]
fn visit_mut() {
let list = (3i32, 4i32, 5i32).into_list();
let mut v = AccumVisitor(0);
list.visit(&mut v);
assert_eq!(v.0, 24);
}
#[test]
fn sum() {
let list = (3i32, 4i32, 5i32).into_list();