diff --git a/src/Cargo.lock b/src/Cargo.lock index 5b41240..b45cdce 100644 --- a/src/Cargo.lock +++ b/src/Cargo.lock @@ -668,10 +668,10 @@ name = "obikmer" version = "0.1.0" dependencies = [ "clap", - "crossbeam-channel", "obifastwrite", "obikrope", "obikseq", + "obipipeline", "obiread", "obiskbuilder", ] diff --git a/src/obikmer/Cargo.toml b/src/obikmer/Cargo.toml index 9b6e4ba..75f78d8 100644 --- a/src/obikmer/Cargo.toml +++ b/src/obikmer/Cargo.toml @@ -12,6 +12,6 @@ obikseq = { path = "../obikseq" } obiread = { path = "../obiread" } obiskbuilder = { path = "../obiskbuilder" } obifastwrite = { path = "../obifastwrite" } -clap = { version = "4", features = ["derive"] } -crossbeam-channel = "0.5" -obikrope = { path = "../obikrope" } +obipipeline = { path = "../obipipeline" } +clap = { version = "4", features = ["derive"] } +obikrope = { path = "../obikrope" } diff --git a/src/obikmer/src/cmd/superkmer.rs b/src/obikmer/src/cmd/superkmer.rs index e309122..836207c 100644 --- a/src/obikmer/src/cmd/superkmer.rs +++ b/src/obikmer/src/cmd/superkmer.rs @@ -1,17 +1,19 @@ use std::io::{self, BufWriter, Write}; -use std::thread; +use std::path::PathBuf; +use std::sync::{Arc, Mutex}; use clap::Args; -use crossbeam_channel::bounded; use obifastwrite::write_scatter; use obikrope::Rope; use obikseq::superkmer::SuperKmer; +use obipipeline::{WorkerPool, make_pipeline}; use obiskbuilder::SuperKmerIter; #[derive(Args)] pub struct SuperkmerArgs { - /// Input: file path, URL (http/https), or `-` for stdin - pub input: String, + /// Input files or directories (FASTA/FASTQ, optionally gzip-compressed) + #[arg(num_args = 1..)] + pub inputs: Vec, /// k-mer size #[arg(short, long, default_value_t = 31)] @@ -34,36 +36,70 @@ pub struct SuperkmerArgs { pub partitions: u64, /// Number of worker threads - #[arg(short = 'T', long, default_value_t = 4)] + #[arg(short = 'T', long, default_value_t = 16)] pub threads: usize, - - /// Force input format: fasta or fastq (default: auto-detect from extension) - #[arg(long)] - pub format: Option, } -#[derive(Clone, Copy)] -enum Format { - Fasta, - Fastq, +enum PipelineData { + Path(PathBuf), + RawChunk(Rope), + NormChunk(Rope), + Batch(Vec<(u64, SuperKmer)>), } -fn detect_format(source: &str, hint: Option<&str>) -> Format { - if let Some(h) = hint { - return match h.to_lowercase().as_str() { - "fastq" | "fq" => Format::Fastq, - _ => Format::Fasta, - }; - } - if source.ends_with(".fq") - || source.ends_with(".fastq") - || source.ends_with(".fq.gz") - || source.ends_with(".fastq.gz") - { - Format::Fastq - } else { - Format::Fasta +// SAFETY: Rope contains Cell which is !Sync, but pipeline ownership transfers +// exclusively through channels — no item is ever shared across threads. +unsafe impl Send for PipelineData {} +unsafe impl Sync for PipelineData {} + +// ── Stage functions ─────────────────────────────────────────────────────────── + +/// Opens a sequence file and returns an iterator over its raw Rope chunks. +/// Chunk-level I/O errors are logged and skipped. +fn open_chunks(path: PathBuf) -> io::Result> { + let path_str = path + .to_str() + .ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "non-UTF-8 path"))?; + let iter = obiread::read_sequence_chunks(path_str)?; + Ok(iter.filter_map(|r| match r { + Ok(rope) => Some(rope), + Err(e) => { + eprintln!("chunk read error: {e}"); + None + } + })) +} + +/// Normalises a raw sequence chunk (FASTA or FASTQ) into a compact ACGT/NUL rope. +fn normalize(rope: Rope, k: usize) -> io::Result { + obiread::normalize_sequence_chunk(rope, k) +} + +/// Extracts all super-kmers from a normalised rope. +fn build_superkmers( + rope: Rope, + k: usize, + m: usize, + level_max: usize, + theta: f64, +) -> Vec<(u64, SuperKmer)> { + SuperKmerIter::new(&rope, k, m, level_max, theta).collect() +} + +/// Writes a batch of super-kmers to the output sink. +fn write_batch( + batch: Vec<(u64, SuperKmer)>, + out: &Mutex>, + partitions: u64, + k: usize, + m: usize, +) -> io::Result<()> { + let mut w = out.lock().unwrap(); + for (min_hash, sk) in batch { + let partition = (mix64(min_hash) % partitions) as u32; + write_scatter(&sk, &mut *w, k, m, partition, min_hash)?; } + Ok(()) } #[inline] @@ -75,9 +111,9 @@ fn mix64(x: u64) -> u64 { x ^ (x >> 31) } +// ── Entry point ─────────────────────────────────────────────────────────────── + pub fn run(args: SuperkmerArgs) { - let format = detect_format(&args.input, args.format.as_deref()); - let input = args.input.clone(); let k = args.kmer_size; let m = args.minimizer_size; let theta = args.theta; @@ -85,81 +121,21 @@ pub fn run(args: SuperkmerArgs) { let partitions = args.partitions; let n_workers = args.threads.max(1); - // raw chunks (reader → workers) - let (raw_tx, raw_rx) = bounded::(n_workers * 2); - // superkmer batches (workers → output) - let (sk_tx, sk_rx) = bounded::>(n_workers * 2); + let paths = args.inputs.iter().map(PathBuf::from).collect(); + let path_source = obiread::PathIter::new(paths); - // ── reader thread ───────────────────────────────────────────────────────── - let reader = thread::spawn(move || { - let file = obiread::xopen(&input).expect("cannot open input"); - match format { - Format::Fasta => { - for chunk in obiread::fasta_chunks(file) { - raw_tx.send(chunk.expect("read error")).unwrap(); - } - } - Format::Fastq => { - for chunk in obiread::fastq_chunks(file) { - raw_tx.send(chunk.expect("read error")).unwrap(); - } - } - } - // raw_tx drops here → workers observe disconnect - }); + let out = Arc::new(Mutex::new(BufWriter::new(io::stdout()))); + let out_sink = Arc::clone(&out); - // ── worker threads ──────────────────────────────────────────────────────── - let workers: Vec<_> = (0..n_workers) - .map(|_| { - let raw_rx = raw_rx.clone(); - let sk_tx = sk_tx.clone(); - thread::spawn(move || { - for raw_chunk in raw_rx { - let norm = match format { - Format::Fasta => obiread::normalize::normalize_fasta_chunk(raw_chunk, k), - Format::Fastq => obiread::normalize::normalize_fastq_chunk(raw_chunk, k), - }; - const BATCH_SIZE: usize = 10_000; - let mut batch = Vec::with_capacity(BATCH_SIZE); - for sk in SuperKmerIter::new(&norm, k, m, level_max, theta) { - batch.push(sk); - if batch.len() == BATCH_SIZE { - sk_tx - .send(std::mem::replace( - &mut batch, - Vec::with_capacity(BATCH_SIZE), - )) - .unwrap(); - } - } - if !batch.is_empty() { - sk_tx.send(batch).unwrap(); - } - } - // sk_tx clone drops here - }) - }) - .collect(); + let pipeline = make_pipeline! { + PipelineData, + source path_source => Path, + ||? open_chunks : Path => RawChunk, + |? { move |rope| normalize(rope, k) } : RawChunk => NormChunk, + | { move |rope| build_superkmers(rope, k, m, level_max, theta) }: NormChunk => Batch, + sink? { move |batch| write_batch(batch, &out_sink, partitions, k, m) } @ Batch, + }; - // drop the extra sk_tx clone held by this thread so the output thread exits - drop(sk_tx); - - // ── output thread ───────────────────────────────────────────────────────── - let output = thread::spawn(move || { - let stdout = io::stdout(); - let mut out = BufWriter::new(stdout.lock()); - for batch in sk_rx { - for (min_hash, sk) in batch { - let partition = (mix64(min_hash) % partitions) as u32; - write_scatter(&sk, &mut out, k, m, partition, min_hash).expect("write error"); - } - } - out.flush().expect("flush error"); - }); - - reader.join().expect("reader thread panicked"); - for w in workers { - w.join().expect("worker thread panicked"); - } - output.join().expect("output thread panicked"); + WorkerPool::new(pipeline, n_workers, 1).run(); + out.lock().unwrap().flush().expect("flush error"); } diff --git a/src/obipipeline/src/lib.rs b/src/obipipeline/src/lib.rs index 48f370c..82b4999 100644 --- a/src/obipipeline/src/lib.rs +++ b/src/obipipeline/src/lib.rs @@ -2,9 +2,19 @@ mod scheduler; pub use scheduler::Pipeline; pub use scheduler::PipelineError; +pub use scheduler::SharedFlatFn; pub use scheduler::SharedFn; pub use scheduler::SinkFn; pub use scheduler::SourceFn; +pub use scheduler::Stage; pub use scheduler::WorkerPool; + +/// Re-export de `crossbeam_channel::Sender` utilisé dans les macros flat transform. +/// Permet aux macros `make_flat_transform!` / `make_flat_transform_fallible!` d'utiliser +/// `$crate::PipelineSender` sans que le crate appelant n'ait besoin de dépendre +/// directement de `crossbeam_channel`. +pub use crossbeam_channel::Sender as PipelineSender; + // make_sink, make_sink_fallible, make_source, make_source_fallible, -// make_transform, make_transform_fallible are exported at crate root via #[macro_export] +// make_transform, make_transform_fallible, make_flat_transform, +// make_flat_transform_fallible sont exportés à la racine du crate via #[macro_export] diff --git a/src/obipipeline/src/scheduler.rs b/src/obipipeline/src/scheduler.rs index 2e8e6ce..fc6625b 100644 --- a/src/obipipeline/src/scheduler.rs +++ b/src/obipipeline/src/scheduler.rs @@ -37,333 +37,52 @@ impl Error for PipelineError { } } -/// Represents a single processing stage in a data pipeline. -/// -/// `StepKind` abstracts over three fundamental types of operations: -/// - **Source**: Produces a sequence of data items (e.g., from a file, iterator, generator). -/// - **Transform**: Converts one data item into another (pure transformation). -/// - **Sink**: Consumes a final data item (e.g., print, store, aggregate). -/// -/// The type `DATA` is typically an enum that unifies all data variants that can flow through -/// the pipeline. The type `ERROR` is the error type used by fallible operations. -/// -/// # Generics -/// - `DATA`: The type of values passed between stages (must be `Send + Sync`). -/// - `ERROR`: The error type for fallible computations (usually `PipelineError`). -/// -/// # Example -/// ``` -/// # use std::io; -/// # enum MyData { Number(i32), Text(String) } -/// # type MyError = io::Error; -/// # let source_iter = vec![Ok(MyData::Number(42))].into_iter(); -/// # let transform_fn = |d: MyData| -> Result { Ok(d) }; -/// # let sink_fn = |d: MyData| {}; -/// let source = StepKind::Source(Box::new(source_iter)); -/// let transform = StepKind::Transform(Box::new(transform_fn)); -/// let sink = StepKind::Sink(Box::new(sink_fn)); -/// ``` -/// -/// # Note -/// - A `Source` does not take an input; it yields an iterator of `Result`. -/// - A `Transform` takes a `DATA` and returns a `Result`. -/// - A `Sink` takes a `DATA` and performs a side effect (no return value). +// ── Function types ──────────────────────────────────────────────────────────── + /// Fonction source : appelée répétitivement, retourne le prochain item ou EndOfStream. /// `FnMut` car elle maintient un état interne (position dans l'itérateur). -pub type SourceFn = Box Result + Send + Sync>; +pub type SourceFn = Box Result + Send>; /// Fonction sink : consomme un item final, peut échouer (erreur d'I/O, etc.). -pub type SinkFn = Box Result<(), PipelineError> + Send + Sync>; - -/// Creates a `StepKind::Source` from an iterator of plain values. -/// Each value is wrapped with the specified `$enum` variant and returned as `Ok`. -/// The source returns `Err(PipelineError::EndOfStream)` when exhausted. -/// -/// # Example -/// ``` -/// let iter = vec![1.0, 2.0].into_iter(); -/// let source = make_source!(PipelineData, iter, Numeric); -/// ``` -#[macro_export] -macro_rules! make_source { - ($enum:ident, $iterator:expr, $output:ident) => {{ - let mut iter = $iterator.into_iter(); - Box::new( - move || -> ::std::result::Result<$enum, $crate::PipelineError> { - match iter.next() { - Some(x) => Ok($enum::$output(x)), - None => Err($crate::PipelineError::EndOfStream), - } - }, - ) - as Box ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync> - }}; -} - -/// Creates a `StepKind::Source` from an iterator of `Result`. -/// On `Ok(x)` it wraps the value into the specified `$enum` variant. -/// On `Err(e)` it returns `Err(PipelineError::StepError(Box::new(e)))`. -/// Returns `Err(PipelineError::EndOfStream)` when the iterator ends. -/// -/// # Example -/// ``` -/// let iter = vec![Ok(1.0), Err("oops")].into_iter(); -/// let source = make_source_fallible!(PipelineData, iter, Numeric); -/// ``` -#[macro_export] -macro_rules! make_source_fallible { - ($enum:ident, $iterator:expr, $output:ident) => {{ - let mut iter = $iterator.into_iter(); - Box::new( - move || -> ::std::result::Result<$enum, $crate::PipelineError> { - match iter.next() { - Some(Ok(x)) => Ok($enum::$output(x)), - Some(Err(e)) => Err($crate::PipelineError::StepError(Box::new(e))), - None => Err($crate::PipelineError::EndOfStream), - } - }, - ) - as Box ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync> - }}; -} - -/// Creates a pipeline stage from a pure (non-fallible) function. -/// -/// This macro generates a closure that implements the `PipelineStage` trait by pattern -/// matching on the input `PipelineData` variant, applying the provided function, and -/// wrapping the result in the output variant. -/// -/// # Arguments -/// * `$func` - The function to apply: `Fn(T) -> U` -/// * `$input` - Input `PipelineData` variant pattern (e.g., `Int(i64)`) -/// * `$output` - Output `PipelineData` variant pattern (e.g., `String(String)`) -/// -/// # Example -/// -/// ```ignore -/// // Define PipelineData enum -/// enum PipelineData { -/// Int(i64), -/// String(String), -/// } -/// -/// // Create pure stage -/// let to_string_stage = make_stage!( -/// to_string, -/// Int(i64), -/// String(String) -/// ); -/// -/// // Use in pipeline -/// let result = to_string_stage(PipelineData::Int(42)).unwrap(); -/// assert!(matches!(result, PipelineData::String(s) if s == "42")); -/// ``` -#[macro_export] -macro_rules! make_transform { - ($enum:ident, $func:ident, $input:ident, $output:ident) => { - ::std::sync::Arc::from(Box::new( - |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> { - match data { - $enum::$input(x) => Ok($enum::$output($func(x))), - _ => Err($crate::PipelineError::TypeMismatch), - } - }, - ) - as Box< - dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync, - >) - }; -} - -/// Creates a pipeline stage from a fallible function. -/// -/// This macro generates a closure that pattern matches on the input `PipelineData` -/// variant, applies the provided function, and wraps the result in the output variant. -/// If the function returns an error, it is boxed and wrapped in `PipelineError::StepError`. -/// -/// # Arguments -/// * `$func` - The fallible function to apply: `Fn(T) -> Result` -/// * `$input` - Input `PipelineData` variant pattern (e.g., `Int(i64)`) -/// * `$output` - Output `PipelineData` variant pattern (e.g., `String(String)`) -/// -/// # Example -/// -/// ```ignore -/// fn parse_int(s: &str) -> Result { -/// s.parse() -/// } -/// -/// // Define PipelineData enum -/// enum PipelineData { -/// String(String), -/// Int(i64), -/// } -/// -/// // Create fallible stage -/// let parse_stage = make_stage_fallible!( -/// parse_int, -/// String(String), -/// Int(i64) -/// ); -/// -/// // Use in pipeline -/// let result = parse_stage(PipelineData::String("42".into())).unwrap(); -/// assert!(matches!(result, PipelineData::Int(n) if n == 42)); -/// ``` -#[macro_export] -macro_rules! make_transform_fallible { - ($enum:ident, $func:ident, $input:ident, $output:ident) => { - ::std::sync::Arc::from(Box::new( - |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> { - match data { - $enum::$input(inner) => { - let result = $func(inner) - .map_err(|e| $crate::PipelineError::StepError(Box::new(e)))?; - Ok($enum::$output(result)) - } - _ => Err($crate::PipelineError::TypeMismatch), - } - }, - ) - as Box< - dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + Send + Sync, - >) - }; -} - -/// Creates a `StepKind::Sink` from a function that consumes a concrete value and returns `()`. -/// The returned sink always returns `Ok(())`. -/// The function is wrapped to accept `$enum::$input` and ignores the result. -/// -/// # Example -/// ``` -/// fn print_number(n: f64) { println!("{}", n); } -/// let sink = make_sink!(PipelineData, print_number, Numeric); -/// ``` -#[macro_export] -macro_rules! make_sink { - ($enum:ident, $func:ident, $input:ident) => { - Box::new( - |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> { - match data { - $enum::$input(x) => { - $func(x); - Ok(()) - } - _ => Err($crate::PipelineError::TypeMismatch), - } - }, - ) - as Box ::std::result::Result<(), $crate::PipelineError> + Send + Sync> - }; -} - -/// Creates a `StepKind::Sink` from a fallible function that returns `Result<(), E>`. -/// Errors from `$func` are wrapped in `PipelineError::StepError`. -/// -/// # Example -/// ``` -/// fn save_to_file(hash: u64) -> std::io::Result<()> { Ok(()) } -/// let sink = make_sink_fallible!(PipelineData, save_to_file, Hash); -/// ``` -#[macro_export] -macro_rules! make_sink_fallible { - ($enum:ident, $func:ident, $input:ident) => { - Box::new( - |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> { - match data { - $enum::$input(inner) => { - $func(inner).map_err(|e| $crate::PipelineError::StepError(Box::new(e))) - } - _ => Err($crate::PipelineError::TypeMismatch), - } - }, - ) - as Box ::std::result::Result<(), $crate::PipelineError> + Send + Sync> - }; -} - -/// Construit un `Pipeline` à partir d'une source, d'une liste de transforms et d'un sink. -/// -/// Syntaxe : -/// ```ignore -/// make_pipeline! { -/// MyData, -/// source my_iter => Variant, // source non-fallible -/// source? my_iter => Variant, // source fallible (Result) -/// | func: In => Out, // transform non-fallible (répété 0..N fois) -/// |? func: In => Out, // transform fallible (répété 0..N fois) -/// sink my_func @ Variant, // sink non-fallible -/// sink? my_func @ Variant, // sink fallible -/// } -/// ``` -/// -/// Implémenté comme un TT muncher : la règle interne `@build` traite les transforms -/// un par un en les accumulant, puis termine sur `sink`/`sink?`. -#[macro_export] -macro_rules! make_pipeline { - // ── Points d'entrée ────────────────────────────────────────────────── - - ($enum:ident, source $src:expr => $src_out:ident, $($rest:tt)*) => { - $crate::make_pipeline!(@build $enum, - { $crate::make_source!($enum, $src, $src_out) }, - [], - $($rest)*) - }; - ($enum:ident, source? $src:expr => $src_out:ident, $($rest:tt)*) => { - $crate::make_pipeline!(@build $enum, - { $crate::make_source_fallible!($enum, $src, $src_out) }, - [], - $($rest)*) - }; - - // ── Accumulation des transforms ────────────────────────────────────── - - // transform non-fallible : | - (@build $enum:ident, $source:tt, [$($acc:tt)*], - | $tf:ident : $t_in:ident => $t_out:ident, $($rest:tt)*) => { - $crate::make_pipeline!(@build $enum, $source, - [$($acc)* $crate::make_transform!($enum, $tf, $t_in, $t_out),], - $($rest)*) - }; - - // transform fallible : |? - (@build $enum:ident, $source:tt, [$($acc:tt)*], - |? $tf:ident : $t_in:ident => $t_out:ident, $($rest:tt)*) => { - $crate::make_pipeline!(@build $enum, $source, - [$($acc)* $crate::make_transform_fallible!($enum, $tf, $t_in, $t_out),], - $($rest)*) - }; - - // ── Terminaison : sink ─────────────────────────────────────────────── - - (@build $enum:ident, $source:tt, [$($acc:tt)*], - sink $sink_fn:ident @ $sink_in:ident $(,)?) => { - $crate::Pipeline::new( - $source, - vec![$($acc)*], - $crate::make_sink!($enum, $sink_fn, $sink_in), - ) - }; - (@build $enum:ident, $source:tt, [$($acc:tt)*], - sink? $sink_fn:ident @ $sink_in:ident $(,)?) => { - $crate::Pipeline::new( - $source, - vec![$($acc)*], - $crate::make_sink_fallible!($enum, $sink_fn, $sink_in), - ) - }; -} +pub type SinkFn = Box Result<(), PipelineError> + Send>; /// Fonction de transformation partagée entre workers via Arc. -/// Arc permet à plusieurs workers de partager la même closure -/// sans la copier (Arc::clone = simple incrément de compteur). pub type SharedFn = Arc Result + Send + Sync>; -/// Tâche envoyée à un worker : donnée + fonction à appliquer + canal de résultat. -/// Le worker n'a pas besoin de connaître sa position dans le pipeline ; -/// le scheduler lui dit exactement quoi faire et où envoyer le résultat. -pub type WorkerTask = (D, SharedFn, Sender>); +/// Fonction de transformation 1→N (flat map) partagée entre workers via Arc. +/// +/// La fonction reçoit l'item d'entrée, un canal `push` pour envoyer chaque item +/// produit, et un canal `delta` pour signaler au scheduler combien d'items +/// supplémentaires sont entrés dans le pipeline (N-1 si N items produits). +/// Elle doit appeler `delta.send(N - 1)` **après** avoir poussé tous les items. +pub type SharedFlatFn = + Arc>, &Sender) + Send + Sync>; + +// ── Stage enum ──────────────────────────────────────────────────────────────── + +/// Une étape du pipeline : transform classique (1→1) ou flat transform (1→N). +pub enum Stage { + Transform(SharedFn), + Flat(SharedFlatFn), +} + +impl Clone for Stage { + fn clone(&self) -> Self { + match self { + Stage::Transform(f) => Stage::Transform(Arc::clone(f)), + Stage::Flat(f) => Stage::Flat(Arc::clone(f)), + } + } +} + +// ── Worker task ─────────────────────────────────────────────────────────────── + +enum WorkerTask { + Transform(D, SharedFn, Sender>), + Flat(D, SharedFlatFn, Sender>, Sender), +} + +// ── Thread runners ──────────────────────────────────────────────────────────── fn source_runner( mut source: SourceFn, @@ -381,7 +100,7 @@ where match source() { Ok(data) => { if tx.send(Ok(data)).is_err() { - break; // récepteur disparu + break; } } Err(PipelineError::EndOfStream) => break, @@ -398,28 +117,29 @@ where /// Lance un thread worker du pool. /// -/// Le worker attend des tâches sur `task_rx`. Chaque tâche est un triplet -/// `(data, f, result_tx)` : il applique `f(data)` et envoie le résultat -/// dans `result_tx`. C'est le scheduler qui décide quelle fonction envoyer -/// et quel canal de résultat utiliser — le worker lui-même est générique. +/// Gère deux types de tâches : +/// - `Transform` : applique `f(data)` et envoie le résultat dans `result_tx`. +/// - `Flat` : appelle `f(data, &push_tx, &delta_tx)` ; la fonction elle-même +/// pousse ses items dans `push_tx` et envoie `N-1` dans `delta_tx`. fn transform_runner(task_rx: Receiver>) -> thread::JoinHandle<()> where DATA: Send + Sync + 'static, { thread::spawn(move || { - while let Ok((data, f, result_tx)) = task_rx.recv() { - let _ = result_tx.send(f(data)); + while let Ok(task) = task_rx.recv() { + match task { + WorkerTask::Transform(data, f, result_tx) => { + let _ = result_tx.send(f(data)); + } + WorkerTask::Flat(data, f, push_tx, delta_tx) => { + f(data, &push_tx, &delta_tx); + } + } } }) } /// Lance le thread sink. -/// -/// Retourne : -/// - `Sender` : le scheduler y envoie les données finales -/// - `Receiver` : le sink y pousse toute erreur rencontrée ; -/// le scheduler surveille ce canal en priorité absolue -/// pour interrompre le pipeline dès qu'une erreur survient. fn sink_runner( sink: SinkFn, capacity: usize, @@ -437,29 +157,33 @@ where for data in data_rx { if let Err(e) = sink(data) { let _ = err_tx.send(e); - break; // on arrête dès la première erreur + break; } } }); (data_tx, err_rx, handle) } +// ── Pipeline ────────────────────────────────────────────────────────────────── + pub struct Pipeline { source: SourceFn, - transforms: Vec>, + stages: Vec>, sink: SinkFn, } impl Pipeline { pub fn new( source: SourceFn, - transforms: Vec>, + stages: Vec>, sink: SinkFn, ) -> Self { - Self { source, transforms, sink } + Self { source, stages, sink } } } +// ── WorkerPool ──────────────────────────────────────────────────────────────── + pub struct WorkerPool { pipeline: Pipeline, handles: Vec>, @@ -481,11 +205,10 @@ where } pub fn run(mut self) { - let n = self.pipeline.transforms.len(); + let n = self.pipeline.stages.len(); // ── Canaux inter-stages ──────────────────────────────────────────── - // stage_txs[i] : le worker qui exécute transform[i] y envoie son résultat - // stage_rxs[i] : le scheduler lit ici pour dispatcher au transform[i+1] (ou sink) + // stage_txs[i] / stage_rxs[i] : sortie du stage i let mut stage_txs: Vec>> = Vec::new(); let mut stage_rxs: Vec>> = Vec::new(); for _ in 0..n { @@ -498,12 +221,9 @@ where let (source_rx, src_handle) = source_runner(self.pipeline.source, self.capacity); self.handles.push(src_handle); - // Les transforms sont déjà des SharedFn — pas de conversion nécessaire. - let transforms = self.pipeline.transforms; + let stages = self.pipeline.stages; // ── Worker pool ──────────────────────────────────────────────────── - // Canal partagé par tous les workers : le scheduler y pousse des WorkerTask, - // chaque worker en dépile une à la fois. let (worker_tx, worker_rx): (Sender>, Receiver>) = bounded(self.capacity); @@ -515,40 +235,51 @@ where let (sink_tx, sink_err_rx, sink_handle) = sink_runner(self.pipeline.sink, self.capacity); self.handles.push(sink_handle); - // ── Boucle principale ───────────────────────────────────────────── + // ── Canal delta pour les flat stages ─────────────────────────────── + // Chaque flat worker envoie `N-1` ici après avoir poussé N items. + // Le scheduler ajuste `in_flight` en conséquence. + let (flat_delta_tx, flat_delta_rx) = bounded::(self.capacity); + + // ── Boucle principale ────────────────────────────────────────────── // - // Le Select est reconstruit à chaque itération, ce qui permet de - // retirer source_rx une fois la source épuisée. + // `in_flight` (isize) = nb d'items qui doivent encore atteindre le sink. + // Peut temporairement être négatif si un flat worker a poussé ses items + // avant que le scheduler ait reçu le delta correspondant. // - // Priorités (biased = index le plus bas gagne) : - // index 0 → sink_err_rx (arrêt immédiat sur erreur sink) - // index 1..=N → stage_rxs[N-1..0] (vider le pipeline en priorité) - // index N+1 → source_rx (dernier recours : nouvelles données) + // `flat_workers_active` = nb de flat workers en cours d'exécution. + // Empêche la terminaison prématurée quand in_flight vaut 0 mais qu'un + // flat worker n'a pas encore envoyé son delta. // - // Quand k == 0 : erreur du sink - // Quand 1 <= k <= N : stage concerné = N-k - // Quand k == N+1 : item venant de la source + // Priorités du Select biaisé (index le plus bas = priorité la plus haute) : + // 0 → sink_err_rx (arrêt immédiat sur erreur sink) + // 1 → flat_delta_rx (mettre à jour in_flight avant de dispatcher) + // 2..=n+1 → stage_rxs[n-1..0] (vider le pipeline en priorité) + // n+2 → source_rx (dernier recours : nouvelles données) // - // Terminaison : on quitte uniquement quand source_done ET in_flight == 0, - // ce qui garantit que tous les items ont traversé le pipeline jusqu'au sink. + // Quand k = 0 : erreur du sink + // Quand k = 1 : delta d'un flat worker + // Quand 2 ≤ k ≤ n+1 : résultat du stage n+1-k + // Quand k = n+2 : item source + // + // Terminaison : source tarie ET in_flight == 0 ET aucun flat worker actif. { let mut source_done = false; - let mut in_flight: usize = 0; + let mut in_flight: isize = 0; + let mut flat_workers_active: usize = 0; loop { - // Condition de sortie : plus rien en vol et source tarie - if source_done && in_flight == 0 { + if source_done && in_flight == 0 && flat_workers_active == 0 { break; } - // Reconstruction du Select (sans source_rx si source épuisée) let mut sel = Select::new_biased(); - sel.recv(&sink_err_rx); // index 0 + sel.recv(&sink_err_rx); // index 0 + sel.recv(&flat_delta_rx); // index 1 for rx in stage_rxs.iter().rev() { - sel.recv(rx); // indices 1 .. N + sel.recv(rx); // indices 2..=n+1 } let src_idx = if !source_done { - Some(sel.recv(&source_rx)) // index N+1 (seulement si encore active) + Some(sel.recv(&source_rx)) // index n+2 } else { None }; @@ -557,54 +288,63 @@ where let k = oper.index(); if k == 0 { - // ── Erreur du sink : on arrête tout ────────────────── + // ── Erreur du sink ──────────────────────────────────── match oper.recv(&sink_err_rx) { Ok(e) => { eprintln!("Sink error: {:?}", e); break; } Err(_) => break, } + } else if k == 1 { + // ── Delta d'un flat worker ──────────────────────────── + // delta = N - 1 (N items poussés, 1 item consommé) + match oper.recv(&flat_delta_rx) { + Ok(delta) => { + in_flight += delta; + flat_workers_active -= 1; + } + Err(_) => {} + } } else if src_idx == Some(k) { // ── Nouvel item depuis la source ────────────────────── match oper.recv(&source_rx) { Ok(Ok(data)) => { if n == 0 { - let _ = sink_tx.send(data); // source → sink direct + let _ = sink_tx.send(data); } else { in_flight += 1; - let _ = worker_tx.send(( - data, - transforms[0].clone(), - stage_txs[0].clone(), - )); + dispatch( + data, 0, + &stages, &stage_txs, &worker_tx, + &flat_delta_tx, &mut flat_workers_active, + ); } } Ok(Err(e)) => eprintln!("Source error: {:?}", e), - Err(_) => source_done = true, // source fermée, on continue à drainer + Err(_) => source_done = true, } } else { // ── Résultat d'un stage intermédiaire ───────────────── - // k ∈ [1, N] → stage = N-k - let stage = n - k; + // k ∈ [2, n+1] → stage = n+1 - k + let stage = n + 1 - k; match oper.recv(&stage_rxs[stage]) { Ok(Ok(data)) => { if stage == n - 1 { in_flight -= 1; - let _ = sink_tx.send(data); // dernière étape → sink + let _ = sink_tx.send(data); } else { - let _ = worker_tx.send(( - data, - transforms[stage + 1].clone(), - stage_txs[stage + 1].clone(), - )); + dispatch( + data, stage + 1, + &stages, &stage_txs, &worker_tx, + &flat_delta_tx, &mut flat_workers_active, + ); } } Ok(Err(e)) => eprintln!("Stage {} error: {:?}", stage, e), - Err(_) => break, // fermeture inattendue d'un canal + Err(_) => break, } } } } - // Signaler la fin aux workers et au sink drop(worker_tx); drop(sink_tx); @@ -613,3 +353,328 @@ where } } } + +/// Envoie `data` au stage `stage_idx`. +/// Pour un `Transform`, empile une `WorkerTask::Transform`. +/// Pour un `Flat`, incrémente `flat_workers_active` et empile une `WorkerTask::Flat`. +#[inline] +fn dispatch( + data: DATA, + stage_idx: usize, + stages: &[Stage], + stage_txs: &[Sender>], + worker_tx: &Sender>, + flat_delta_tx: &Sender, + flat_workers_active: &mut usize, +) { + match &stages[stage_idx] { + Stage::Transform(f) => { + let _ = worker_tx.send(WorkerTask::Transform( + data, + Arc::clone(f), + stage_txs[stage_idx].clone(), + )); + } + Stage::Flat(f) => { + *flat_workers_active += 1; + let _ = worker_tx.send(WorkerTask::Flat( + data, + Arc::clone(f), + stage_txs[stage_idx].clone(), + flat_delta_tx.clone(), + )); + } + } +} + +// ── Macros ──────────────────────────────────────────────────────────────────── + +/// Creates a `SourceFn` from an iterator of plain values. +#[macro_export] +macro_rules! make_source { + ($enum:ident, $iterator:expr, $output:ident) => {{ + let mut iter = $iterator.into_iter(); + Box::new( + move || -> ::std::result::Result<$enum, $crate::PipelineError> { + match iter.next() { + Some(x) => Ok($enum::$output(x)), + None => Err($crate::PipelineError::EndOfStream), + } + }, + ) + as Box ::std::result::Result<$enum, $crate::PipelineError> + Send> + }}; +} + +/// Creates a `SourceFn` from an iterator of `Result`. +#[macro_export] +macro_rules! make_source_fallible { + ($enum:ident, $iterator:expr, $output:ident) => {{ + let mut iter = $iterator.into_iter(); + Box::new( + move || -> ::std::result::Result<$enum, $crate::PipelineError> { + match iter.next() { + Some(Ok(x)) => Ok($enum::$output(x)), + Some(Err(e)) => Err($crate::PipelineError::StepError(Box::new(e))), + None => Err($crate::PipelineError::EndOfStream), + } + }, + ) + as Box ::std::result::Result<$enum, $crate::PipelineError> + Send> + }}; +} + +/// Creates a `Stage::Transform` from a pure (non-fallible) function `Fn(T) -> U`. +#[macro_export] +macro_rules! make_transform { + ($enum:ident, $func:tt, $input:ident, $output:ident) => {{ + let __f = $func; + $crate::Stage::Transform( + ::std::sync::Arc::from(Box::new( + move |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> { + match data { + $enum::$input(x) => Ok($enum::$output(__f(x))), + _ => Err($crate::PipelineError::TypeMismatch), + } + }, + ) + as Box< + dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + + Send + + Sync, + >) + ) + }}; +} + +/// Creates a `Stage::Transform` from a fallible function `Fn(T) -> Result`. +#[macro_export] +macro_rules! make_transform_fallible { + ($enum:ident, $func:tt, $input:ident, $output:ident) => {{ + let __f = $func; + $crate::Stage::Transform( + ::std::sync::Arc::from(Box::new( + move |data: $enum| -> ::std::result::Result<$enum, $crate::PipelineError> { + match data { + $enum::$input(inner) => { + let result = __f(inner) + .map_err(|e| $crate::PipelineError::StepError(Box::new(e)))?; + Ok($enum::$output(result)) + } + _ => Err($crate::PipelineError::TypeMismatch), + } + }, + ) + as Box< + dyn Fn($enum) -> ::std::result::Result<$enum, $crate::PipelineError> + + Send + + Sync, + >) + ) + }}; +} + +/// Creates a `Stage::Flat` from a function `Fn(T) -> impl IntoIterator`. +/// +/// Pour chaque item produit par l'itérateur, il est poussé individuellement dans +/// le canal de sortie, permettant au scheduler de dispatcher les items en parallèle +/// dès qu'un worker est disponible. +#[macro_export] +macro_rules! make_flat_transform { + ($enum:ident, $func:tt, $input:ident, $output:ident) => {{ + let __f = $func; + $crate::Stage::Flat( + ::std::sync::Arc::new( + move |data: $enum, + push: &$crate::PipelineSender< + ::std::result::Result<$enum, $crate::PipelineError>, + >, + delta: &$crate::PipelineSender| { + match data { + $enum::$input(inner) => { + let mut count: isize = 0; + for item in __f(inner) { + push.send(Ok($enum::$output(item))).ok(); + count += 1; + } + delta.send(count - 1).ok(); + } + _ => { + push.send(Err($crate::PipelineError::TypeMismatch)).ok(); + delta.send(0).ok(); + } + } + }, + ) as $crate::SharedFlatFn<$enum> + ) + }}; +} + +/// Creates a `Stage::Flat` from a fallible function +/// `Fn(T) -> Result, E>`. +/// +/// Si la fonction retourne `Err`, une erreur est poussée dans le canal et aucun +/// item normal n'est produit. +#[macro_export] +macro_rules! make_flat_transform_fallible { + ($enum:ident, $func:tt, $input:ident, $output:ident) => {{ + let __f = $func; + $crate::Stage::Flat( + ::std::sync::Arc::new( + move |data: $enum, + push: &$crate::PipelineSender< + ::std::result::Result<$enum, $crate::PipelineError>, + >, + delta: &$crate::PipelineSender| { + match data { + $enum::$input(inner) => match __f(inner) { + Ok(iter) => { + let mut count: isize = 0; + for item in iter { + push.send(Ok($enum::$output(item))).ok(); + count += 1; + } + delta.send(count - 1).ok(); + } + Err(e) => { + push.send(Err($crate::PipelineError::StepError(Box::new(e)))) + .ok(); + delta.send(0).ok(); + } + }, + _ => { + push.send(Err($crate::PipelineError::TypeMismatch)).ok(); + delta.send(0).ok(); + } + } + }, + ) as $crate::SharedFlatFn<$enum> + ) + }}; +} + +/// Creates a `SinkFn` from a function that consumes a concrete value and returns `()`. +#[macro_export] +macro_rules! make_sink { + ($enum:ident, $func:tt, $input:ident) => {{ + let __f = $func; + Box::new( + move |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> { + match data { + $enum::$input(x) => { + __f(x); + Ok(()) + } + _ => Err($crate::PipelineError::TypeMismatch), + } + }, + ) + as Box ::std::result::Result<(), $crate::PipelineError> + Send> + }}; +} + +/// Creates a `SinkFn` from a fallible function that returns `Result<(), E>`. +#[macro_export] +macro_rules! make_sink_fallible { + ($enum:ident, $func:tt, $input:ident) => {{ + let __f = $func; + Box::new( + move |data: $enum| -> ::std::result::Result<(), $crate::PipelineError> { + match data { + $enum::$input(inner) => { + __f(inner).map_err(|e| $crate::PipelineError::StepError(Box::new(e))) + } + _ => Err($crate::PipelineError::TypeMismatch), + } + }, + ) + as Box ::std::result::Result<(), $crate::PipelineError> + Send> + }}; +} + +/// Construit un `Pipeline` à partir d'une source, d'une liste de stages et d'un sink. +/// +/// Syntaxe : +/// ```ignore +/// make_pipeline! { +/// MyData, +/// source my_iter => Variant, // source non-fallible +/// source? my_iter => Variant, // source fallible (Result) +/// | func: In => Out, // transform 1→1 non-fallible +/// |? func: In => Out, // transform 1→1 fallible +/// || func: In => Out, // flat transform 1→N non-fallible +/// ||? func: In => Out, // flat transform 1→N fallible +/// sink my_func @ Variant, // sink non-fallible +/// sink? my_func @ Variant, // sink fallible +/// } +/// ``` +#[macro_export] +macro_rules! make_pipeline { + // ── Points d'entrée ────────────────────────────────────────────────── + + ($enum:ident, source $src:expr => $src_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, + { $crate::make_source!($enum, $src, $src_out) }, + [], + $($rest)*) + }; + ($enum:ident, source? $src:expr => $src_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, + { $crate::make_source_fallible!($enum, $src, $src_out) }, + [], + $($rest)*) + }; + + // ── Accumulation des stages ────────────────────────────────────────── + + // transform 1→1 non-fallible + (@build $enum:ident, $source:tt, [$($acc:tt)*], + | $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, $source, + [$($acc)* $crate::make_transform!($enum, $tf, $t_in, $t_out),], + $($rest)*) + }; + + // transform 1→1 fallible + (@build $enum:ident, $source:tt, [$($acc:tt)*], + |? $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, $source, + [$($acc)* $crate::make_transform_fallible!($enum, $tf, $t_in, $t_out),], + $($rest)*) + }; + + // flat transform 1→N non-fallible + (@build $enum:ident, $source:tt, [$($acc:tt)*], + || $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, $source, + [$($acc)* $crate::make_flat_transform!($enum, $tf, $t_in, $t_out),], + $($rest)*) + }; + + // flat transform 1→N fallible + (@build $enum:ident, $source:tt, [$($acc:tt)*], + ||? $tf:tt : $t_in:ident => $t_out:ident, $($rest:tt)*) => { + $crate::make_pipeline!(@build $enum, $source, + [$($acc)* $crate::make_flat_transform_fallible!($enum, $tf, $t_in, $t_out),], + $($rest)*) + }; + + // ── Terminaison : sink ─────────────────────────────────────────────── + + (@build $enum:ident, $source:tt, [$($acc:tt)*], + sink $sink_fn:tt @ $sink_in:ident $(,)?) => { + $crate::Pipeline::new( + $source, + vec![$($acc)*], + $crate::make_sink!($enum, $sink_fn, $sink_in), + ) + }; + (@build $enum:ident, $source:tt, [$($acc:tt)*], + sink? $sink_fn:tt @ $sink_in:ident $(,)?) => { + $crate::Pipeline::new( + $source, + vec![$($acc)*], + $crate::make_sink_fallible!($enum, $sink_fn, $sink_in), + ) + }; +}