diff --git a/crates/cross/src/compound/list/flat.rs b/crates/cross/src/compound/list/flat.rs index 2784fb7..205504e 100644 --- a/crates/cross/src/compound/list/flat.rs +++ b/crates/cross/src/compound/list/flat.rs @@ -309,6 +309,7 @@ pub trait SplitHead { type Head; type Tail; fn split(self) -> (Self::Head, Self::Tail); + fn split_ref<'a>(&'a self) -> (&'a Self::Head, &'a Self::Tail); } impl SplitHead for Node { type Head = H; @@ -316,6 +317,9 @@ impl SplitHead for Node { fn split(self) -> (Self::Head, Self::Tail) { (self.head, self.tail) } + fn split_ref<'a>(&'a self) -> (&'a Self::Head, &'a Self::Tail) { + (&self.head, &self.tail) + } } /// these are exported for the convenience of potential consumers: not needed internally diff --git a/crates/cross/src/compound/list/mod.rs b/crates/cross/src/compound/list/mod.rs index 6b2d531..2fe3923 100644 --- a/crates/cross/src/compound/list/mod.rs +++ b/crates/cross/src/compound/list/mod.rs @@ -97,7 +97,7 @@ where // - [x] map (calls into fold, with the accumulator being Appendable::append) // - [ ] visit (calls into fold, with a no-op accumulator) // - [ ] visit_ref, visit_mut (calls as_ref/as_mut, then visit) -// - [ ] enumerate (calls fold, with the accumulator being a Peano counter + value copy) +// - [x] enumerate (calls fold, with the accumulator being a Peano counter + value copy) pub trait ListConsumer { type Output; @@ -113,6 +113,7 @@ pub trait FoldOp { pub struct FoldImpl(Op, State); +//////// fold by-value impl ListConsumer for FoldImpl { type Output = State; fn consume(self, _l: Empty) -> Self::Output { @@ -134,6 +135,28 @@ where } } +//////// fold by-ref +impl ListConsumer<&Empty> for FoldImpl { + type Output = State; + fn consume(self, _l: &Empty) -> Self::Output { + self.1 + } +} + +impl<'a, H, T, Op, State> ListConsumer<&'a Node> for FoldImpl +where + Op: FoldOp, + FoldImpl: ListConsumer<&'a T>, +{ + type Output = as ListConsumer<&'a T>>::Output; + fn consume(self, l: &'a Node) -> Self::Output { + let FoldImpl(op, state) = self; + let (head, tail) = l.split_ref(); + let next_state = op.feed(state, head); + FoldImpl(op, next_state).consume(tail) + } +} + pub trait Fold { type Output; fn fold(self, op: Op, init: Init) -> Self::Output; @@ -309,6 +332,32 @@ mod test { assert_eq!(list.fold(Sum, 2i32), 28f32); } + + #[derive(PartialEq)] + struct NotCopy(i32); + + struct SumRef; + impl FoldOp for SumRef { + type Output = i32; + fn feed(&self, prev: i32, next: &NotCopy) -> Self::Output { + prev + next.0 + } + } + impl FoldOp for SumRef { + type Output = i32; + fn feed(&self, prev: i32, next: &i32) -> Self::Output { + prev + *next + } + } + + #[test] + fn fold_ref() { + let list = &(3i32, NotCopy(4i32), 5i32).into_list(); + assert_eq!(list.fold(SumRef, 2i32), 14i32); + assert!(list == list); // just check that it wasn't consumed + } + + #[test] fn reverse_empty() { assert!(Empty::default().reverse() == Empty::default());