Skip to content

Training Programs

Training Programs

Like in machine learning, a LM application needs to be trained. In that case, we don't update the weights of the model, but optimize the prompts by automatically picking the best examples or generate instructions in order to help the program to perform better on your dataset.

In production settings, this means that you can use smaller and more cost-effective models from your preferred provider while enhancing their accuracy with Synalinks.

Training Flow

graph TB
    subgraph Data
        A[x_train, y_train]
        B[x_test, y_test]
    end
    subgraph Training
        C[program.compile] --> D[program.fit]
        D --> E{val_reward improved?}
        E -->|Yes| F[Save Checkpoint]
        E -->|No| D
    end
    subgraph Evaluation
        G[program.evaluate]
    end
    A --> D
    B --> G
    F --> G

Loading a Dataset

Synalinks provides built-in datasets for training and evaluation:

(x_train, y_train), (x_test, y_test) = synalinks.datasets.gsm8k.load_data()

Training with fit()

Training a program is similar to Keras. Use the fit() method with your data:

history = await program.fit(
    x=x_train,
    y=y_train,
    validation_split=0.2,
    epochs=20,
    batch_size=32,
    callbacks=[program_checkpoint_callback],
)

Saving and Loading Checkpoints

Use callbacks to save the best performing program during training:

program_checkpoint_callback = synalinks.callbacks.ProgramCheckpoint(
    filepath="checkpoint.program.json",
    monitor="val_reward",
    mode="max",
    save_best_only=True,
)

# Load the best checkpoint after training
program.load("checkpoint.program.json")

Key Takeaways

  • Dataset Loading: Use built-in datasets or create your own for training.
  • Training Loop: The fit() method handles the training process with configurable epochs, batch size, and validation split.
  • Checkpointing: Save the best performing model during training using ProgramCheckpoint callback.
  • Evaluation: Use evaluate() to measure performance before and after training.

Program Visualization

gsm8k_baseline

API References