Skip to content

Language Models Ops

Predict

Bases: Operation

Perform a prediction using a LanguageModel.

Source code in synalinks/src/ops/language_models.py
class Predict(Operation):
    """Perform a prediction using a `LanguageModel`."""

    def __init__(
        self,
        schema=None,
        data_model=None,
        language_model=None,
        streaming=False,
        name=None,
        description=None,
        **kwargs,
    ):
        super().__init__(
            name=name,
            description=description,
        )
        if not schema and data_model:
            schema = data_model.schema()
        self.schema = schema
        self.data_model = data_model
        self.language_model = language_model
        if schema and streaming:
            streaming = False
        self.streaming = streaming
        self.lm_kwargs = kwargs

    async def call(self, x):
        value = await self.language_model(
            x,
            schema=self.schema,
            streaming=self.streaming,
            **self.lm_kwargs,
        )
        if isinstance(value, StreamingIterator):
            return value
        if not value:
            return None
        if self.schema:
            return JsonDataModel(value=value, schema=self.schema, name=self.name)
        else:
            return JsonDataModel(value=value, schema=ChatMessage.schema(), name=self.name)

    async def compute_output_spec(self, x):
        if self.schema:
            return SymbolicDataModel(schema=self.schema, name=self.name)
        else:
            return SymbolicDataModel(schema=ChatMessage.schema(), name=self.name)

    def get_config(self):
        config = {
            "schema": self.schema,
            "streaming": self.streaming,
            "name": self.name,
            "description": self.description,
        }
        language_model_config = serialization_lib.serialize_synalinks_object(
            self.language_model
        )
        config.update({"lm_kwargs": self.lm_kwargs})
        return {"language_model": language_model_config, **config}

    @classmethod
    def from_config(cls, config):
        language_model = serialization_lib.deserialize_synalinks_object(
            config.pop("language_model")
        )
        lm_kwargs = config.pop("lm_kwargs")
        return cls(language_model=language_model, **config, **lm_kwargs)

predict(x, schema=None, data_model=None, language_model=None, streaming=False, name=None, description=None, **kwargs) async

Perform a prediction using a LanguageModel.

Predict consist in predicting a target data_model from an input data_model. This function uses a backend DataModel to get the target schema.

Parameters:

Name Type Description Default
x JsonDataModel | SymbolicDataModel

the input data model.

required
data_model DataModel

The target data model.

None
language_model LanguageModel

The language model to use

None
streaming bool

Enable streaming if True (Default to False)

False
name str

Optional. The name of the operation.

None
description str

Optional. The description of the operation.

None
**kwargs keyword arguments

Additional keywords forwarded to the LanguageModel call.

{}

Returns:

Type Description
JsonDataModel | SymbolicDataModel

The resulting data model.

Source code in synalinks/src/ops/language_models.py
@synalinks_export(["synalinks.ops.predict", "synalinks.ops.json.predict"])
async def predict(
    x,
    schema=None,
    data_model=None,
    language_model=None,
    streaming=False,
    name=None,
    description=None,
    **kwargs,
):
    """Perform a prediction using a `LanguageModel`.

    Predict consist in predicting a target data_model from an input data_model.
    This function uses a backend DataModel to get the target schema.

    Args:
        x (JsonDataModel | SymbolicDataModel): the input data model.
        data_model (DataModel): The target data model.
        language_model (LanguageModel): The language model to use
        streaming (bool): Enable streaming if True (Default to False)
        name (str): Optional. The name of the operation.
        description (str): Optional. The description of the operation.
        **kwargs (keyword arguments): Additional keywords forwarded to the
            LanguageModel call.

    Returns:
        (JsonDataModel | SymbolicDataModel): The resulting data model.
    """
    if language_model is None:
        raise ValueError("You should provide the `language_model` argument")
    if any_symbolic_data_models(x):
        return await Predict(
            schema=schema,
            data_model=data_model,
            language_model=language_model,
            streaming=False,
            name=name,
            description=description,
        ).symbolic_call(x)
    return await Predict(
        schema=schema,
        data_model=data_model,
        language_model=language_model,
        streaming=streaming,
        name=name,
        description=description,
        **kwargs,
    )(x)