Skip to content

Monitor

Monitor

Bases: Callback

Monitor callback for logging training metrics to MLflow.

This callback logs training progress and evaluation metrics to MLflow for experiment tracking and visualization.

Parameters:

Name Type Description Default
experiment_name str

Name of the MLflow experiment. If None, uses the program name.

None
run_name str

Name of the MLflow run. If None, auto-generated.

None
tracking_uri str

MLflow tracking server URI. If None, uses the default (local ./mlruns directory or MLFLOW_TRACKING_URI env var).

None
log_batch_metrics bool

Whether to log metrics at batch level (default: False).

False
log_epoch_metrics bool

Whether to log metrics at epoch level (default: True).

True
log_program_plot bool

Whether to log the program plot as an artifact at the beginning of training (default: True).

True
log_program_model bool

Whether to log the program as an MLflow model at the end of training (default: True).

True
tags dict

Optional tags to add to the MLflow run.

None

Example:

import synalinks

# Basic usage - uses local MLflow storage
monitor = synalinks.callbacks.Monitor(experiment_name="my_experiment")

# With custom MLflow tracking server
monitor = synalinks.callbacks.Monitor(
    tracking_uri="http://localhost:5000",
    experiment_name="my_experiment",
    run_name="training_run_1",
    log_program_plot=True,
    log_program_model=True,
    tags={"model_type": "chain_of_thought"}
)

# Use in training
program.fit(
    x=train_data,
    y=train_labels,
    epochs=10,
    callbacks=[monitor]
)
Note

For tracing module calls along with training metrics, use synalinks.enable_observability() at the beggining of your script which configures the Monitor hook & callback:

