diff --git a/crates/cross/src/compound/list/mod.rs b/crates/cross/src/compound/list/mod.rs index 47f4f4d..2d836e0 100644 --- a/crates/cross/src/compound/list/mod.rs +++ b/crates/cross/src/compound/list/mod.rs @@ -207,6 +207,31 @@ where } } +pub struct FlattenOp; +impl FoldOp for FlattenOp +where + Prev: Extend, +{ + type Output = Prev::Output; + fn feed(&self, prev: Prev, next: Next) -> Self::Output { + prev.extend(next) + } +} + +pub trait Flatten { + type Output; + fn flatten(self) -> Self::Output; +} +impl Flatten for L +where + L: Fold +{ + type Output = L::Output; + fn flatten(self) -> Self::Output { + self.fold(FlattenOp, Empty::default()) + } +} + #[derive(Copy, Clone, Default, PartialEq)] pub struct Tagged { @@ -435,6 +460,50 @@ mod test { assert!(l0.extend(l1) == expected); } + #[test] + fn flatten_empty() { + assert!(Empty::default().flatten() == Empty::default()); + } + #[test] + fn flatten_inner_empty1() { + let l = (Empty::default(),).into_list(); + assert!(l.flatten() == Empty::default()); + } + #[test] + fn flatten_inner_empty2() { + let l = (Empty::default(), Empty::default()).into_list(); + assert!(l.flatten() == Empty::default()); + } + #[test] + fn flatten_mixed() { + let l = ( + (2u32, 3f32).into_list(), + (4u32,).into_list(), + ).into_list(); + let expected = ( + 2u32, + 3f32, + 4u32, + ).into_list(); + assert!(l.flatten() == expected); + } + #[test] + fn flatten_nested() { + let l = ( + (2u32, 3f32).into_list(), + ("hello", ("every", "one").into_list()).into_list(), + (4u32,).into_list(), + ).into_list(); + let expected = ( + 2u32, + 3f32, + "hello", + ("every", "one").into_list(), + 4u32, + ).into_list(); + assert!(l.flatten() == expected); + } + #[test] fn enumerate_empty() { assert!(Empty::default().enumerate() == Empty::default());