Skip to content

Branch module

Branch

Bases: Module

Use a LanguageModel to select which module(s) to call based on an arbitrary input, a question and a list of labels.

The selected branch(es) output the data model computed using the inputs and module's branch, while the others output None. The output is always a tuple of length len(branches) so each label has a fixed positional slot regardless of which one was selected.

The behaviour of the selector depends on decision_type:

  • decision_type=Decision (default) — exactly one branch is selected per call. All other slots are None.
  • decision_type=MultiDecisionone or more branches are selected per call. Non-selected slots remain None. Use this for multi-label routing where several branches may need to fire at once (e.g., an article that spans both science and finance, or a query that should be answered by both a retrieval and a tool-using sub-program).

Single-label example (one branch active per call):

import synalinks
import asyncio

async def main():
    class Query(synalinks.DataModel):
        query: str

    class Answer(synalinks.DataModel):
        answer: str

    class AnswerWithCritique(synalinks.DataModel):
        thinking: str
        critique: str
        answer: str

    language_model = synalinks.LanguageModel(
        model="ollama/mistral",
    )

    x0 = synalinks.Input(data_model=Query)
    (x1, x2) = await synalinks.Branch(
        question="What is the difficulty level of the above query?",
        labels=["easy", "difficult"],
        branches=[
            synalinks.Generator(
                data_model=Answer,
                language_model=language_model,
            ),
            synalinks.Generator(
                data_model=AnswerWithCritique,
                language_model=language_model,
            ),
        ],
        language_model=language_model,
    )(x0)
    x3 = x1 | x2

    program = synalinks.Program(
        inputs=x0,
        outputs=x3,
        name="adaptative_chain_of_thought",
        description="Useful to answer step by step only when needed",
    )

if __name__ == "__main__":
    asyncio.run(main())

Multi-label example (zero, one, or several branches active per call):

import synalinks
import asyncio

async def main():
    class Article(synalinks.DataModel):
        text: str

    class ScienceSummary(synalinks.DataModel):
        thinking: str
        science_summary: str

    class FinanceSummary(synalinks.DataModel):
        thinking: str
        finance_summary: str

    class SportsSummary(synalinks.DataModel):
        thinking: str
        sports_summary: str

    language_model = synalinks.LanguageModel(model="ollama/mistral")

    x0 = synalinks.Input(data_model=Article)
    # Each label has a fixed slot in the output tuple. With
    # MultiDecision, several may be populated at once; the rest
    # are None.
    (sci, fin, spo) = await synalinks.Branch(
        question="Which topics does this article cover?",
        labels=["science", "finance", "sports"],
        branches=[
            synalinks.Generator(
                data_model=ScienceSummary,
                language_model=language_model,
            ),
            synalinks.Generator(
                data_model=FinanceSummary,
                language_model=language_model,
            ),
            synalinks.Generator(
                data_model=SportsSummary,
                language_model=language_model,
            ),
        ],
        decision_type=synalinks.MultiDecision,
        language_model=language_model,
    )(x0)

if __name__ == "__main__":
    asyncio.run(main())

For a biotech-startup article the result might be (<ScienceSummary>, <FinanceSummary>, None)science and finance are both active, sports stays None. The non-active slots can be combined downstream with | (logical OR) the same way as in the single-label example.

Parameters:

Name Type Description Default
question str

The question to ask.

None
labels list

The list of labels to choose from (strings).

None
branches list

The list of modules or programs to select from.

None
inject_decision bool

If True, inject the decision to the branch inputs. (default to True).

True
return_decision bool

If True, return the decision with the branch outputs. (default to True).

True
language_model LanguageModel

The language model to use.

None
prompt_template str

The default jinja2 prompt template to use (see Generator).

None
examples list

The default examples to use in the prompt (see Decision).

None
instructions list

The default instructions to use (see Decision).

None
seed_instructions list

Optional. A list of instructions to use as seed for the optimization. If not provided, use the default instructions as seed.

None
temperature float

Optional. The temperature for the LM call.

0.0
reasoning_effort string

Optional. The reasoning effort for the LM call between ['minimal', 'low', 'medium', 'high', 'disable', 'none', None]. Default to None (no reasoning).

None
use_inputs_schema bool

Optional. Whether or not use the inputs schema in the decision prompt (Default to False) (see Decision).

False
use_outputs_schema bool

Optional. Whether or not use the outputs schema in the decision prompt (Default to False) (see Decision).

False
decision_type type

Optional. The decision module class. Defaults to Decision (single-label, exactly one branch active). Pass MultiDecision to enable multi-label routing where several branches may be active simultaneously.

Decision
name str

Optional. The name of the module.

None
description str

Optional. The description of the module.

None
trainable bool

Whether the module's variables should be trainable.