synalinks.enable_observability(
    tracking_uri="http://localhost:5000",
    experiment_name="my_traces"
)
Source code in synalinks/src/callbacks/monitor.py
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
@synalinks_export("synalinks.callbacks.Monitor")
class Monitor(Callback):
    """Monitor callback for logging training metrics to MLflow.

    This callback logs training progress and evaluation metrics to MLflow
    for experiment tracking and visualization.

    Args:
        experiment_name (str): Name of the MLflow experiment. If None, uses
            the program name.
        run_name (str): Name of the MLflow run. If None, auto-generated.
        tracking_uri (str): MLflow tracking server URI. If None, uses the
            default (local ./mlruns directory or MLFLOW_TRACKING_URI env var).
        log_batch_metrics (bool): Whether to log metrics at batch level
            (default: False).
        log_epoch_metrics (bool): Whether to log metrics at epoch level
            (default: True).
        log_program_plot (bool): Whether to log the program plot as an artifact
            at the beginning of training (default: True).
        log_program_model (bool): Whether to log the program as an MLflow model
            at the end of training (default: True).
        tags (dict): Optional tags to add to the MLflow run.

    Example:

    ```python
    import synalinks

    # Basic usage - uses local MLflow storage
    monitor = synalinks.callbacks.Monitor(experiment_name="my_experiment")

    # With custom MLflow tracking server
    monitor = synalinks.callbacks.Monitor(
        tracking_uri="http://localhost:5000",
        experiment_name="my_experiment",
        run_name="training_run_1",
        log_program_plot=True,
        log_program_model=True,
        tags={"model_type": "chain_of_thought"}
    )

    # Use in training
    program.fit(
        x=train_data,
        y=train_labels,
        epochs=10,
        callbacks=[monitor]
    )
    ```

    Note:
        For tracing module calls along with training metrics, use
        `synalinks.enable_observability()` at the beggining of your script
        which configures the Monitor hook & callback:

        ```python
        synalinks.enable_observability(
            tracking_uri="http://localhost:5000",
            experiment_name="my_traces"
        )
        ```
    """

    def __init__(
        self,
        experiment_name=None,
        run_name=None,
        tracking_uri=None,
        log_batch_metrics=False,
        log_epoch_metrics=True,
        log_program_plot=True,
        log_program_model=True,
        tags=None,
    ):
        super().__init__()
        if not MLFLOW_AVAILABLE:
            raise ImportError(
                "mlflow is required for the Monitor callback. "
                "Install it with: pip install mlflow"
            )

        self.experiment_name = experiment_name
        self.run_name = run_name
        self.tracking_uri = tracking_uri
        self.log_batch_metrics = log_batch_metrics
        self.log_epoch_metrics = log_epoch_metrics
        self.log_program_plot = log_program_plot
        self.log_program_model = log_program_model
        self.tags = tags or {}
        self.logger = logging.getLogger(__name__)

        self._run = None
        self._step = 0
        self._epoch = 0
        # Track if we're inside fit() to avoid ending run during validation
        self._in_training = False

    def _setup_mlflow(self):
        """Configure MLflow tracking."""
        if self.tracking_uri:
            mlflow.set_tracking_uri(self.tracking_uri)

        experiment_name = self.experiment_name
        if experiment_name is None and self.program is not None:
            experiment_name = self.program.name or "synalinks_experiment"

        mlflow.set_experiment(experiment_name)

    def _start_run(self, run_name_suffix=""):
        """Start a new MLflow run."""
        run_name = self.run_name
        if run_name and run_name_suffix:
            run_name = f"{run_name}_{run_name_suffix}"
        elif run_name_suffix:
            run_name = run_name_suffix

        self._run = mlflow.start_run(run_name=run_name)

        tags = dict(self.tags)
        if self.program is not None:
            if self.program.name:
                tags["program_name"] = self.program.name
            if self.program.description:
                tags["program_description"] = self.program.description

        if tags:
            mlflow.set_tags(tags)

        self._step = 0
        self._epoch = 0

    def _end_run(self):
        """End the current MLflow run."""
        if self._run is not None:
            mlflow.end_run()
            self._run = None

    async def _log_metrics(self, logs, step=None):
        """Log metrics to MLflow asynchronously."""
        if logs is None or self._run is None:
            return

        metrics = {}
        for key, value in logs.items():
            if isinstance(value, (int, float)):
                metrics[key] = value

        if metrics:
            await asyncio.to_thread(mlflow.log_metrics, metrics, step=step)

    async def _upload_artifact_via_http(self, local_path, artifact_path, run_id):
        """Upload artifact via HTTP to MLflow server asynchronously.

        This method uses the MLflow REST API to upload artifacts directly,
        bypassing local filesystem artifact repo issues. Requires the MLflow
        server to be started with --serve-artifacts flag.
        """
        import requests

        if not self.tracking_uri:
            raise ValueError("tracking_uri is required for HTTP artifact upload")

        # Get the run's artifact URI to determine the correct upload path
        client = mlflow.MlflowClient(tracking_uri=self.tracking_uri)
        run = await asyncio.to_thread(client.get_run, run_id)
        artifact_uri = run.info.artifact_uri

        filename = os.path.basename(local_path)
        if artifact_path:
            full_artifact_path = f"{artifact_path}/{filename}"
        else:
            full_artifact_path = filename

        # Parse the artifact URI to construct the correct upload URL
        # artifact_uri can be:
        #   - mlflow-artifacts:/<experiment_id>/<run_id>/artifacts
        #   - mlflow-artifacts://host:port/<experiment_id>/<run_id>/artifacts
        #   - /mlflow/artifacts/<experiment_id>/<run_id>/artifacts (server local path)
        if artifact_uri.startswith("mlflow-artifacts:"):
            # Extract the path part after the scheme
            uri_path = artifact_uri.replace("mlflow-artifacts://", "").replace(
                "mlflow-artifacts:/", ""
            )
            # Remove host:port if present (will use tracking_uri instead)
            if "/" in uri_path and not uri_path.startswith("/"):
                parts = uri_path.split("/", 1)
                if ":" in parts[0] or "." in parts[0]:
                    # First part looks like host:port, skip it
                    uri_path = parts[1] if len(parts) > 1 else ""
        elif artifact_uri.startswith("/"):
            # Server-side local path like /mlflow/artifacts/<exp_id>/<run_id>/artifacts
            # Extract the relative path: <exp_id>/<run_id>/artifacts
            # Find the pattern after the base artifacts directory
            parts = artifact_uri.split("/")
            # Look for 'artifacts' in the path and take everything after the first one
            try:
                artifacts_idx = parts.index("artifacts")
                uri_path = "/".join(parts[artifacts_idx + 1 :])
            except ValueError:
                # Fallback: use experiment_id/run_id/artifacts pattern
                uri_path = f"0/{run_id}/artifacts"
        else:
            # Fallback for other URI schemes
            uri_path = f"0/{run_id}/artifacts"

        # Construct the full upload URL
        base = f"{self.tracking_uri}/api/2.0/mlflow-artifacts/artifacts"
        url = f"{base}/{uri_path}/{full_artifact_path}"

        with open(local_path, "rb") as f:
            content = f.read()

        # Determine content type based on file extension
        content_type = "application/octet-stream"
        if local_path.endswith(".png"):
            content_type = "image/png"
        elif local_path.endswith(".json"):
            content_type = "application/json"

        headers = {"Content-Type": content_type}
        response = await asyncio.to_thread(
            requests.put, url, data=content, headers=headers
        )

        if response.status_code not in (200, 201, 204):
            raise Exception(
                f"Failed to upload artifact: {response.status_code} {response.text}"
            )

    async def _log_program_plot_artifact(self):
        """Log the program plot as an MLflow artifact asynchronously."""
        if self._run is None:
            self.logger.warning("No MLflow run active, skipping plot logging")
            return

        if self.program is None:
            self.logger.warning("No program set, skipping plot logging")
            return

        if not self.program.built:
            self.logger.warning("Program not built, skipping plot logging")
            return

        try:
            from synalinks.src.utils.program_visualization import check_graphviz
            from synalinks.src.utils.program_visualization import check_pydot
            from synalinks.src.utils.program_visualization import plot_program

            if not check_pydot() or not check_graphviz():
                self.logger.warning(
                    "pydot or graphviz not available, skipping program plot"
                )
                return

            run_id = self._run.info.run_id

            with tempfile.TemporaryDirectory() as tmpdir:
                plot_filename = f"{self.program.name or 'program'}.png"
                plot_path = os.path.join(tmpdir, plot_filename)

                # Run plot generation in thread pool
                await asyncio.to_thread(
                    plot_program,
                    self.program,
                    to_file=plot_filename,
                    to_folder=tmpdir,
                    show_schemas=True,
                    show_module_names=True,
                    show_trainable=True,
                    dpi=96,  # Lower DPI for smaller file size
                )

                if os.path.exists(plot_path):
                    # Use HTTP upload if tracking_uri is set (remote server),
                    # otherwise fall back to direct artifact logging (local)
                    if self.tracking_uri:
                        await self._upload_artifact_via_http(
                            plot_path, artifact_path="program_plots", run_id=run_id
                        )
                    else:
                        await asyncio.to_thread(
                            mlflow.log_artifact,
                            plot_path,
                            artifact_path="program_plots",
                            run_id=run_id,
                        )
                    self.logger.info(f"Logged program plot: {plot_filename}")
                else:
                    self.logger.warning(f"Plot file not created: {plot_path}")

        except Exception as e:
            self.logger.warning(f"Failed to log program plot: {e}")

    async def _log_params(self):
        """Log training hyperparameters to MLflow asynchronously."""
        if self._run is None or self.params is None:
            return

        try:
            params_to_log = {}
            for key, value in self.params.items():
                if isinstance(value, (str, int, float, bool)):
                    params_to_log[key] = value

            if params_to_log:
                await asyncio.to_thread(mlflow.log_params, params_to_log)
                self.logger.debug(f"Logged params: {params_to_log}")
        except Exception as e:
            self.logger.warning(f"Failed to log params: {e}")

    async def _log_program_model(self):
        """Log the program trainable state as an MLflow artifact asynchronously.

        This saves only the trainable variables (state), not the full
        program architecture. This is useful for checkpointing the learned
        parameters like few-shot examples, optimized prompts, etc.
        """
        if self._run is None or self.program is None:
            self.logger.warning("No run or program, skipping model logging")
            return

        try:
            import orjson

            # Get the state tree (trainable, non-trainable, optimizer variables)
            state_tree = self.program.get_state_tree()

            # Create model info
            model_info = {
                "program_name": self.program.name or "program",
                "program_description": self.program.description or "",
                "framework": "synalinks",
                "num_trainable_variables": len(self.program.trainable_variables),
            }

            run_id = self._run.info.run_id

            # Write to temp files and log as artifacts
            with tempfile.TemporaryDirectory() as tmpdir:
                # Save state tree
                state_path = os.path.join(tmpdir, "state_tree.json")
                with open(state_path, "wb") as f:
                    f.write(orjson.dumps(state_tree, option=orjson.OPT_INDENT_2))

                # Save model info
                info_path = os.path.join(tmpdir, "model_info.json")
                with open(info_path, "wb") as f:
                    f.write(orjson.dumps(model_info, option=orjson.OPT_INDENT_2))

                # Upload artifacts
                if self.tracking_uri:
                    await self._upload_artifact_via_http(
                        state_path, artifact_path="model", run_id=run_id
                    )
                    await self._upload_artifact_via_http(
                        info_path, artifact_path="model", run_id=run_id
                    )
                else:
                    await asyncio.to_thread(
                        mlflow.log_artifact,
                        state_path,
                        artifact_path="model",
                        run_id=run_id,
                    )
                    await asyncio.to_thread(
                        mlflow.log_artifact,
                        info_path,
                        artifact_path="model",
                        run_id=run_id,
                    )

            self.logger.info(
                f"Logged program state: {self.program.name} "
                f"({len(self.program.trainable_variables)} trainable variables)"
            )

        except Exception as e:
            self.logger.warning(f"Failed to log program model: {e}")

    def on_train_begin(self, logs=None):
        """Called at the beginning of training."""
        self._in_training = True
        self._setup_mlflow()
        self._start_run(run_name_suffix="train")
        self.logger.debug("MLflow run started for training")

        # Log hyperparameters
        run_maybe_nested(self._log_params())

        # Log program plot
        if self.log_program_plot:
            run_maybe_nested(self._log_program_plot_artifact())

    def on_train_end(self, logs=None):
        """Called at the end of training."""
        run_maybe_nested(self._log_metrics(logs, step=self._step))

        # Log program as model at end of training
        if self.log_program_model:
            run_maybe_nested(self._log_program_model())

        self._end_run()
        self._in_training = False
        self.logger.debug("MLflow run ended for training")

    def on_epoch_begin(self, epoch, logs=None):
        """Called at the start of an epoch."""
        self._epoch = epoch

    def on_epoch_end(self, epoch, logs=None):
        """Called at the end of an epoch."""
        if not self.log_epoch_metrics:
            return

        self._epoch = epoch
        run_maybe_nested(self._log_metrics(logs, step=epoch))
        self.logger.debug(f"Logged metrics for epoch {epoch}")

    def on_train_batch_begin(self, batch, logs=None):
        """Called at the beginning of a training batch."""
        pass

    def on_train_batch_end(self, batch, logs=None):
        """Called at the end of a training batch."""
        if not self.log_batch_metrics:
            return

        self._step += 1
        run_maybe_nested(self._log_metrics(logs, step=self._step))

    def on_test_begin(self, logs=None):
        """Called at the beginning of evaluation or validation."""
        # Only start a new run if we're not already in a training run
        if self._run is None and not self._in_training:
            self._setup_mlflow()
            self._start_run(run_name_suffix="test")
            self.logger.debug("MLflow run started for testing")

    def on_test_end(self, logs=None):
        """Called at the end of evaluation or validation."""
        run_maybe_nested(self._log_metrics(logs, step=self._step))
        # Only end the run if we're not in training (standalone evaluate() call)
        if self._run is not None and not self._in_training:
            self._end_run()
            self.logger.debug("MLflow run ended for testing")

    def on_test_batch_begin(self, batch, logs=None):
        """Called at the beginning of a test batch."""
        pass

    def on_test_batch_end(self, batch, logs=None):
        """Called at the end of a test batch."""
        if not self.log_batch_metrics:
            return

        self._step += 1
        run_maybe_nested(self._log_metrics(logs, step=self._step))

    def on_predict_begin(self, logs=None):
        """Called at the beginning of prediction."""
        pass

    def on_predict_end(self, logs=None):
        """Called at the end of prediction."""
        pass

    def on_predict_batch_begin(self, batch, logs=None):
        """Called at the beginning of a prediction batch."""
        pass

    def on_predict_batch_end(self, batch, logs=None):
        """Called at the end of a prediction batch."""
        pass

    def __del__(self):
        """Cleanup any open MLflow run."""
        if hasattr(self, "_run") and self._run is not None:
            try:
                mlflow.end_run()
            except Exception:
                pass

