""" Demonstrate a method to create a pipeline that can handle dynamic splits in the pipeline based on the input type. """ import functools import operator from typing import Any from langchain_core.callbacks import CallbackManagerForChainRun from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda, RunnableSerializable, patch_config from langchain_core.runnables.base import RunnableEach from langchain_core.runnables.utils import Input, Output verbose = [] def get_input_list(item): out = list(range(item)) verbose and print(f" get_input_list: in: {item}, out: {out}") return out def get_input_string(item): out = str(item) * 2 verbose and print(f" get_input_string: in: {item}, out: {out}") return out def process(item): out = item * 2 verbose and print(f" process: in: {item}, out: {out}") return out def process_merge(item): out = "-".join([str(i) for i in item]) verbose and print(f" process_merge: in: {item}, out: {out}") return out class FlexibleRunnableEach(RunnableEach): """Runnable that will split the pipeline if the input is a list. Otherwise, it will invoke the bound runnable.""" # This shouldn't really extend `RunnableEach`, this is just for demonstration purposes # It should rather be a standalone runnable extending RunnableSerializable def _invoke( self, inputs: Input | list[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> list[Output]: if isinstance(inputs, list): return self.batch(inputs, patch_config(config, callbacks=run_manager.get_child()), **kwargs) return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child())) class MergeRunnable(RunnableSerializable[Input, Output]): bound: Runnable[Input, Output] class Config: arbitrary_types_allowed = True def _invoke( self, inputs: Input | list[Input], run_manager: CallbackManagerForChainRun, config: RunnableConfig, **kwargs: Any, ) -> list[Output]: if not isinstance(inputs, list): inputs = [inputs] return self.bound.invoke(inputs, patch_config(config, callbacks=run_manager.get_child())) def invoke(self, inputs: Input | list[Input], config: RunnableConfig | None = None, **kwargs: Any) -> list[Output]: return self._call_with_config(self._invoke, inputs, config, **kwargs) def main(): # set_debug(True) pipes = [ [RunnableLambda(get_input_list), RunnableLambda(get_input_list), RunnableLambda(process)], [ RunnableLambda(get_input_list), MergeRunnable(bound=RunnableLambda(process_merge)), RunnableLambda(get_input_string), ], [RunnableLambda(get_input_string), RunnableLambda(get_input_string), RunnableLambda(process)], [ RunnableLambda(get_input_string), RunnableLambda(get_input_string), MergeRunnable(bound=RunnableLambda(process_merge)), ], ] for pipe in pipes: run_pipe(pipe) def run_pipe(steps): # Wrap each step in a runnable that will split the pipeline if the input is a list # This could be made more explicit if we know that a step is expected to return a list or not def _wrap_step(step: RunnableSerializable): if isinstance(step, MergeRunnable): return step return FlexibleRunnableEach(bound=step) chain = functools.reduce(operator.or_, map(_wrap_step, steps)) print(f"\nRunning chain: {chain}\n================================") for val in ([5, 6], 4): output = chain.invoke(val) print(f" input: '{val}', output: '{output}'") print(" -----------------------------") if __name__ == "__main__": verbose.append(1) main()