(feat): refactor superkmer to use obipipeline with flat transforms

- Replace crossbeam-channel-based threading model
- Introduce obipipeline crate with Stage::Transform/Flat support  
- Replace single input + format detection by multiple inputs via PathIter
- Implement pipeline stages: open_chunks → normalize → build_superkmers (flat) + write_batch
- Add SharedFlatFn for 1→N transformations with delta tracking in scheduler loop
This commit is contained in:
Eric Coissac
2026-04-24 18:16:47 +02:00
parent f1c8fc85c9
commit d4e4289aff
5 changed files with 540 additions and 489 deletions
+1 -1
View File
@@ -668,10 +668,10 @@ name = "obikmer"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"clap", "clap",
"crossbeam-channel",
"obifastwrite", "obifastwrite",
"obikrope", "obikrope",
"obikseq", "obikseq",
"obipipeline",
"obiread", "obiread",
"obiskbuilder", "obiskbuilder",
] ]
+1 -1
View File
@@ -12,6 +12,6 @@ obikseq = { path = "../obikseq" }
obiread = { path = "../obiread" } obiread = { path = "../obiread" }
obiskbuilder = { path = "../obiskbuilder" } obiskbuilder = { path = "../obiskbuilder" }
obifastwrite = { path = "../obifastwrite" } obifastwrite = { path = "../obifastwrite" }
obipipeline = { path = "../obipipeline" }
clap = { version = "4", features = ["derive"] } clap = { version = "4", features = ["derive"] }
crossbeam-channel = "0.5"
obikrope = { path = "../obikrope" } obikrope = { path = "../obikrope" }
+78 -102
View File
@@ -1,17 +1,19 @@
use std::io::{self, BufWriter, Write}; use std::io::{self, BufWriter, Write};
use std::thread; use std::path::PathBuf;
use std::sync::{Arc, Mutex};
use clap::Args; use clap::Args;
use crossbeam_channel::bounded;
use obifastwrite::write_scatter; use obifastwrite::write_scatter;
use obikrope::Rope; use obikrope::Rope;
use obikseq::superkmer::SuperKmer; use obikseq::superkmer::SuperKmer;
use obipipeline::{WorkerPool, make_pipeline};
use obiskbuilder::SuperKmerIter; use obiskbuilder::SuperKmerIter;
#[derive(Args)] #[derive(Args)]
pub struct SuperkmerArgs { pub struct SuperkmerArgs {
/// Input: file path, URL (http/https), or `-` for stdin /// Input files or directories (FASTA/FASTQ, optionally gzip-compressed)
pub input: String, #[arg(num_args = 1..)]
pub inputs: Vec<String>,
/// k-mer size /// k-mer size
#[arg(short, long, default_value_t = 31)] #[arg(short, long, default_value_t = 31)]
@@ -34,36 +36,70 @@ pub struct SuperkmerArgs {
pub partitions: u64, pub partitions: u64,
/// Number of worker threads /// Number of worker threads
#[arg(short = 'T', long, default_value_t = 4)] #[arg(short = 'T', long, default_value_t = 16)]
pub threads: usize, pub threads: usize,
/// Force input format: fasta or fastq (default: auto-detect from extension)
#[arg(long)]
pub format: Option<String>,
} }
#[derive(Clone, Copy)] enum PipelineData {
enum Format { Path(PathBuf),
Fasta, RawChunk(Rope),
Fastq, NormChunk(Rope),
Batch(Vec<(u64, SuperKmer)>),
} }
fn detect_format(source: &str, hint: Option<&str>) -> Format { // SAFETY: Rope contains Cell<u8> which is !Sync, but pipeline ownership transfers
if let Some(h) = hint { // exclusively through channels — no item is ever shared across threads.
return match h.to_lowercase().as_str() { unsafe impl Send for PipelineData {}
"fastq" | "fq" => Format::Fastq, unsafe impl Sync for PipelineData {}
_ => Format::Fasta,
}; // ── Stage functions ───────────────────────────────────────────────────────────
/// Opens a sequence file and returns an iterator over its raw Rope chunks.
/// Chunk-level I/O errors are logged and skipped.
fn open_chunks(path: PathBuf) -> io::Result<impl Iterator<Item = Rope>> {
let path_str = path
.to_str()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "non-UTF-8 path"))?;
let iter = obiread::read_sequence_chunks(path_str)?;
Ok(iter.filter_map(|r| match r {
Ok(rope) => Some(rope),
Err(e) => {
eprintln!("chunk read error: {e}");
None
} }
if source.ends_with(".fq") }))
|| source.ends_with(".fastq") }
|| source.ends_with(".fq.gz")
|| source.ends_with(".fastq.gz") /// Normalises a raw sequence chunk (FASTA or FASTQ) into a compact ACGT/NUL rope.
{ fn normalize(rope: Rope, k: usize) -> io::Result<Rope> {
Format::Fastq obiread::normalize_sequence_chunk(rope, k)
} else { }
Format::Fasta
/// Extracts all super-kmers from a normalised rope.
fn build_superkmers(
rope: Rope,
k: usize,
m: usize,
level_max: usize,
theta: f64,
) -> Vec<(u64, SuperKmer)> {
SuperKmerIter::new(&rope, k, m, level_max, theta).collect()
}
/// Writes a batch of super-kmers to the output sink.
fn write_batch(
batch: Vec<(u64, SuperKmer)>,
out: &Mutex<BufWriter<io::Stdout>>,
partitions: u64,
k: usize,
m: usize,
) -> io::Result<()> {
let mut w = out.lock().unwrap();
for (min_hash, sk) in batch {
let partition = (mix64(min_hash) % partitions) as u32;
write_scatter(&sk, &mut *w, k, m, partition, min_hash)?;
} }
Ok(())
} }
#[inline] #[inline]
@@ -75,9 +111,9 @@ fn mix64(x: u64) -> u64 {
x ^ (x >> 31) x ^ (x >> 31)
} }
// ── Entry point ───────────────────────────────────────────────────────────────
pub fn run(args: SuperkmerArgs) { pub fn run(args: SuperkmerArgs) {
let format = detect_format(&args.input, args.format.as_deref());
let input = args.input.clone();
let k = args.kmer_size; let k = args.kmer_size;
let m = args.minimizer_size; let m = args.minimizer_size;
let theta = args.theta; let theta = args.theta;
@@ -85,81 +121,21 @@ pub fn run(args: SuperkmerArgs) {
let partitions = args.partitions; let partitions = args.partitions;
let n_workers = args.threads.max(1); let n_workers = args.threads.max(1);
// raw chunks (reader → workers) let paths = args.inputs.iter().map(PathBuf::from).collect();
let (raw_tx, raw_rx) = bounded::<Rope>(n_workers * 2); let path_source = obiread::PathIter::new(paths);
// superkmer batches (workers → output)
let (sk_tx, sk_rx) = bounded::<Vec<(u64, SuperKmer)>>(n_workers * 2);
// ── reader thread ───────────────────────────────────────────────────────── let out = Arc::new(Mutex::new(BufWriter::new(io::stdout())));
let reader = thread::spawn(move || { let out_sink = Arc::clone(&out);
let file = obiread::xopen(&input).expect("cannot open input");
match format {
Format::Fasta => {
for chunk in obiread::fasta_chunks(file) {
raw_tx.send(chunk.expect("read error")).unwrap();
}
}
Format::Fastq => {
for chunk in obiread::fastq_chunks(file) {
raw_tx.send(chunk.expect("read error")).unwrap();
}
}
}
// raw_tx drops here → workers observe disconnect
});
// ── worker threads ──────────────────────────────────────────────────────── let pipeline = make_pipeline! {
let workers: Vec<_> = (0..n_workers) PipelineData,
.map(|_| { source path_source => Path,
let raw_rx = raw_rx.clone(); ||? open_chunks : Path => RawChunk,
let sk_tx = sk_tx.clone(); |? { move |rope| normalize(rope, k) } : RawChunk => NormChunk,
thread::spawn(move || { | { move |rope| build_superkmers(rope, k, m, level_max, theta) }: NormChunk => Batch,
for raw_chunk in raw_rx { sink? { move |batch| write_batch(batch, &out_sink, partitions, k, m) } @ Batch,
let norm = match format {
Format::Fasta => obiread::normalize::normalize_fasta_chunk(raw_chunk, k),
Format::Fastq => obiread::normalize::normalize_fastq_chunk(raw_chunk, k),
}; };
const BATCH_SIZE: usize = 10_000;
let mut batch = Vec::with_capacity(BATCH_SIZE);
for sk in SuperKmerIter::new(&norm, k, m, level_max, theta) {
batch.push(sk);
if batch.len() == BATCH_SIZE {
sk_tx
.send(std::mem::replace(
&mut batch,
Vec::with_capacity(BATCH_SIZE),
))
.unwrap();
}
}
if !batch.is_empty() {
sk_tx.send(batch).unwrap();
}
}
// sk_tx clone drops here
})
})
.collect();
// drop the extra sk_tx clone held by this thread so the output thread exits WorkerPool::new(pipeline, n_workers, 1).run();
drop(sk_tx); out.lock().unwrap().flush().expect("flush error");
// ── output thread ─────────────────────────────────────────────────────────
let output = thread::spawn(move || {
let stdout = io::stdout();
let mut out = BufWriter::new(stdout.lock());
for batch in sk_rx {
for (min_hash, sk) in batch {
let partition = (mix64(min_hash) % partitions) as u32;
write_scatter(&sk, &mut out, k, m, partition, min_hash).expect("write error");
}
}
out.flush().expect("flush error");
});
reader.join().expect("reader thread panicked");
for w in workers {
w.join().expect("worker thread panicked");
}
output.join().expect("output thread panicked");
} }
+11 -1
View File
@@ -2,9 +2,19 @@ mod scheduler;
pub use scheduler::Pipeline; pub use scheduler::Pipeline;
pub use scheduler::PipelineError; pub use scheduler::PipelineError;
pub use scheduler::SharedFlatFn;
pub use scheduler::SharedFn; pub use scheduler::SharedFn;
pub use scheduler::SinkFn; pub use scheduler::SinkFn;
pub use scheduler::SourceFn; pub use scheduler::SourceFn;
pub use scheduler::Stage;
pub use scheduler::WorkerPool; pub use scheduler::WorkerPool;
/// Re-export de `crossbeam_channel::Sender` utilisé dans les macros flat transform.
/// Permet aux macros `make_flat_transform!` / `make_flat_transform_fallible!` d'utiliser
/// `$crate::PipelineSender` sans que le crate appelant n'ait besoin de dépendre
/// directement de `crossbeam_channel`.
pub use crossbeam_channel::Sender as PipelineSender;
// make_sink, make_sink_fallible, make_source, make_source_fallible, // make_sink, make_sink_fallible, make_source, make_source_fallible,
// make_transform, make_transform_fallible are exported at crate root via #[macro_export] // make_transform, make_transform_fallible, make_flat_transform,
// make_flat_transform_fallible sont exportés à la racine du crate via #[macro_export]
+443 -378
View File
@@ -37,333 +37,52 @@ impl Error for PipelineError {
} }
} }
/// Represents a single processing stage in a data pipeline. // ── Function types ────────────────────────────────────────────────────────────
///
/// `StepKind` abstracts over three fundamental types of operations:
/// - **Source**: Produces a sequence of data items (e.g., from a file, iterator, generator).
/// - **Transform**: Converts one data item into another (pure transformation).
/// - **Sink**: Consumes a final data item (e.g., print, store, aggregate).
///
/// The type `DATA` is typically an enum that unifies all data variants that can flow through
/// the pipeline. The type `ERROR` is the error type used by fallible operations.
///
/// # Generics
/// - `DATA`: The type of values passed between stages (must be `Send + Sync`).
/// - `ERROR`: The error type for fallible computations (usually `PipelineError`).
///
/// # Example
/// ```
/// # use std::io;
/// # enum MyData { Number(i32), Text(String) }
/// # type MyError = io::Error;
/// # let source_iter = vec![Ok(MyData::Number(42))].into_iter();
/// # let transform_fn = |d: MyData| -> Result<MyData, MyError> { Ok(d) };
/// # let sink_fn = |d: MyData| {};
/// let source = StepKind::Source(Box::new(source_iter));
/// let transform = StepKind::Transform(Box::new(transform_fn));
/// let sink = StepKind::Sink(Box::new(sink_fn));
/// ```
///
/// # Note
/// - A `Source` does not take an input; it yields an iterator of `Result<DATA, ERROR>`.
/// - A `Transform` takes a `DATA` and returns a `Result<DATA, ERROR>`.
/// - A `Sink` takes a `DATA` and performs a side effect (no return value).
/// Fonction source : appelée répétitivement, retourne le prochain item ou EndOfStream. /// Fonction source : appelée répétitivement, retourne le prochain item ou EndOfStream.
/// `FnMut` car elle maintient un état interne (position dans l'itérateur). /// `FnMut` car elle maintient un état interne (position dans l'itérateur).
pub type SourceFn<D> = Box<dyn FnMut() -> Result<D, PipelineError> + Send + Sync>; pub type SourceFn<D> = Box<dyn FnMut() -> Result<D, PipelineError> + Send>;
/// Fonction sink : consomme un item final, peut échouer (erreur d'I/O, etc.). /// Fonction sink : consomme un item final, peut échouer (erreur d'I/O, etc.).
pub type SinkFn<D> = Box<dyn Fn(D) -> Result<(), PipelineError> + Send + Sync>; pub type SinkFn<D> = Box<dyn Fn(D) -> Result<(), PipelineError> + Send>;
/// Creates a `StepKind::Source` from an iterator of plain values.
/// Each value is wrapped with the specified `$enum` variant and returned as `Ok`.
/// The source returns `Err(PipelineError::EndOfStream)` when exhausted.
///
/// # Example
/// ```
/// let iter = vec![1.0, 2.0].into_iter();
/// let source = make_source!(PipelineData, iter, Numeric);
/// ```
#[macro_export]
macro_rules! make_source {
($enum:ident, $iterator:expr, $output:ident) => {{
let mut iter = $iterator.into_iter();
Box::new(
move || -> ::std::result::Result<$enum, $crate::PipelineError> {
match iter.next() {
Some(x) => Ok($enum::$output(x)),
None => Err($crate::PipelineError::EndOfStream),
}
},
)
as Box<dyn FnMut() -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync>
}};
}
/// Creates a `StepKind::Source` from an iterator of `Result<T, E>`.
/// On `Ok(x)` it wraps the value into the specified `$enum` variant.
/// On `Err(e)` it returns `Err(PipelineError::StepError(Box::new(e)))`.
/// Returns `Err(PipelineError::EndOfStream)` when the iterator ends.
///
/// # Example
/// ```
/// let iter = vec![Ok(1.0), Err("oops")].into_iter();
/// let source = make_source_fallible!(PipelineData, iter, Numeric);
/// ```
#[macro_export]
macro_rules! make_source_fallible {
($enum:ident, $iterator:expr, $output:ident) => {{
let mut iter = $iterator.into_iter();
Box::new(
move || -> ::std::result::Result<$enum, $crate::PipelineError> {
match iter.next() {
Some(Ok(x)) => Ok($enum::$output(x)),
Some(Err(e)) => Err($crate::PipelineError::StepError(Box::new(e))),
None => Err($crate::PipelineError::EndOfStream),
}
},
)
as Box<dyn FnMut() -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync>
}};
}
/// Creates a pipeline stage from a pure (non-fallible) function.
///
/// This macro generates a closure that implements the `PipelineStage` trait by pattern
/// matching on the input `PipelineData` variant, applying the provided function, and
/// wrapping the result in the output variant.
///
/// # Arguments
/// * `$func` - The function to apply: `Fn(T) -> U`
/// * `$input` - Input `PipelineData` variant pattern (e.g., `Int(i64)`)
/// * `$output` - Output `PipelineData` variant pattern (e.g., `String(String)`)
///
/// # Example
///
/// ```ignore
/// // Define PipelineData enum
/// enum PipelineData {
/// Int(i64),
/// String(String),
/// }
///
/// // Create pure stage
/// let to_string_stage = make_stage!(
/// to_string,
/// Int(i64),
/// String(String)
/// );
///
/// // Use in pipeline
/// let result = to_string_stage(PipelineData::Int(42)).unwrap();
/// assert!(matches!(result, PipelineData::String(s) if s == "42"));
/// ```
#[macro_export]
macro_rules! make_transform {
($enum:ident, $func:ident, $input:ident, $output:ident) => {
::std::sync::Arc::from(Box::new(
|data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> {
match data {
$enum::$input(x) => Ok($enum::$output($func(x))),
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<
dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync,
>)
};
}
/// Creates a pipeline stage from a fallible function.
///
/// This macro generates a closure that pattern matches on the input `PipelineData`
/// variant, applies the provided function, and wraps the result in the output variant.
/// If the function returns an error, it is boxed and wrapped in `PipelineError::StepError`.
///
/// # Arguments
/// * `$func` - The fallible function to apply: `Fn(T) -> Result<U, E>`
/// * `$input` - Input `PipelineData` variant pattern (e.g., `Int(i64)`)
/// * `$output` - Output `PipelineData` variant pattern (e.g., `String(String)`)
///
/// # Example
///
/// ```ignore
/// fn parse_int(s: &str) -> Result<i64, std::num::ParseIntError> {
/// s.parse()
/// }
///
/// // Define PipelineData enum
/// enum PipelineData {
/// String(String),
/// Int(i64),
/// }
///
/// // Create fallible stage
/// let parse_stage = make_stage_fallible!(
/// parse_int,
/// String(String),
/// Int(i64)
/// );
///
/// // Use in pipeline
/// let result = parse_stage(PipelineData::String("42".into())).unwrap();
/// assert!(matches!(result, PipelineData::Int(n) if n == 42));
/// ```
#[macro_export]
macro_rules! make_transform_fallible {
($enum:ident, $func:ident, $input:ident, $output:ident) => {
::std::sync::Arc::from(Box::new(
|data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> {
match data {
$enum::$input(inner) => {
let result = $func(inner)
.map_err(|e| $crate::PipelineError::StepError(Box::new(e)))?;
Ok($enum::$output(result))
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<
dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync,
>)
};
}
/// Creates a `StepKind::Sink` from a function that consumes a concrete value and returns `()`.
/// The returned sink always returns `Ok(())`.
/// The function is wrapped to accept `$enum::$input` and ignores the result.
///
/// # Example
/// ```
/// fn print_number(n: f64) { println!("{}", n); }
/// let sink = make_sink!(PipelineData, print_number, Numeric);
/// ```
#[macro_export]
macro_rules! make_sink {
($enum:ident, $func:ident, $input:ident) => {
Box::new(
|data: $enum| -> ::std::result::Result<(), $crate::PipelineError> {
match data {
$enum::$input(x) => {
$func(x);
Ok(())
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<dyn Fn($enum) -> ::std::result::Result<(), $crate::PipelineError> + Send + Sync>
};
}
/// Creates a `StepKind::Sink` from a fallible function that returns `Result<(), E>`.
/// Errors from `$func` are wrapped in `PipelineError::StepError`.
///
/// # Example
/// ```
/// fn save_to_file(hash: u64) -> std::io::Result<()> { Ok(()) }
/// let sink = make_sink_fallible!(PipelineData, save_to_file, Hash);
/// ```
#[macro_export]
macro_rules! make_sink_fallible {
($enum:ident, $func:ident, $input:ident) => {
Box::new(
|data: $enum| -> ::std::result::Result<(), $crate::PipelineError> {
match data {
$enum::$input(inner) => {
$func(inner).map_err(|e| $crate::PipelineError::StepError(Box::new(e)))
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<dyn Fn($enum) -> ::std::result::Result<(), $crate::PipelineError> + Send + Sync>
};
}
/// Construit un `Pipeline` à partir d'une source, d'une liste de transforms et d'un sink.
///
/// Syntaxe :
/// ```ignore
/// make_pipeline! {
/// MyData,
/// source my_iter => Variant, // source non-fallible
/// source? my_iter => Variant, // source fallible (Result<T, E>)
/// | func: In => Out, // transform non-fallible (répété 0..N fois)
/// |? func: In => Out, // transform fallible (répété 0..N fois)
/// sink my_func @ Variant, // sink non-fallible
/// sink? my_func @ Variant, // sink fallible
/// }
/// ```
///
/// Implémenté comme un TT muncher : la règle interne `@build` traite les transforms
/// un par un en les accumulant, puis termine sur `sink`/`sink?`.
#[macro_export]
macro_rules! make_pipeline {
// ── Points d'entrée ──────────────────────────────────────────────────
($enum:ident, source $src:expr => $src_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum,
{ $crate::make_source!($enum, $src, $src_out) },
[],
$($rest)*)
};
($enum:ident, source? $src:expr => $src_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum,
{ $crate::make_source_fallible!($enum, $src, $src_out) },
[],
$($rest)*)
};
// ── Accumulation des transforms ──────────────────────────────────────
// transform non-fallible : |
(@build $enum:ident, $source:tt, [$($acc:tt)*],
| $tf:ident : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_transform!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// transform fallible : |?
(@build $enum:ident, $source:tt, [$($acc:tt)*],
|? $tf:ident : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_transform_fallible!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// ── Terminaison : sink ───────────────────────────────────────────────
(@build $enum:ident, $source:tt, [$($acc:tt)*],
sink $sink_fn:ident @ $sink_in:ident $(,)?) => {
$crate::Pipeline::new(
$source,
vec![$($acc)*],
$crate::make_sink!($enum, $sink_fn, $sink_in),
)
};
(@build $enum:ident, $source:tt, [$($acc:tt)*],
sink? $sink_fn:ident @ $sink_in:ident $(,)?) => {
$crate::Pipeline::new(
$source,
vec![$($acc)*],
$crate::make_sink_fallible!($enum, $sink_fn, $sink_in),
)
};
}
/// Fonction de transformation partagée entre workers via Arc. /// Fonction de transformation partagée entre workers via Arc.
/// Arc<dyn Fn> permet à plusieurs workers de partager la même closure
/// sans la copier (Arc::clone = simple incrément de compteur).
pub type SharedFn<D> = Arc<dyn Fn(D) -> Result<D, PipelineError> + Send + Sync>; pub type SharedFn<D> = Arc<dyn Fn(D) -> Result<D, PipelineError> + Send + Sync>;
/// Tâche envoyée à un worker : donnée + fonction à appliquer + canal de résultat. /// Fonction de transformation 1→N (flat map) partagée entre workers via Arc.
/// Le worker n'a pas besoin de connaître sa position dans le pipeline ; ///
/// le scheduler lui dit exactement quoi faire et où envoyer le résultat. /// La fonction reçoit l'item d'entrée, un canal `push` pour envoyer chaque item
pub type WorkerTask<D> = (D, SharedFn<D>, Sender<Result<D, PipelineError>>); /// produit, et un canal `delta` pour signaler au scheduler combien d'items
/// supplémentaires sont entrés dans le pipeline (N-1 si N items produits).
/// Elle doit appeler `delta.send(N - 1)` **après** avoir poussé tous les items.
pub type SharedFlatFn<D> =
Arc<dyn Fn(D, &Sender<Result<D, PipelineError>>, &Sender<isize>) + Send + Sync>;
// ── Stage enum ────────────────────────────────────────────────────────────────
/// Une étape du pipeline : transform classique (1→1) ou flat transform (1→N).
pub enum Stage<D> {
Transform(SharedFn<D>),
Flat(SharedFlatFn<D>),
}
impl<D> Clone for Stage<D> {
fn clone(&self) -> Self {
match self {
Stage::Transform(f) => Stage::Transform(Arc::clone(f)),
Stage::Flat(f) => Stage::Flat(Arc::clone(f)),
}
}
}
// ── Worker task ───────────────────────────────────────────────────────────────
enum WorkerTask<D> {
Transform(D, SharedFn<D>, Sender<Result<D, PipelineError>>),
Flat(D, SharedFlatFn<D>, Sender<Result<D, PipelineError>>, Sender<isize>),
}
// ── Thread runners ────────────────────────────────────────────────────────────
fn source_runner<DATA>( fn source_runner<DATA>(
mut source: SourceFn<DATA>, mut source: SourceFn<DATA>,
@@ -381,7 +100,7 @@ where
match source() { match source() {
Ok(data) => { Ok(data) => {
if tx.send(Ok(data)).is_err() { if tx.send(Ok(data)).is_err() {
break; // récepteur disparu break;
} }
} }
Err(PipelineError::EndOfStream) => break, Err(PipelineError::EndOfStream) => break,
@@ -398,28 +117,29 @@ where
/// Lance un thread worker du pool. /// Lance un thread worker du pool.
/// ///
/// Le worker attend des tâches sur `task_rx`. Chaque tâche est un triplet /// Gère deux types de tâches :
/// `(data, f, result_tx)` : il applique `f(data)` et envoie le résultat /// - `Transform` : applique `f(data)` et envoie le résultat dans `result_tx`.
/// dans `result_tx`. C'est le scheduler qui décide quelle fonction envoyer /// - `Flat` : appelle `f(data, &push_tx, &delta_tx)` ; la fonction elle-même
/// et quel canal de résultat utiliser — le worker lui-même est générique. /// pousse ses items dans `push_tx` et envoie `N-1` dans `delta_tx`.
fn transform_runner<DATA>(task_rx: Receiver<WorkerTask<DATA>>) -> thread::JoinHandle<()> fn transform_runner<DATA>(task_rx: Receiver<WorkerTask<DATA>>) -> thread::JoinHandle<()>
where where
DATA: Send + Sync + 'static, DATA: Send + Sync + 'static,
{ {
thread::spawn(move || { thread::spawn(move || {
while let Ok((data, f, result_tx)) = task_rx.recv() { while let Ok(task) = task_rx.recv() {
match task {
WorkerTask::Transform(data, f, result_tx) => {
let _ = result_tx.send(f(data)); let _ = result_tx.send(f(data));
} }
WorkerTask::Flat(data, f, push_tx, delta_tx) => {
f(data, &push_tx, &delta_tx);
}
}
}
}) })
} }
/// Lance le thread sink. /// Lance le thread sink.
///
/// Retourne :
/// - `Sender<DATA>` : le scheduler y envoie les données finales
/// - `Receiver<PipelineError>` : le sink y pousse toute erreur rencontrée ;
/// le scheduler surveille ce canal en priorité absolue
/// pour interrompre le pipeline dès qu'une erreur survient.
fn sink_runner<DATA>( fn sink_runner<DATA>(
sink: SinkFn<DATA>, sink: SinkFn<DATA>,
capacity: usize, capacity: usize,
@@ -437,29 +157,33 @@ where
for data in data_rx { for data in data_rx {
if let Err(e) = sink(data) { if let Err(e) = sink(data) {
let _ = err_tx.send(e); let _ = err_tx.send(e);
break; // on arrête dès la première erreur break;
} }
} }
}); });
(data_tx, err_rx, handle) (data_tx, err_rx, handle)
} }
// ── Pipeline ──────────────────────────────────────────────────────────────────
pub struct Pipeline<DATA> { pub struct Pipeline<DATA> {
source: SourceFn<DATA>, source: SourceFn<DATA>,
transforms: Vec<SharedFn<DATA>>, stages: Vec<Stage<DATA>>,
sink: SinkFn<DATA>, sink: SinkFn<DATA>,
} }
impl<DATA> Pipeline<DATA> { impl<DATA> Pipeline<DATA> {
pub fn new( pub fn new(
source: SourceFn<DATA>, source: SourceFn<DATA>,
transforms: Vec<SharedFn<DATA>>, stages: Vec<Stage<DATA>>,
sink: SinkFn<DATA>, sink: SinkFn<DATA>,
) -> Self { ) -> Self {
Self { source, transforms, sink } Self { source, stages, sink }
} }
} }
// ── WorkerPool ────────────────────────────────────────────────────────────────
pub struct WorkerPool<DATA> { pub struct WorkerPool<DATA> {
pipeline: Pipeline<DATA>, pipeline: Pipeline<DATA>,
handles: Vec<std::thread::JoinHandle<()>>, handles: Vec<std::thread::JoinHandle<()>>,
@@ -481,11 +205,10 @@ where
} }
pub fn run(mut self) { pub fn run(mut self) {
let n = self.pipeline.transforms.len(); let n = self.pipeline.stages.len();
// ── Canaux inter-stages ──────────────────────────────────────────── // ── Canaux inter-stages ────────────────────────────────────────────
// stage_txs[i] : le worker qui exécute transform[i] y envoie son résultat // stage_txs[i] / stage_rxs[i] : sortie du stage i
// stage_rxs[i] : le scheduler lit ici pour dispatcher au transform[i+1] (ou sink)
let mut stage_txs: Vec<Sender<Result<DATA, PipelineError>>> = Vec::new(); let mut stage_txs: Vec<Sender<Result<DATA, PipelineError>>> = Vec::new();
let mut stage_rxs: Vec<Receiver<Result<DATA, PipelineError>>> = Vec::new(); let mut stage_rxs: Vec<Receiver<Result<DATA, PipelineError>>> = Vec::new();
for _ in 0..n { for _ in 0..n {
@@ -498,12 +221,9 @@ where
let (source_rx, src_handle) = source_runner(self.pipeline.source, self.capacity); let (source_rx, src_handle) = source_runner(self.pipeline.source, self.capacity);
self.handles.push(src_handle); self.handles.push(src_handle);
// Les transforms sont déjà des SharedFn<DATA> — pas de conversion nécessaire. let stages = self.pipeline.stages;
let transforms = self.pipeline.transforms;
// ── Worker pool ──────────────────────────────────────────────────── // ── Worker pool ────────────────────────────────────────────────────
// Canal partagé par tous les workers : le scheduler y pousse des WorkerTask,
// chaque worker en dépile une à la fois.
let (worker_tx, worker_rx): (Sender<WorkerTask<DATA>>, Receiver<WorkerTask<DATA>>) = let (worker_tx, worker_rx): (Sender<WorkerTask<DATA>>, Receiver<WorkerTask<DATA>>) =
bounded(self.capacity); bounded(self.capacity);
@@ -515,40 +235,51 @@ where
let (sink_tx, sink_err_rx, sink_handle) = sink_runner(self.pipeline.sink, self.capacity); let (sink_tx, sink_err_rx, sink_handle) = sink_runner(self.pipeline.sink, self.capacity);
self.handles.push(sink_handle); self.handles.push(sink_handle);
// ── Boucle principale ───────────────────────────────────────────── // ── Canal delta pour les flat stages ───────────────────────────────
// Chaque flat worker envoie `N-1` ici après avoir poussé N items.
// Le scheduler ajuste `in_flight` en conséquence.
let (flat_delta_tx, flat_delta_rx) = bounded::<isize>(self.capacity);
// ── Boucle principale ──────────────────────────────────────────────
// //
// Le Select est reconstruit à chaque itération, ce qui permet de // `in_flight` (isize) = nb d'items qui doivent encore atteindre le sink.
// retirer source_rx une fois la source épuisée. // Peut temporairement être négatif si un flat worker a poussé ses items
// avant que le scheduler ait reçu le delta correspondant.
// //
// Priorités (biased = index le plus bas gagne) : // `flat_workers_active` = nb de flat workers en cours d'exécution.
// index 0 → sink_err_rx (arrêt immédiat sur erreur sink) // Empêche la terminaison prématurée quand in_flight vaut 0 mais qu'un
// index 1..=N → stage_rxs[N-1..0] (vider le pipeline en priorité) // flat worker n'a pas encore envoyé son delta.
// index N+1 → source_rx (dernier recours : nouvelles données)
// //
// Quand k == 0 : erreur du sink // Priorités du Select biaisé (index le plus bas = priorité la plus haute) :
// Quand 1 <= k <= N : stage concerné = N-k // 0 sink_err_rx (arrêt immédiat sur erreur sink)
// Quand k == N+1 : item venant de la source // 1 → flat_delta_rx (mettre à jour in_flight avant de dispatcher)
// 2..=n+1 → stage_rxs[n-1..0] (vider le pipeline en priorité)
// n+2 → source_rx (dernier recours : nouvelles données)
// //
// Terminaison : on quitte uniquement quand source_done ET in_flight == 0, // Quand k = 0 : erreur du sink
// ce qui garantit que tous les items ont traversé le pipeline jusqu'au sink. // Quand k = 1 : delta d'un flat worker
// Quand 2 ≤ k ≤ n+1 : résultat du stage n+1-k
// Quand k = n+2 : item source
//
// Terminaison : source tarie ET in_flight == 0 ET aucun flat worker actif.
{ {
let mut source_done = false; let mut source_done = false;
let mut in_flight: usize = 0; let mut in_flight: isize = 0;
let mut flat_workers_active: usize = 0;
loop { loop {
// Condition de sortie : plus rien en vol et source tarie if source_done && in_flight == 0 && flat_workers_active == 0 {
if source_done && in_flight == 0 {
break; break;
} }
// Reconstruction du Select (sans source_rx si source épuisée)
let mut sel = Select::new_biased(); let mut sel = Select::new_biased();
sel.recv(&sink_err_rx); // index 0 sel.recv(&sink_err_rx); // index 0
sel.recv(&flat_delta_rx); // index 1
for rx in stage_rxs.iter().rev() { for rx in stage_rxs.iter().rev() {
sel.recv(rx); // indices 1 .. N sel.recv(rx); // indices 2..=n+1
} }
let src_idx = if !source_done { let src_idx = if !source_done {
Some(sel.recv(&source_rx)) // index N+1 (seulement si encore active) Some(sel.recv(&source_rx)) // index n+2
} else { } else {
None None
}; };
@@ -557,54 +288,63 @@ where
let k = oper.index(); let k = oper.index();
if k == 0 { if k == 0 {
// ── Erreur du sink : on arrête tout ────────────────── // ── Erreur du sink ────────────────────────────────────
match oper.recv(&sink_err_rx) { match oper.recv(&sink_err_rx) {
Ok(e) => { eprintln!("Sink error: {:?}", e); break; } Ok(e) => { eprintln!("Sink error: {:?}", e); break; }
Err(_) => break, Err(_) => break,
} }
} else if k == 1 {
// ── Delta d'un flat worker ────────────────────────────
// delta = N - 1 (N items poussés, 1 item consommé)
match oper.recv(&flat_delta_rx) {
Ok(delta) => {
in_flight += delta;
flat_workers_active -= 1;
}
Err(_) => {}
}
} else if src_idx == Some(k) { } else if src_idx == Some(k) {
// ── Nouvel item depuis la source ────────────────────── // ── Nouvel item depuis la source ──────────────────────
match oper.recv(&source_rx) { match oper.recv(&source_rx) {
Ok(Ok(data)) => { Ok(Ok(data)) => {
if n == 0 { if n == 0 {
let _ = sink_tx.send(data); // source → sink direct let _ = sink_tx.send(data);
} else { } else {
in_flight += 1; in_flight += 1;
let _ = worker_tx.send(( dispatch(
data, data, 0,
transforms[0].clone(), &stages, &stage_txs, &worker_tx,
stage_txs[0].clone(), &flat_delta_tx, &mut flat_workers_active,
)); );
} }
} }
Ok(Err(e)) => eprintln!("Source error: {:?}", e), Ok(Err(e)) => eprintln!("Source error: {:?}", e),
Err(_) => source_done = true, // source fermée, on continue à drainer Err(_) => source_done = true,
} }
} else { } else {
// ── Résultat d'un stage intermédiaire ───────────────── // ── Résultat d'un stage intermédiaire ─────────────────
// k ∈ [1, N] → stage = N-k // k ∈ [2, n+1] → stage = n+1 - k
let stage = n - k; let stage = n + 1 - k;
match oper.recv(&stage_rxs[stage]) { match oper.recv(&stage_rxs[stage]) {
Ok(Ok(data)) => { Ok(Ok(data)) => {
if stage == n - 1 { if stage == n - 1 {
in_flight -= 1; in_flight -= 1;
let _ = sink_tx.send(data); // dernière étape → sink let _ = sink_tx.send(data);
} else { } else {
let _ = worker_tx.send(( dispatch(
data, data, stage + 1,
transforms[stage + 1].clone(), &stages, &stage_txs, &worker_tx,
stage_txs[stage + 1].clone(), &flat_delta_tx, &mut flat_workers_active,
)); );
} }
} }
Ok(Err(e)) => eprintln!("Stage {} error: {:?}", stage, e), Ok(Err(e)) => eprintln!("Stage {} error: {:?}", stage, e),
Err(_) => break, // fermeture inattendue d'un canal Err(_) => break,
} }
} }
} }
} }
// Signaler la fin aux workers et au sink
drop(worker_tx); drop(worker_tx);
drop(sink_tx); drop(sink_tx);
@@ -613,3 +353,328 @@ where
} }
} }
} }
/// Envoie `data` au stage `stage_idx`.
/// Pour un `Transform`, empile une `WorkerTask::Transform`.
/// Pour un `Flat`, incrémente `flat_workers_active` et empile une `WorkerTask::Flat`.
#[inline]
fn dispatch<DATA>(
data: DATA,
stage_idx: usize,
stages: &[Stage<DATA>],
stage_txs: &[Sender<Result<DATA, PipelineError>>],
worker_tx: &Sender<WorkerTask<DATA>>,
flat_delta_tx: &Sender<isize>,
flat_workers_active: &mut usize,
) {
match &stages[stage_idx] {
Stage::Transform(f) => {
let _ = worker_tx.send(WorkerTask::Transform(
data,
Arc::clone(f),
stage_txs[stage_idx].clone(),
));
}
Stage::Flat(f) => {
*flat_workers_active += 1;
let _ = worker_tx.send(WorkerTask::Flat(
data,
Arc::clone(f),
stage_txs[stage_idx].clone(),
flat_delta_tx.clone(),
));
}
}
}
// ── Macros ────────────────────────────────────────────────────────────────────
/// Creates a `SourceFn` from an iterator of plain values.
#[macro_export]
macro_rules! make_source {
($enum:ident, $iterator:expr, $output:ident) => {{
let mut iter = $iterator.into_iter();
Box::new(
move || -> ::std::result::Result<$enum, $crate::PipelineError> {
match iter.next() {
Some(x) => Ok($enum::$output(x)),
None => Err($crate::PipelineError::EndOfStream),
}
},
)
as Box<dyn FnMut() -> ::std::result::Result<$enum, $crate::PipelineError> + Send>
}};
}
/// Creates a `SourceFn` from an iterator of `Result<T, E>`.
#[macro_export]
macro_rules! make_source_fallible {
($enum:ident, $iterator:expr, $output:ident) => {{
let mut iter = $iterator.into_iter();
Box::new(
move || -> ::std::result::Result<$enum, $crate::PipelineError> {
match iter.next() {
Some(Ok(x)) => Ok($enum::$output(x)),
Some(Err(e)) => Err($crate::PipelineError::StepError(Box::new(e))),
None => Err($crate::PipelineError::EndOfStream),
}
},
)
as Box<dyn FnMut() -> ::std::result::Result<$enum, $crate::PipelineError> + Send>
}};
}
/// Creates a `Stage::Transform` from a pure (non-fallible) function `Fn(T) -> U`.
#[macro_export]
macro_rules! make_transform {
($enum:ident, $func:tt, $input:ident, $output:ident) => {{
let __f = $func;
$crate::Stage::Transform(
::std::sync::Arc::from(Box::new(
move |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> {
match data {
$enum::$input(x) => Ok($enum::$output(__f(x))),
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<
dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError>
+ Send
+ Sync,
>)
)
}};
}
/// Creates a `Stage::Transform` from a fallible function `Fn(T) -> Result<U, E>`.
#[macro_export]
macro_rules! make_transform_fallible {
($enum:ident, $func:tt, $input:ident, $output:ident) => {{
let __f = $func;
$crate::Stage::Transform(
::std::sync::Arc::from(Box::new(
move |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> {
match data {
$enum::$input(inner) => {
let result = __f(inner)
.map_err(|e| $crate::PipelineError::StepError(Box::new(e)))?;
Ok($enum::$output(result))
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<
dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError>
+ Send
+ Sync,
>)
)
}};
}
/// Creates a `Stage::Flat` from a function `Fn(T) -> impl IntoIterator<Item = U>`.
///
/// Pour chaque item produit par l'itérateur, il est poussé individuellement dans
/// le canal de sortie, permettant au scheduler de dispatcher les items en parallèle
/// dès qu'un worker est disponible.
#[macro_export]
macro_rules! make_flat_transform {
($enum:ident, $func:tt, $input:ident, $output:ident) => {{
let __f = $func;
$crate::Stage::Flat(
::std::sync::Arc::new(
move |data: $enum,
push: &$crate::PipelineSender<
::std::result::Result<$enum, $crate::PipelineError>,
>,
delta: &$crate::PipelineSender<isize>| {
match data {
$enum::$input(inner) => {
let mut count: isize = 0;
for item in __f(inner) {
push.send(Ok($enum::$output(item))).ok();
count += 1;
}
delta.send(count - 1).ok();
}
_ => {
push.send(Err($crate::PipelineError::TypeMismatch)).ok();
delta.send(0).ok();
}
}
},
) as $crate::SharedFlatFn<$enum>
)
}};
}
/// Creates a `Stage::Flat` from a fallible function
/// `Fn(T) -> Result<impl IntoIterator<Item = U>, E>`.
///
/// Si la fonction retourne `Err`, une erreur est poussée dans le canal et aucun
/// item normal n'est produit.
#[macro_export]
macro_rules! make_flat_transform_fallible {
($enum:ident, $func:tt, $input:ident, $output:ident) => {{
let __f = $func;
$crate::Stage::Flat(
::std::sync::Arc::new(
move |data: $enum,
push: &$crate::PipelineSender<
::std::result::Result<$enum, $crate::PipelineError>,
>,
delta: &$crate::PipelineSender<isize>| {
match data {
$enum::$input(inner) => match __f(inner) {
Ok(iter) => {
let mut count: isize = 0;
for item in iter {
push.send(Ok($enum::$output(item))).ok();
count += 1;
}
delta.send(count - 1).ok();
}
Err(e) => {
push.send(Err($crate::PipelineError::StepError(Box::new(e))))
.ok();
delta.send(0).ok();
}
},
_ => {
push.send(Err($crate::PipelineError::TypeMismatch)).ok();
delta.send(0).ok();
}
}
},
) as $crate::SharedFlatFn<$enum>
)
}};
}
/// Creates a `SinkFn` from a function that consumes a concrete value and returns `()`.
#[macro_export]
macro_rules! make_sink {
($enum:ident, $func:tt, $input:ident) => {{
let __f = $func;
Box::new(
move |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> {
match data {
$enum::$input(x) => {
__f(x);
Ok(())
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<dyn Fn($enum) -> ::std::result::Result<(), $crate::PipelineError> + Send>
}};
}
/// Creates a `SinkFn` from a fallible function that returns `Result<(), E>`.
#[macro_export]
macro_rules! make_sink_fallible {
($enum:ident, $func:tt, $input:ident) => {{
let __f = $func;
Box::new(
move |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> {
match data {
$enum::$input(inner) => {
__f(inner).map_err(|e| $crate::PipelineError::StepError(Box::new(e)))
}
_ => Err($crate::PipelineError::TypeMismatch),
}
},
)
as Box<dyn Fn($enum) -> ::std::result::Result<(), $crate::PipelineError> + Send>
}};
}
/// Construit un `Pipeline` à partir d'une source, d'une liste de stages et d'un sink.
///
/// Syntaxe :
/// ```ignore
/// make_pipeline! {
/// MyData,
/// source my_iter => Variant, // source non-fallible
/// source? my_iter => Variant, // source fallible (Result<T, E>)
/// | func: In => Out, // transform 1→1 non-fallible
/// |? func: In => Out, // transform 1→1 fallible
/// || func: In => Out, // flat transform 1→N non-fallible
/// ||? func: In => Out, // flat transform 1→N fallible
/// sink my_func @ Variant, // sink non-fallible
/// sink? my_func @ Variant, // sink fallible
/// }
/// ```
#[macro_export]
macro_rules! make_pipeline {
// ── Points d'entrée ──────────────────────────────────────────────────
($enum:ident, source $src:expr => $src_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum,
{ $crate::make_source!($enum, $src, $src_out) },
[],
$($rest)*)
};
($enum:ident, source? $src:expr => $src_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum,
{ $crate::make_source_fallible!($enum, $src, $src_out) },
[],
$($rest)*)
};
// ── Accumulation des stages ──────────────────────────────────────────
// transform 1→1 non-fallible
(@build $enum:ident, $source:tt, [$($acc:tt)*],
| $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_transform!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// transform 1→1 fallible
(@build $enum:ident, $source:tt, [$($acc:tt)*],
|? $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_transform_fallible!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// flat transform 1→N non-fallible
(@build $enum:ident, $source:tt, [$($acc:tt)*],
|| $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_flat_transform!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// flat transform 1→N fallible
(@build $enum:ident, $source:tt, [$($acc:tt)*],
||? $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => {
$crate::make_pipeline!(@build $enum, $source,
[$($acc)* $crate::make_flat_transform_fallible!($enum, $tf, $t_in, $t_out),],
$($rest)*)
};
// ── Terminaison : sink ───────────────────────────────────────────────
(@build $enum:ident, $source:tt, [$($acc:tt)*],
sink $sink_fn:tt @ $sink_in:ident $(,)?) => {
$crate::Pipeline::new(
$source,
vec![$($acc)*],
$crate::make_sink!($enum, $sink_fn, $sink_in),
)
};
(@build $enum:ident, $source:tt, [$($acc:tt)*],
sink? $sink_fn:tt @ $sink_in:ident $(,)?) => {
$crate::Pipeline::new(
$source,
vec![$($acc)*],
$crate::make_sink_fallible!($enum, $sink_fn, $sink_in),
)
};
}