__del__()

Cleanup any open MLflow run.

Source code in synalinks/src/callbacks/monitor.py
def __del__(self):
    """Cleanup any open MLflow run."""
    if hasattr(self, "_run") and self._run is not None:
        try:
            mlflow.end_run()
        except Exception:
            pass

on_epoch_begin(epoch, logs=None)

Called at the start of an epoch.

Source code in synalinks/src/callbacks/monitor.py
def on_epoch_begin(self, epoch, logs=None):
    """Called at the start of an epoch."""
    self._epoch = epoch

on_epoch_end(epoch, logs=None)

Called at the end of an epoch.

Source code in synalinks/src/callbacks/monitor.py
def on_epoch_end(self, epoch, logs=None):
    """Called at the end of an epoch."""
    if not self.log_epoch_metrics:
        return

    self._epoch = epoch
    run_maybe_nested(self._log_metrics(logs, step=epoch))
    self.logger.debug(f"Logged metrics for epoch {epoch}")

on_predict_batch_begin(batch, logs=None)

Called at the beginning of a prediction batch.

Source code in synalinks/src/callbacks/monitor.py
def on_predict_batch_begin(self, batch, logs=None):
    """Called at the beginning of a prediction batch."""
    pass

on_predict_batch_end(batch, logs=None)