True
Source code in synalinks/src/modules/core/branch.py
@synalinks_export(["synalinks.modules.Branch", "synalinks.Branch"])
class Branch(Module):
    """Use a `LanguageModel` to select which module(s) to call based on an
        arbitrary input, a question and a list of labels.

    The selected branch(es) output the data model computed using the inputs
    and module's branch, while the others output `None`. The output is
    always a tuple of length `len(branches)` so each label has a fixed
    positional slot regardless of which one was selected.

    The behaviour of the selector depends on `decision_type`:

    - `decision_type=Decision` (default) — exactly **one** branch is
      selected per call. All other slots are `None`.
    - `decision_type=MultiDecision` — **one or more** branches are
      selected per call. Non-selected slots remain `None`. Use this for
      multi-label routing where several branches may need to fire at
      once (e.g., an article that spans both `science` and `finance`,
      or a query that should be answered by both a retrieval and a
      tool-using sub-program).

    Single-label example (one branch active per call):

    ```python
    import synalinks
    import asyncio

    async def main():
        class Query(synalinks.DataModel):
            query: str

        class Answer(synalinks.DataModel):
            answer: str

        class AnswerWithCritique(synalinks.DataModel):
            thinking: str
            critique: str
            answer: str

        language_model = synalinks.LanguageModel(
            model="ollama/mistral",
        )

        x0 = synalinks.Input(data_model=Query)
        (x1, x2) = await synalinks.Branch(
            question="What is the difficulty level of the above query?",
            labels=["easy", "difficult"],
            branches=[
                synalinks.Generator(
                    data_model=Answer,
                    language_model=language_model,
                ),
                synalinks.Generator(
                    data_model=AnswerWithCritique,
                    language_model=language_model,
                ),
            ],
            language_model=language_model,
        )(x0)
        x3 = x1 | x2

        program = synalinks.Program(
            inputs=x0,
            outputs=x3,
            name="adaptative_chain_of_thought",
            description="Useful to answer step by step only when needed",
        )

    if __name__ == "__main__":
        asyncio.run(main())
    ```

    Multi-label example (zero, one, or several branches active per call):

    ```python
    import synalinks
    import asyncio

    async def main():
        class Article(synalinks.DataModel):
            text: str

        class ScienceSummary(synalinks.DataModel):
            thinking: str
            science_summary: str

        class FinanceSummary(synalinks.DataModel):
            thinking: str
            finance_summary: str

        class SportsSummary(synalinks.DataModel):
            thinking: str
            sports_summary: str

        language_model = synalinks.LanguageModel(model="ollama/mistral")

        x0 = synalinks.Input(data_model=Article)
        # Each label has a fixed slot in the output tuple. With
        # MultiDecision, several may be populated at once; the rest
        # are None.
        (sci, fin, spo) = await synalinks.Branch(
            question="Which topics does this article cover?",
            labels=["science", "finance", "sports"],
            branches=[
                synalinks.Generator(
                    data_model=ScienceSummary,
                    language_model=language_model,
                ),
                synalinks.Generator(
                    data_model=FinanceSummary,
                    language_model=language_model,
                ),
                synalinks.Generator(
                    data_model=SportsSummary,
                    language_model=language_model,
                ),
            ],
            decision_type=synalinks.MultiDecision,
            language_model=language_model,
        )(x0)

    if __name__ == "__main__":
        asyncio.run(main())
    ```

    For a biotech-startup article the result might be
    `(<ScienceSummary>, <FinanceSummary>, None)` — `science` and
    `finance` are both active, `sports` stays `None`. The non-active
    slots can be combined downstream with `|` (logical OR) the same
    way as in the single-label example.

    Args:
        question (str): The question to ask.
        labels (list): The list of labels to choose from (strings).
        branches (list): The list of modules or programs to select from.
        inject_decision (bool): If True, inject the decision to the branch inputs.
            (default to True).
        return_decision (bool): If True, return the decision with the branch outputs.
            (default to True).
        language_model (LanguageModel): The language model to use.
        prompt_template (str): The default jinja2 prompt template
            to use (see `Generator`).
        examples (list): The default examples to use in the prompt
            (see `Decision`).
        instructions (list): The default instructions to use (see `Decision`).
        seed_instructions (list): Optional. A list of instructions to use as seed for the
            optimization. If not provided, use the default instructions as seed.
        temperature (float): Optional. The temperature for the LM call.
        reasoning_effort (string): Optional. The reasoning effort for the LM call
            between ['minimal', 'low', 'medium', 'high', 'disable', 'none', None].
            Default to None (no reasoning).
        use_inputs_schema (bool): Optional. Whether or not use the inputs
            schema in the decision prompt (Default to False) (see `Decision`).
        use_outputs_schema (bool): Optional. Whether or not use the outputs
            schema in the decision prompt (Default to False) (see `Decision`).
        decision_type (type): Optional. The decision module class. Defaults to
            `Decision` (single-label, exactly one branch active). Pass
            `MultiDecision` to enable multi-label routing where several
            branches may be active simultaneously.
        name (str): Optional. The name of the module.
        description (str): Optional. The description of the module.
        trainable (bool): Whether the module's variables should be trainable.
    """

    def __init__(
        self,
        *,
        question=None,
        labels=None,
        branches=None,
        inject_decision=True,
        return_decision=True,
        language_model=None,
        prompt_template=None,
        examples=None,
        instructions=None,
        seed_instructions=None,
        temperature=0.0,
        reasoning_effort=None,
        use_inputs_schema=False,
        use_outputs_schema=False,
        decision_type=Decision,
        name=None,
        description=None,
        trainable=True,
        **kwargs,
    ):
        super().__init__(
            name=name,
            description=description,
            trainable=trainable,
        )
        if not branches:
            raise ValueError("The `branches` argument must be provided.")
        if not isinstance(branches, list):
            raise ValueError("The `branches` must be a list of `Module` or `Program`.")
        if len(labels) != len(branches):
            raise ValueError("The `labels` and `branches` must have the same length.")
        self.question = question
        self.labels = labels
        self.branches = {labels[i]: m for i, m in enumerate(branches)}
        self.inject_decision = inject_decision
        self.return_decision = return_decision
        self.language_model = _get_lm(language_model)
        self.prompt_template = prompt_template
        self.examples = examples
        self.instructions = instructions
        self.seed_instructions = seed_instructions
        self.temperature = temperature
        self.reasoning_effort = reasoning_effort
        self.use_inputs_schema = use_inputs_schema
        self.use_outputs_schema = use_outputs_schema
        self.decision = decision_type(
            question=self.question,
            labels=self.labels,
            language_model=self.language_model,
            prompt_template=self.prompt_template,
            examples=self.examples,
            instructions=self.instructions,
            seed_instructions=self.seed_instructions,
            temperature=self.temperature,
            reasoning_effort=self.reasoning_effort,
            use_inputs_schema=self.use_inputs_schema,
            use_outputs_schema=self.use_outputs_schema,
            name="decision_" + self.name,
        )

    async def call(self, inputs, training=False):
        outputs = [None] * len(self.branches)

        if not inputs:
            return tuple(outputs)

        decision = await self.decision(
            inputs,
            training=training,
        )

        if not decision:
            return tuple(outputs)

        choice = decision.get("choice", decision.get("choices"))

        if not choice:
            return tuple(outputs)

        if self.inject_decision:
            inputs = await ops.concat(
                inputs,
                decision,
                name="inputs_with_decision_" + self.name,
            )

        tasks = []

        async def execute_branch(
            inputs, module=None, decision=None, return_decision=False
        ):
            if not inputs:
                return None
            if return_decision:
                return await ops.logical_and(
                    decision,
                    await module(inputs),
                )
            else:
                return await module(inputs)

        for label in self.labels:
            module = self.branches[label]
            selected = False
            if isinstance(choice, str):
                if label == choice:
                    selected = True
            elif isinstance(choice, (list, set)):
                if label in choice:
                    selected = True
            if selected and module:
                tasks.append(
                    execute_branch(
                        inputs,
                        module,
                        decision,
                        return_decision=self.return_decision,
                    )
                )
            else:
                tasks.append(execute_branch(None))
        outputs = await asyncio.gather(*tasks)
        return tuple(outputs)

    async def compute_output_spec(self, inputs, training=False):
        outputs = []
        decision = await self.decision(
            inputs,
            training=training,
        )
        if self.inject_decision:
            inputs = await ops.concat(
                inputs,
                decision,
                name="inputs_with_decision_" + self.name,
            )
        for label in self.labels:
            module = self.branches[label]
            if self.return_decision:
                outputs.append(
                    await ops.logical_and(
                        decision,
                        await module(
                            inputs,
                            training=training,
                        ),
                        name="with_decision_" + self.name,
                    )
                )
            else:
                outputs.append(
                    await module(
                        inputs,
                        training=training,
                    )
                )
        return tuple(outputs)

    def get_config(self):
        config = {
            "question": self.question,
            "labels": self.labels,
            "inject_decision": self.inject_decision,
            "return_decision": self.return_decision,
            "prompt_template": self.prompt_template,
            "examples": self.examples,
            "instructions": self.instructions,
            "seed_instructions": self.seed_instructions,
            "temperature": self.temperature,
            "reasoning_effort": self.reasoning_effort,
            "use_inputs_schema": self.use_inputs_schema,
            "use_outputs_schema": self.use_outputs_schema,
            "name": self.name,
            "description": self.description,
            "trainable": self.trainable,
        }
        language_model_config = {
            "language_model": serialization_lib.serialize_synalinks_object(
                self.language_model
            )
        }
        branches_config = {
            "branches": [
                serialization_lib.serialize_synalinks_object(branch)
                for branch in self.branches.values()
            ]
        }
        return {**config, **language_model_config, **branches_config}

    @classmethod
    def from_config(cls, config, custom_objects=None):
        language_model = serialization_lib.deserialize_synalinks_object(
            config.pop("language_model")
        )
        branches = [
            serialization_lib.deserialize_synalinks_object(
                branch_config, custom_objects=custom_objects
            )
            for branch_config in config.pop("branches")
        ]
        return cls(language_model=language_model, branches=branches, **config)