Called at the end of a prediction batch.

Source code in synalinks/src/callbacks/monitor.py
def on_predict_batch_end(self, batch, logs=None):
    """Called at the end of a prediction batch."""
    pass

on_predict_begin(logs=None)

Called at the beginning of prediction.

Source code in synalinks/src/callbacks/monitor.py
def on_predict_begin(self, logs=None):
    """Called at the beginning of prediction."""
    pass

on_predict_end(logs=None)

Called at the end of prediction.

Source code in synalinks/src/callbacks/monitor.py
def on_predict_end(self, logs=None):
    """Called at the end of prediction."""
    pass

on_test_batch_begin(batch, logs=None)

Called at the beginning of a test batch.

Source code in synalinks/src/callbacks/monitor.py
def on_test_batch_begin(self, batch, logs=None):
    """Called at the beginning of a test batch."""
    pass

on_test_batch_end(batch, logs=None)

Called at the end of a test batch.

Source code in synalinks/src/callbacks/monitor.py
def on_test_batch_end(self, batch, logs=None):
    """Called at the end of a test batch."""
    if not self.log_batch_metrics:
        return

    self._step += 1
    run_maybe_nested(self._log_metrics(logs, step=self._step))

on_test_begin(logs=None)

Called at the beginning of evaluation or validation.

Source code in synalinks/src/callbacks/monitor.py
def on_test_begin(self, logs=None):
    """Called at the beginning of evaluation or validation."""
    # Only start a new run if we're not already in a training run
    if self._run is None and not self._in_training:
        self._setup_mlflow()
        self._start_run(run_name_suffix="test")
        self.logger.debug("MLflow run started for testing")

on_test_end(logs=None)

Called at the end of evaluation or validation.

Source code in synalinks/src/callbacks/monitor.py
def on_test_end(self, logs=None):
    """Called at the end of evaluation or validation."""
    run_maybe_nested(self._log_metrics(logs, step=self._step))
    # Only end the run if we're not in training (standalone evaluate() call)
    if self._run is not None and not self._in_training:
        self._end_run()
        self.logger.debug("MLflow run ended for testing")

on_train_batch_begin(batch, logs=None)

Called at the beginning of a training batch.

Source code in synalinks/src/callbacks/monitor.py
def on_train_batch_begin(self, batch, logs=None):
    """Called at the beginning of a training batch."""
    pass

on_train_batch_end(batch, logs=None)

Called at the end of a training batch.

Source code in synalinks/src/callbacks/monitor.py
def on_train_batch_end(self, batch, logs=None):
    """Called at the end of a training batch."""
    if not self.log_batch_metrics:
        return

    self._step += 1
    run_maybe_nested(self._log_metrics(logs, step=self._step))

on_train_begin(logs=None)

Called at the beginning of training.

Source code in synalinks/src/callbacks/monitor.py
def on_train_begin(self, logs=None):
    """Called at the beginning of training."""
    self._in_training = True
    self._setup_mlflow()
    self._start_run(run_name_suffix="train")
    self.logger.debug("MLflow run started for training")

    # Log hyperparameters
    run_maybe_nested(self._log_params())

    # Log program plot
    if self.log_program_plot:
        run_maybe_nested(self._log_program_plot_artifact())

on_train_end(logs=None)

Called at the end of training.

Source code in synalinks/src/callbacks/monitor.py
def on_train_end(self, logs=None):
    """Called at the end of training."""
    run_maybe_nested(self._log_metrics(logs, step=self._step))

    # Log program as model at end of training
    if self.log_program_model:
        run_maybe_nested(self._log_program_model())

    self._end_run()
    self._in_training = False
    self.logger.debug("MLflow run ended for training")