Skip to content

Program training API

Source code in synalinks/src/trainers/trainer.py
  27
  28
  29
  30
  31
  32
  33
  34
  35
  36
  37
  38
  39
  40
  41
  42
  43
  44
  45
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
class Trainer:
    def __init__(self):
        self._lock = False
        self._run_eagerly = False
        self.compiled = False
        self.reward = None
        self.steps_per_execution = 1
        # Can be set by callbacks in on_train_begin
        self._initial_epoch = None
        self._compute_reward_has_training_arg = (
            "training" in inspect.signature(self.compute_reward).parameters
        )
        # Placeholders used in `compile`
        self._optimizer = None
        self._compile_reward = None
        self._compile_metrics = None
        self._reward_tracker = None

    @tracking.no_automatic_dependency_tracking
    def compile(
        self,
        optimizer=None,
        reward=None,
        reward_weights=None,
        metrics=None,
        run_eagerly=False,
        steps_per_execution=1,
    ):
        """Configures the program for training.

        Example:

        ```python
        program.compile(
            optimizer=synalinks.optimizers.RandomFewShot(),
            reward=synalinks.rewards.ExactMatch(),
            metrics=[
                synalinks.metrics.MeanMetricWrapper(synalinks.rewards.exact_match),
            ],
        )
        ```

        Args:
            optimizer (Optimizer): Optimizer instance. See `synalinks.optimizers`.
            reward (Reward): Reward function. A `synalinks.rewards.Reward`
                instance. See `synalinks.rewards`. A reward function is
                any callable with the signature `reward = fn(y_true, y_pred)`,
                where `y_true` are the ground truth values, and `y_pred`
                are the program's predictions.
                `y_true` should be a list of batch size length `[d0, .. dN]`.
                `y_pred` should be a list of batch size length `[d0, .. dN]`.
                The reward function should return a float.
            reward_weights (list): Optional list specifying scalar coefficients
                (Python floats) to weight the reward contributions of
                different program outputs. The reward value that will be maximized
                by the program will then be the *weighted sum* of all individual
                rewards, weighted by the `reward_weights` coefficients. It is
                expected to have a 1:1 mapping to the program's outputs.
            metrics (list): List of metrics to be evaluated by the program during
                training and testing. Each of it is a `synalinks.metrics.Metric`
                instance. See `synalinks.metrics`. A function is any callable with the
                signature `result = fn(y_true, y_pred)`.
            run_eagerly (bool): If `True`, this program's forward pass
                 will never be compiled. It is recommended to leave this
                 as `False` when training (for best performance),
                 and to set it to `True` when debugging.
            steps_per_execution (int): The number of batches to run
                during each a single compiled function call. Running multiple
                batches inside a single compiled function call can
                greatly improve performance on TPUs or small programs with a large
                Python overhead. At most, one full epoch will be run each
                execution. If a number larger than the size of the epoch is
                passed, the execution will be truncated to the size of the
                epoch. Note that if `steps_per_execution` is set to `N`,
                `Callback.on_batch_begin` and `Callback.on_batch_end` methods
                will only be called every `N` batches (i.e. before/after
                each compiled function execution).
        """
        self._clear_previous_trainer_metrics()
        self._optimizer = optimizer

        if hasattr(self, "output_names"):
            output_names = self.output_names
        else:
            output_names = None
        if reward is not None:
            self._compile_reward = CompileReward(
                reward, reward_weights, output_names=output_names
            )
            self.reward = reward
        if metrics is not None:
            self._compile_metrics = CompileMetrics(metrics, output_names=output_names)
        self.run_eagerly = run_eagerly
        self.stop_training = False
        self.compiled = True
        self._reward_tracker = metrics_module.Mean(name="reward")
        self.steps_per_execution = steps_per_execution

        self._compile_config = serialization_lib.SerializableDict(
            optimizer=optimizer,
            reward=reward,
            reward_weights=reward_weights,
            metrics=metrics,
            run_eagerly=run_eagerly,
            steps_per_execution=steps_per_execution,
        )

    @property
    def optimizer(self):
        return self._optimizer

    @property
    def metrics(self):
        # Order: reward tracker, individual reward trackers, compiled metrics,
        # custom metrcis, submodule metrics.
        metrics = []
        if self.compiled:
            if self._reward_tracker is not None:
                metrics.append(self._reward_tracker)
            if self._compile_metrics is not None:
                metrics.append(self._compile_metrics)
            if self._compile_reward is not None:
                metrics.extend(self._compile_reward.metrics)
        metrics.extend(self._metrics)
        for module in self._flatten_modules(include_self=False):
            if isinstance(module, Trainer):
                # All Trainer-related metrics in submodules should be ignored
                # because a new Trainer has been instantiated.
                continue
            metrics.extend(module.metrics)
        return metrics

    @property
    def metrics_names(self):
        return [m.name for m in self.metrics]

    def reset_metrics(self):
        for m in self.metrics:
            m.reset_state()

    def _get_own_metrics(self):
        metrics = []
        if self._reward_tracker is not None:
            metrics.append(self._reward_tracker)
        if self._compile_metrics is not None:
            metrics.append(self._compile_metrics)
        if self._compile_reward is not None:
            metrics.extend(self._compile_reward.metrics)
        metrics.extend(self._metrics)
        return metrics

    def _clear_previous_trainer_metrics(self):
        for module in self._flatten_modules(include_self=False):
            if not isinstance(module, Trainer):
                continue
            # A submodule might be a Trainer. In that case, we need to clear
            # the Trainer-related metrics, as they are not usable when a
            # new Trainer is instantiated.
            for m in self._get_own_metrics():
                module._tracker.untrack(m)
            module._reward_tracker = None
            module._compile_metrics = None
            if module._compile_reward is not None:
                module._compile_reward._metrics.clear()
            module._metrics.clear()

    @property
    def run_eagerly(self):
        return self._run_eagerly

    @run_eagerly.setter
    def run_eagerly(self, value):
        self._run_eagerly = value

    async def compute_reward(
        self,
        x=None,
        y=None,
        y_pred=None,
        sample_weight=None,
        training=True,
    ):
        """Compute the total reward, validate it, and return it.

        Subclasses can optionally override this method to provide custom reward
        computation logic.

        Args:
            x (list): Input data.
            y (list): Target data.
            y_pred (list): Predictions returned by the program (output of `program(x)`).
            training (bool): Whether we are training or evaluating the program.

        Returns:
            (float | None): The total reward as a scalar, or `None` if no reward results
                (which is the case when called by `Program.test_step`).
        """
        # The default implementation does not use `x` or `training`.
        del x
        del training
        rewards = []
        if self._compile_reward is not None:
            for y_t, y_p in zip(y, y_pred):
                reward = await self._compile_reward(y_t, y_p)
                if reward is not None:
                    rewards.append(reward)
        for reward in self.rewards:
            rewards.append(numpy.sum(reward))
        if len(rewards) == 1:
            total_reward = rewards[0]
        elif len(rewards) == 0:
            total_reward = numpy.zeros(())
        else:
            total_reward = numpy.mean(rewards)
        return float(total_reward)

    def stateless_compute_reward(
        self,
        trainable_variables,
        non_trainable_variables,
        metrics_variables,
        x=None,
        y=None,
        y_pred=None,
        sample_weight=None,
        training=True,
    ):
        var_mapping = list(zip(self.trainable_variables, trainable_variables))
        var_mapping.extend(zip(self.non_trainable_variables, non_trainable_variables))
        var_mapping.extend(zip(self.metrics_variables, metrics_variables))
        with backend.StatelessScope(state_mapping=var_mapping) as scope:
            # Note that this is needed for the regularization reward, which need
            # the latest value of train/non-trainable variables.
            reward = self._compute_reward(
                x,
                y,
                y_pred,
                sample_weight=sample_weight,
                training=training,
            )

        # Update non trainable vars (may have been updated in compute_reward)
        non_trainable_variables = []
        for v in self.non_trainable_variables:
            new_v = scope.get_current_value(v)
            non_trainable_variables.append(new_v)

        # Update metrics vars (may have been updated in compute_reward)
        metrics_variables = []
        for v in self.metrics_variables:
            new_v = scope.get_current_value(v)
            metrics_variables.append(new_v)
        return reward, (
            trainable_variables,
            non_trainable_variables,
            metrics_variables,
        )

    async def compute_metrics(self, x, y, y_pred):
        """Update metric states and collect all metrics to be returned.

        Subclasses can optionally override this method to provide custom metric
        updating and collection logic. Custom metrics are not passed in
        `compile()`, they can be created in `__init__` or `build`. They are
        automatically tracked and returned by `self.metrics`.
        ```

        Args:
            x: Input data.
            y: Target data.
            y_pred: Predictions returned by the program output of `program.call(x)`.

        Returns:
            A `dict` containing values that will be passed to
                `synalinks.callbacks.CallbackList.on_train_batch_end()`. Typically,
                the values of the metrics listed in `self.metrics` are returned.
                Example: `{'reward': 0.2, 'accuracy': 0.7}`.
        """
        del x  # The default implementation does not use `x`.
        if self._compile_metrics is not None:
            for y_t, y_p in zip(y, y_pred):
                await self._compile_metrics.update_state(y_t, y_p)
        return self.get_metrics_result()

    def get_metrics_result(self):
        """Returns the program's metrics values as a dict.

        If any of the metric result is a dict (containing multiple metrics),
        each of them gets added to the top level returned dict of this method.

        Returns:
            (dict): A `dict` containing values of the metrics listed in `self.metrics`.
                Example: `{'reward': 0.2, 'accuracy': 0.7}`.
        """
        return_metrics = {}
        for metric in self.metrics:
            result = metric.result()
            if isinstance(result, dict):
                return_metrics.update(result)
            else:
                return_metrics[metric.name] = result
        return python_utils.pythonify_logs(return_metrics)

    async def fit(
        self,
        x=None,
        y=None,
        batch_size=None,
        epochs=1,
        verbose="auto",
        callbacks=None,
        validation_split=0.0,
        validation_data=None,
        shuffle=True,
        initial_epoch=0,
        steps_per_epoch=None,
        validation_steps=None,
        validation_batch_size=None,
        validation_freq=1,
    ):
        """Trains the program for a fixed number of epochs (dataset iterations).

        Args:
            x (np.ndarray | generator): Input data. It can be:
                - A NumPy array (or array-like), or a list of `DataModel` arrays
                    (in case the model has multiple inputs).
                - A list of dict mapping input names to the corresponding `DataModel`s,
                    if the program has named inputs.
                - A Python generator function yielding `(inputs, targets)`.
            y (np.ndarray): Target data. Like the input data `x`, it can be either NumPy
                array(s) of `DataModel`(s). If `x` is a Python generator function,
                `y` should not be specified since targets will be obtained from
                `x`.
            batch_size (int): Integer or `None`.
                Number of samples per batch of computation.
                If unspecified, `batch_size` will default to 32.
                Do not specify the `batch_size` if your input data `x` is a
                Python generator function since they generate batches.
            epochs (int): Integer. Number of epochs to train the program.
                An epoch is an iteration over the entire `x` and `y`
                data provided (unless the `steps_per_epoch` flag is set to
                something other than None).
                Note that in conjunction with `initial_epoch`,
                `epochs` is to be understood as "final epoch".
                The program is not trained for a number of iterations
                given by `epochs`, but merely until the epoch
                of index `epochs` is reached.
            verbose (int): `"auto"`, 0, 1, or 2. Verbosity mode.
                0 = silent, 1 = progress bar, 2 = one line per epoch.
                "auto" becomes 1 for most cases.
                Note that the progress bar is not
                particularly useful when logged to a file,
                so `verbose=2` is recommended when not running interactively
                (e.g., in a production environment). Defaults to `"auto"`.
            callbacks (list): List of `synalinks.callbacks.Callback` instances.
                List of callbacks to apply during training.
                See `synalinks.callbacks`. Note
                `synalinks.callbacks.ProgbarLogger` and
                `synalinks.callbacks.History` callbacks are created
                automatically and need not be passed to `program.fit()`.
                `synalinks.callbacks.ProgbarLogger` is created
                or not based on the `verbose` argument in `program.fit()`.
            validation_split (float): Float between 0 and 1.
                Fraction of the training data to be used as validation data.
                The program will set apart this fraction of the training data,
                will not train on it, and will evaluate the reward and any program
                metrics on this data at the end of each epoch. The validation
                data is selected from the last samples in the `x` and `y` data
                provided, before shuffling.
                This argument is only supported when `x` and `y` are made of
                data_models.
                If both `validation_data` and `validation_split` are provided,
                `validation_data` will override `validation_split`.
            validation_data (tuple | iterator): Data on which to evaluate
                the reward and any program metrics at the end of each epoch.
                The program will not be trained on this data.
                `validation_data` will override `validation_split`.
                It can be:
                - A tuple `(x_val, y_val)` of `DataModel`s lists.
            shuffle (bool): Whether to shuffle the training data before each
                epoch. This argument is ignored when `x` is a Python generator function.
            initial_epoch (int): Integer.
                Epoch at which to start training
                (useful for resuming a previous training run).
            steps_per_epoch (int): Integer or `None`.
                Total number of steps (batches of samples) before declaring one
                epoch finished and starting the next epoch. When training with
                input data_models arrays, the default `None` means that the
                value used is the number of samples in your dataset divided by
                the batch size, or 1 if that cannot be determined.
                If `x` is a Python generator function, the
                epoch will run until the input dataset is exhausted. When
                passing an infinitely repeating dataset, you must specify the
                `steps_per_epoch` argument, otherwise the training will run
                indefinitely.
            validation_steps (int): Integer or `None`.
                Only relevant if `validation_data` is provided.
                Total number of steps (batches of samples) to draw before
                stopping when performing validation at the end of every epoch.
                If `validation_steps` is `None`, validation will run until the
                `validation_data` dataset is exhausted. In the case of an
                infinitely repeating dataset, it will run indefinitely. If
                `validation_steps` is specified and only part of the dataset
                is consumed, the evaluation will start from the beginning of the
                dataset at each epoch. This ensures that the same validation
                samples are used every time.
            validation_batch_size (int): Integer or `None`.
                Number of samples per validation batch.
                If unspecified, will default to `batch_size`.
                Do not specify the `validation_batch_size` if your data is a
                `synalinks.utils.PyDataset`, `tf.data.Dataset`,
                `torch.utils.data.DataLoader` or Python generator function
                since they generate batches.
            validation_freq (int): Only relevant if validation data is provided.
                Specifies how many training epochs to run
                before a new validation run is performed,
                e.g. `validation_freq=2` runs validation every 2 epochs.

        Returns:
            (History): A `History` object. Its `History.history` attribute is
                a record of training reward values and metrics values
                at successive epochs, as well as validation reward values
                and validation metrics values (if applicable).
        """
        self._assert_compile_called("fit")
        # TODO: respect compiled trainable state
        self._eval_epoch_iterator = None
        if validation_split and validation_data is None:
            # Create the validation data using the training data. Only supported
            # for numpy arrays.
            (x, y), validation_data = array_slicing.train_validation_split(
                (x, y), validation_split=validation_split
            )

        if validation_data is not None:
            (val_x, val_y) = data_adapter_utils.unpack_x_y(validation_data)
        # Create an iterator that yields batches of input/target data.
        epoch_iterator = EpochIterator(
            x=x,
            y=y,
            batch_size=batch_size,
            steps_per_epoch=steps_per_epoch,
            shuffle=False,
            steps_per_execution=self.steps_per_execution,
        )

        if not all(module.built for module in self._flatten_modules()):
            # Build the model on one batch of data.
            for _, data in epoch_iterator:
                data_batch = data[0]
                self._symbolic_build(data_batch)
                break
        epoch_iterator.reset()

        # Container that configures and calls callbacks.
        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(
                callbacks,
                add_history=True,
                add_progbar=verbose != 0,
                verbose=verbose,
                epochs=epochs,
                steps=steps_per_epoch,
                program=self,
            )

        self.stop_training = False
        callbacks.on_train_begin()
        training_logs = None
        logs = {}
        initial_epoch = self._initial_epoch or initial_epoch

        for epoch in range(initial_epoch, epochs):
            self.reset_metrics()
            callbacks.on_epoch_begin(epoch)
            with epoch_iterator.catch_stop_iteration():
                for step, iterator in epoch_iterator:
                    data = iterator[0]
                    x_batch, y_batch = data_adapter_utils.unpack_x_y(data)
                    callbacks.on_train_batch_begin(step)
                    logs = await self.train_on_batch(
                        x=x_batch,
                        y=y_batch,
                        return_dict=True,
                    )
                    callbacks.on_train_batch_end(step, logs)
                    if self.stop_training:
                        break

            # Override with model metrics instead of last step logs if needed.
            epoch_logs = dict(self._get_metrics_result_or_logs(logs))

            # Run validation.
            if validation_data is not None and self._should_eval(epoch, validation_freq):
                # Create EpochIterator for evaluation and cache it.
                if getattr(self, "_eval_epoch_iterator", None) is None:
                    self._eval_epoch_iterator = EpochIterator(
                        x=val_x,
                        y=val_y,
                        batch_size=validation_batch_size or batch_size,
                        steps_per_execution=self.steps_per_execution,
                        steps_per_epoch=validation_steps,
                        shuffle=False,
                    )
                val_logs = await self.evaluate(
                    x=val_x,
                    y=val_y,
                    batch_size=validation_batch_size or batch_size,
                    steps=validation_steps,
                    callbacks=callbacks,
                    _use_cached_eval_dataset=True,
                )
                val_logs = {"val_" + name: val for name, val in val_logs.items()}
                epoch_logs.update(val_logs)

            callbacks.on_epoch_end(epoch, epoch_logs)
            training_logs = epoch_logs
            if self.stop_training:
                break

        if isinstance(self.optimizer, optimizers_module.Optimizer) and epochs > 0:
            await self.optimizer.finalize_variable_values(self.trainable_variables)

        # If _eval_epoch_iterator exists, delete it after all epochs are done.
        if getattr(self, "_eval_epoch_iterator", None) is not None:
            del self._eval_epoch_iterator
        callbacks.on_train_end(logs=training_logs)
        return self.history

    async def evaluate(
        self,
        x=None,
        y=None,
        batch_size=None,
        verbose="auto",
        steps=None,
        callbacks=None,
        return_dict=True,
        **kwargs,
    ):
        """Returns the reward value & metrics values for the program in test mode.

        Computation is done in batches (see the `batch_size` arg.)

        Args:
            x (np.ndarray | generator): Input data. It can be:
                - A NumPy array (or array-like), or a list of `DataModel` arrays
                    (in case the model has multiple inputs).
                - A list of dict mapping input names to the corresponding `DataModel`s,
                    if the program has named inputs.
                - A Python generator function yielding `(inputs, targets)`.
            y (np.ndarray): Target data. Like the input data `x`, it can be either NumPy
                array(s) of `DataModel`(s). If `x` is a Python generator function,
                `y` should not be specified since targets will be obtained from
                `x`.
            batch_size (int): Integer or `None`.
                Number of samples per batch of computation.
                If unspecified, `batch_size` will default to 32.
                Do not specify the `batch_size` if your input data `x` is a
                Python generator function since they generate batches.
            verbose (int | str): `"auto"`, 0, 1, or 2. Verbosity mode.
                0 = silent, 1 = progress bar, 2 = single line.
                `"auto"` becomes 1 for most cases.
                Note that the progress bar is not
                particularly useful when logged to a file, so `verbose=2` is
                recommended when not running interactively
                (e.g. in a production environment). Defaults to `"auto"`.
            steps (int): Integer or `None`.
                Total number of steps (batches of samples) to draw before
                declaring the evaluation round finished. If `steps` is `None`,
                it will run until `x` is exhausted. In the case of an infinitely
                repeating dataset, it will run indefinitely.
            callbacks (list): List of `synalinks.callbacks.Callback` instances.
                List of callbacks to apply during evaluation.
            return_dict (bool): If `True`, reward and metric results are returned as a
                dict, with each key being the name of the metric.
                If `False`, they are returned as a list.

        Returns:
            (float | list | dict): Scalar test reward
                (if the program has a single output and no metrics)
                or list of scalars (if the program has multiple outputs
                and/or metrics). The attribute `program.metrics_names` will give you
                the display labels for the scalar outputs.
        """
        self._assert_compile_called("evaluate")
        use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
        if kwargs:
            raise ValueError(f"Arguments not recognized: {kwargs}")
        # Create an iterator that yields batches of input/target data.
        if use_cached_eval_dataset:
            epoch_iterator = self._eval_epoch_iterator
        else:
            epoch_iterator = EpochIterator(
                x=x,
                y=y,
                batch_size=batch_size,
                steps_per_epoch=steps,
                shuffle=False,
                steps_per_execution=self.steps_per_execution,
            )

        if not all(module.built for module in self._flatten_modules()):
            # Build the model on one batch of data.
            for _, data in epoch_iterator:
                data_batch = data[0]
                self._symbolic_build(data_batch)
                break
        epoch_iterator.reset()

        # Container that configures and calls callbacks.
        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(
                callbacks,
                add_history=False,
                add_progbar=verbose != 0,
                verbose=verbose,
                epochs=1,
                steps=epoch_iterator.num_batches,
                program=self,
            )

        self.stop_evaluating = False
        callbacks.on_test_begin()
        logs = {}
        self.reset_metrics()
        for step, iterator in epoch_iterator:
            callbacks.on_test_batch_begin(step)
            data = iterator[0]
            x_batch, y_batch = data_adapter_utils.unpack_x_y(data)
            logs = await self.test_on_batch(
                x=x_batch,
                y=y_batch,
                return_dict=True,
            )
            callbacks.on_test_batch_end(step, logs)
            if self.stop_evaluating:
                break
        logs = self._get_metrics_result_or_logs(logs)
        callbacks.on_test_end(logs)

        if return_dict:
            return logs
        return self._flatten_metrics_in_order(logs)

    async def predict(
        self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
    ):
        """Generates output predictions for the input samples.

        Computation is done in batches. This method is designed for batch
        processing of large numbers of inputs. It is not intended for use inside
        of loops that iterate over your data and process small numbers of inputs
        at a time.

        For small numbers of inputs that fit in one batch,
        directly use `__call__()` for faster execution, e.g.,
        `program(x)`, or `program(x, training=False)` if you have modules
        that behave differently during inference.

        Args:
            x (np.ndarray | generator): Input data. It can be:
                - A NumPy array (or array-like), or a list of `DataModel` arrays
                    (in case the model has multiple inputs).
                - A list of dict mapping input names to the corresponding `DataModel`s,
                    if the program has named inputs.
                - A Python generator function yielding `(inputs, targets)`.
            batch_size (int): Integer or `None`.
                Number of samples per batch of computation.
                If unspecified, `batch_size` will default to 32.
                Do not specify the `batch_size` if your input data `x` is a
                `synalinks.utils.PyDataset`, `tf.data.Dataset`,
                `torch.utils.data.DataLoader` or Python generator function
                since they generate batches.
            verbose (int): `"auto"`, 0, 1, or 2. Verbosity mode.
                0 = silent, 1 = progress bar, 2 = single line.
                `"auto"` becomes 1 for most cases. Note that the progress bar
                is not particularly useful when logged to a file,
                so `verbose=2` is recommended when not running interactively
                (e.g. in a production environment). Defaults to `"auto"`.
            steps (int): Total number of steps (batches of samples) to draw before
                declaring the prediction round finished. If `steps` is `None`,
                it will run until `x` is exhausted. In the case of an infinitely
                repeating dataset, it will run indefinitely.
            callbacks (list): List of `synalinks.callbacks.Callback` instances.
                List of callbacks to apply during prediction.

        Returns:
            (list): `JsonDataModel` array(s) of predictions.
                If the pipeline failed, a None is added to the predictions.
        """
        # Create an iterator that yields batches of input data.
        epoch_iterator = EpochIterator(
            x=x,
            batch_size=batch_size,
            steps_per_epoch=steps,
            shuffle=False,
            steps_per_execution=self.steps_per_execution,
        )

        # Container that configures and calls callbacks.
        if not isinstance(callbacks, callbacks_module.CallbackList):
            callbacks = callbacks_module.CallbackList(
                callbacks,
                add_history=True,
                add_progbar=verbose != 0,
                verbose=verbose,
                epochs=1,
                steps=epoch_iterator.num_batches,
                model=self,
            )

        self.stop_predicting = False
        callbacks.on_test_begin()
        outputs = []
        for step, iterator in epoch_iterator:
            callbacks.on_predict_batch_begin(step)
            data = iterator[0]
            x_batch, _ = data_adapter_utils.unpack_x_y(data)
            batch_outputs = await self.predict_on_batch(x_batch)
            outputs.extend(batch_outputs)
            callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
            if self.stop_predicting:
                break
        callbacks.on_predict_end()
        return np.array(outputs, dtype="object")

    async def train_on_batch(
        self,
        x,
        y=None,
        return_dict=False,
    ):
        """Runs a single backpropagation/optimization update on a single batch of data.

        Args:
            x (np.ndarray): Input data. Must be array-like.
            y (np.ndarray): Target data. Must be array-like.
            return_dict (bool): If `True`, reward and metric results are returned as a
                dict, with each key being the name of the metric. If `False`,
                they are returned as a list.

        Returns:
            (float | list | dict): A scalar reward value
                (when no metrics and `return_dict=False`), a list of reward
                and metric values (if there are metrics and `return_dict=False`),
                or a dict of metric and reward values (if `return_dict=True`).
        """
        y_pred = await self.predict_on_batch(x, training=True)

        reward = await self.compute_reward(
            x=x,
            y=y,
            y_pred=y_pred,
            training=True,
        )

        await self._reward_tracker.update_state(reward)

        # Perform training/optimization
        if self.trainable_variables:
            await self.optimizer.apply_optimization(
                self.trainable_variables,
                reward=reward,
            )
        else:
            warnings.warn("The program does not have any trainable variables.")

        metrics = await self.compute_metrics(x, y, y_pred)

        if return_dict:
            return metrics
        return self._flatten_metrics_in_order(metrics)

    async def test_on_batch(
        self,
        x,
        y=None,
        return_dict=False,
    ):
        """Test the program on a single batch of samples.

        Args:
            x (np.ndarray): Input data. Must be array-like.
            y (np.ndarray): Target data. Must be array-like.
            return_dict (bool): If `True`, reward and metric results are returned as a
                dict, with each key being the name of the metric. If `False`,
                they are returned as a list.

        Returns:
            (float | list | dict): A scalar reward value
                (when no metrics and `return_dict=False`), a list of reward
                and metric values (if there are metrics and `return_dict=False`),
                or a dict of metric and reward values (if `return_dict=True`).
        """
        y_pred = await self.predict_on_batch(x, training=False)

        reward = await self.compute_reward(
            x=x,
            y=y,
            y_pred=y_pred,
            training=False,
        )

        await self._reward_tracker.update_state(reward)

        metrics = await self.compute_metrics(x, y, y_pred)

        if return_dict:
            return metrics
        return self._flatten_metrics_in_order(metrics)

    async def predict_on_batch(self, x, training=False):
        """Returns predictions for a single batch of samples.

        Args:
            x (np.ndarray): Input data. Must be array-like.
            training (bool): Boolean. True if training.

        Returns:
            (list): list(s) of JsonDataModel predictions.
        """
        tasks = []
        for inputs in x:
            tasks.append(self(inputs, training=training))
        y_pred = await asyncio.gather(*tasks)
        return y_pred

    def get_compile_config(self):
        """Returns a serialized config with information for compiling the program.

        This method returns a config dictionary containing all the information
        (optimizer, reward, metrics, etc.) with which the program was compiled.

        Returns:
            (dict): A dict containing information for compiling the program.
        """
        if self.compiled and hasattr(self, "_compile_config"):
            return self._compile_config.serialize()

    def compile_from_config(self, config):
        """Compiles the program with the information given in config.

        This method uses the information in the config (optimizer, reward,
        metrics, etc.) to compile the program.

        Args:
            config (dict): Dict containing information for compiling the program.
        """
        has_overridden_compile = self.__class__.compile != Trainer.compile
        if has_overridden_compile:
            warnings.warn(
                "`compile()` was not called as part of program loading "
                "because the program's `compile()` method is custom. "
                "All subclassed Models that have `compile()` "
                "overridden should also override "
                "`get_compile_config()` and `compile_from_config(config)`. "
                "Alternatively, you can "
                "call `compile()` manually after loading.",
                stacklevel=2,
            )
            return
        config = serialization_lib.deserialize_synalinks_object(config)
        self.compile(**config)
        if hasattr(self, "optimizer") and self.built:
            # Create optimizer variables.
            self.optimizer.build(self.trainable_variables)

    def _should_reward(self, epoch, validation_freq):
        epoch = epoch + 1  # one-index the user-facing epoch.
        if isinstance(validation_freq, int):
            return epoch % validation_freq == 0
        elif isinstance(validation_freq, list):
            return epoch in validation_freq
        else:
            raise ValueError(
                "Expected `validation_freq` to be a list or int. "
                f"Received: validation_freq={validation_freq} of the "
                f"type {type(validation_freq)}."
            )

    def _get_metrics_result_or_logs(self, logs):
        """Returns program metrics as a dict if the keys match with input logs.

        When the training / evaluation is performed with an asynchronous steps,
        the last scheduled `train / test_step` may not give the latest metrics
        because it is not guaranteed to be executed the last. This method gets
        metrics from the program directly instead of relying on the return from
        last step function.

        When the user has custom train / test step functions, the metrics
        returned may be different from `Program.metrics`. In those instances,
        this function will be no-op and return the logs passed in.

        Args:
            logs (dict): A `dict` of metrics returned by train / test step function.

        Returns:
            (dict): A `dict` containing values of the metrics listed in `self.metrics`
                when logs and program metrics keys match. Otherwise it returns input
                `logs`.
        """
        metric_logs = self.get_metrics_result()
        # Verify that train / test step logs passed and metric logs have
        # matching keys. It could be different when using custom step functions,
        # in which case we return the logs from the last step.
        if isinstance(logs, dict) and set(logs.keys()) == set(metric_logs.keys()):
            return metric_logs
        return logs

    def _flatten_metrics_in_order(self, logs):
        """Turns `logs` dict into a list as per key order of `metrics_names`."""
        metric_names = []
        for metric in self.metrics:
            if isinstance(metric, CompileMetrics):
                metric_names += [sub_metric.name for sub_metric in metric.metrics]
            else:
                metric_names.append(metric.name)
        results = []
        for name in metric_names:
            if name in logs:
                results.append(logs[name])
        for key in sorted(logs.keys()):
            if key not in metric_names:
                results.append(logs[key])
        if len(results) == 1:
            return results[0]
        return results

    def _assert_compile_called(self, method_name=None):
        if not self.compiled:
            msg = "You must call `compile()` before "
            if metrics_module:
                msg += "using the program."
            else:
                msg += f"calling `{method_name}()`."
            raise ValueError(msg)

    def _symbolic_build(self, iterator=None, data_batch=None):
        program_unbuilt = not all(module.built for module in self._flatten_modules())
        compile_metrics_unbuilt = (
            self._compile_metrics is not None and not self._compile_metrics.built
        )
        compile_reward_unbuilt = (
            self._compile_reward is not None and not self._compile_reward.built
        )
        optimizer_unbuilt = self.optimizer is not None and not self.optimizer.built
        if program_unbuilt or compile_metrics_unbuilt or compile_reward_unbuilt:
            # Create symbolic data_models matching an input batch.

            def to_symbolic_input(v):
                if v is None:
                    return None
                if backend.is_data_model(v):
                    return backend.SymbolicDataModel(schema=v.schema())
                else:
                    return backend.SymbolicDataModel(schema=v.schema)

            if data_batch is None:
                for _, data_or_iterator in iterator:
                    if isinstance(data_or_iterator, (list, tuple)):
                        data_batch = data_or_iterator[0]
                    else:
                        data_batch = next(data_or_iterator)
                    break
            data_batch = tree.map_structure(to_symbolic_input, data_batch)
            (x, y) = data_batch
            # Build all program state with `backend.compute_output_spec`.
            try:
                y_pred = asyncio.get_event_loop().run_until_complete(
                    backend.compute_output_spec(self, x, training=False)
                )
            except Exception as e:
                raise RuntimeError(
                    "Unable to automatically build the program. "
                    "Please build it yourself before calling "
                    "fit/evaluate/predict. "
                    "A program is 'built' when its variables have "
                    "been created and its `self.built` attribute "
                    "is True. Usually, calling the program on a batch "
                    "of data is the right way to build it.\n"
                    "Exception encountered:\n"
                    f"'{e}'"
                )
            if compile_metrics_unbuilt:
                # Build all metric state with `backend.compute_output_spec`.
                asyncio.get_event_loop().run_until_complete(
                    backend.compute_output_spec(
                        self.compute_metrics,
                        x,
                        y,
                        y_pred,
                    )
                )
            if compile_reward_unbuilt:
                # Build `CompileReward` state with `backend.compute_output_spec`.
                asyncio.get_event_loop().run_until_complete(
                    backend.compute_output_spec(
                        self._compute_reward,
                        x,
                        y,
                        y_pred,
                        training=False,
                    )
                )
        if optimizer_unbuilt:
            # Build optimizer
            self.optimizer.build(self.trainable_variables)
        self._post_build()

    def _assert_compile_called(self, method_name=None):
        if not self.compiled:
            msg = "You must call `compile()` before "
            if metrics_module:
                msg += "using the model."
            else:
                msg += f"calling `{method_name}()`."
            raise ValueError(msg)

    def _should_eval(self, epoch, validation_freq):
        epoch = epoch + 1  # one-index the user-facing epoch.
        if isinstance(validation_freq, int):
            return epoch % validation_freq == 0
        elif isinstance(validation_freq, list):
            return epoch in validation_freq
        else:
            raise ValueError(
                "Expected `validation_freq` to be a list or int. "
                f"Received: validation_freq={validation_freq} of the "
                f"type {type(validation_freq)}."
            )

_flatten_metrics_in_order(logs)

Turns logs dict into a list as per key order of metrics_names.

Source code in synalinks/src/trainers/trainer.py
def _flatten_metrics_in_order(self, logs):
    """Turns `logs` dict into a list as per key order of `metrics_names`."""
    metric_names = []
    for metric in self.metrics:
        if isinstance(metric, CompileMetrics):
            metric_names += [sub_metric.name for sub_metric in metric.metrics]
        else:
            metric_names.append(metric.name)
    results = []
    for name in metric_names:
        if name in logs:
            results.append(logs[name])
    for key in sorted(logs.keys()):
        if key not in metric_names:
            results.append(logs[key])
    if len(results) == 1:
        return results[0]
    return results

_get_metrics_result_or_logs(logs)

Returns program metrics as a dict if the keys match with input logs.

When the training / evaluation is performed with an asynchronous steps, the last scheduled train / test_step may not give the latest metrics because it is not guaranteed to be executed the last. This method gets metrics from the program directly instead of relying on the return from last step function.

When the user has custom train / test step functions, the metrics returned may be different from Program.metrics. In those instances, this function will be no-op and return the logs passed in.

Parameters:

Name Type Description Default
logs dict

A dict of metrics returned by train / test step function.

required

Returns:

Type Description
dict

A dict containing values of the metrics listed in self.metrics when logs and program metrics keys match. Otherwise it returns input logs.

Source code in synalinks/src/trainers/trainer.py
def _get_metrics_result_or_logs(self, logs):
    """Returns program metrics as a dict if the keys match with input logs.

    When the training / evaluation is performed with an asynchronous steps,
    the last scheduled `train / test_step` may not give the latest metrics
    because it is not guaranteed to be executed the last. This method gets
    metrics from the program directly instead of relying on the return from
    last step function.

    When the user has custom train / test step functions, the metrics
    returned may be different from `Program.metrics`. In those instances,
    this function will be no-op and return the logs passed in.

    Args:
        logs (dict): A `dict` of metrics returned by train / test step function.

    Returns:
        (dict): A `dict` containing values of the metrics listed in `self.metrics`
            when logs and program metrics keys match. Otherwise it returns input
            `logs`.
    """
    metric_logs = self.get_metrics_result()
    # Verify that train / test step logs passed and metric logs have
    # matching keys. It could be different when using custom step functions,
    # in which case we return the logs from the last step.
    if isinstance(logs, dict) and set(logs.keys()) == set(metric_logs.keys()):
        return metric_logs
    return logs

compile(optimizer=None, reward=None, reward_weights=None, metrics=None, run_eagerly=False, steps_per_execution=1)

Configures the program for training.

Example:

program.compile(
    optimizer=synalinks.optimizers.RandomFewShot(),
    reward=synalinks.rewards.ExactMatch(),
    metrics=[
        synalinks.metrics.MeanMetricWrapper(synalinks.rewards.exact_match),
    ],
)

Parameters:

Name Type Description Default
optimizer Optimizer

Optimizer instance. See synalinks.optimizers.

None
reward Reward

Reward function. A synalinks.rewards.Reward instance. See synalinks.rewards. A reward function is any callable with the signature reward = fn(y_true, y_pred), where y_true are the ground truth values, and y_pred are the program's predictions. y_true should be a list of batch size length [d0, .. dN]. y_pred should be a list of batch size length [d0, .. dN]. The reward function should return a float.

None
reward_weights list

Optional list specifying scalar coefficients (Python floats) to weight the reward contributions of different program outputs. The reward value that will be maximized by the program will then be the weighted sum of all individual rewards, weighted by the reward_weights coefficients. It is expected to have a 1:1 mapping to the program's outputs.

None
metrics list

List of metrics to be evaluated by the program during training and testing. Each of it is a synalinks.metrics.Metric instance. See synalinks.metrics. A function is any callable with the signature result = fn(y_true, y_pred).

None
run_eagerly bool

If True, this program's forward pass will never be compiled. It is recommended to leave this as False when training (for best performance), and to set it to True when debugging.

False
steps_per_execution int

The number of batches to run during each a single compiled function call. Running multiple batches inside a single compiled function call can greatly improve performance on TPUs or small programs with a large Python overhead. At most, one full epoch will be run each execution. If a number larger than the size of the epoch is passed, the execution will be truncated to the size of the epoch. Note that if steps_per_execution is set to N, Callback.on_batch_begin and Callback.on_batch_end methods will only be called every N batches (i.e. before/after each compiled function execution).

1
Source code in synalinks/src/trainers/trainer.py
@tracking.no_automatic_dependency_tracking
def compile(
    self,
    optimizer=None,
    reward=None,
    reward_weights=None,
    metrics=None,
    run_eagerly=False,
    steps_per_execution=1,
):
    """Configures the program for training.

    Example:

    ```python
    program.compile(
        optimizer=synalinks.optimizers.RandomFewShot(),
        reward=synalinks.rewards.ExactMatch(),
        metrics=[
            synalinks.metrics.MeanMetricWrapper(synalinks.rewards.exact_match),
        ],
    )
    ```

    Args:
        optimizer (Optimizer): Optimizer instance. See `synalinks.optimizers`.
        reward (Reward): Reward function. A `synalinks.rewards.Reward`
            instance. See `synalinks.rewards`. A reward function is
            any callable with the signature `reward = fn(y_true, y_pred)`,
            where `y_true` are the ground truth values, and `y_pred`
            are the program's predictions.
            `y_true` should be a list of batch size length `[d0, .. dN]`.
            `y_pred` should be a list of batch size length `[d0, .. dN]`.
            The reward function should return a float.
        reward_weights (list): Optional list specifying scalar coefficients
            (Python floats) to weight the reward contributions of
            different program outputs. The reward value that will be maximized
            by the program will then be the *weighted sum* of all individual
            rewards, weighted by the `reward_weights` coefficients. It is
            expected to have a 1:1 mapping to the program's outputs.
        metrics (list): List of metrics to be evaluated by the program during
            training and testing. Each of it is a `synalinks.metrics.Metric`
            instance. See `synalinks.metrics`. A function is any callable with the
            signature `result = fn(y_true, y_pred)`.
        run_eagerly (bool): If `True`, this program's forward pass
             will never be compiled. It is recommended to leave this
             as `False` when training (for best performance),
             and to set it to `True` when debugging.
        steps_per_execution (int): The number of batches to run
            during each a single compiled function call. Running multiple
            batches inside a single compiled function call can
            greatly improve performance on TPUs or small programs with a large
            Python overhead. At most, one full epoch will be run each
            execution. If a number larger than the size of the epoch is
            passed, the execution will be truncated to the size of the
            epoch. Note that if `steps_per_execution` is set to `N`,
            `Callback.on_batch_begin` and `Callback.on_batch_end` methods
            will only be called every `N` batches (i.e. before/after
            each compiled function execution).
    """
    self._clear_previous_trainer_metrics()
    self._optimizer = optimizer

    if hasattr(self, "output_names"):
        output_names = self.output_names
    else:
        output_names = None
    if reward is not None:
        self._compile_reward = CompileReward(
            reward, reward_weights, output_names=output_names
        )
        self.reward = reward
    if metrics is not None:
        self._compile_metrics = CompileMetrics(metrics, output_names=output_names)
    self.run_eagerly = run_eagerly
    self.stop_training = False
    self.compiled = True
    self._reward_tracker = metrics_module.Mean(name="reward")
    self.steps_per_execution = steps_per_execution

    self._compile_config = serialization_lib.SerializableDict(
        optimizer=optimizer,
        reward=reward,
        reward_weights=reward_weights,
        metrics=metrics,
        run_eagerly=run_eagerly,
        steps_per_execution=steps_per_execution,
    )

compile_from_config(config)

Compiles the program with the information given in config.

This method uses the information in the config (optimizer, reward, metrics, etc.) to compile the program.

Parameters:

Name Type Description Default
config dict

Dict containing information for compiling the program.

required
Source code in synalinks/src/trainers/trainer.py
def compile_from_config(self, config):
    """Compiles the program with the information given in config.

    This method uses the information in the config (optimizer, reward,
    metrics, etc.) to compile the program.

    Args:
        config (dict): Dict containing information for compiling the program.
    """
    has_overridden_compile = self.__class__.compile != Trainer.compile
    if has_overridden_compile:
        warnings.warn(
            "`compile()` was not called as part of program loading "
            "because the program's `compile()` method is custom. "
            "All subclassed Models that have `compile()` "
            "overridden should also override "
            "`get_compile_config()` and `compile_from_config(config)`. "
            "Alternatively, you can "
            "call `compile()` manually after loading.",
            stacklevel=2,
        )
        return
    config = serialization_lib.deserialize_synalinks_object(config)
    self.compile(**config)
    if hasattr(self, "optimizer") and self.built:
        # Create optimizer variables.
        self.optimizer.build(self.trainable_variables)

compute_metrics(x, y, y_pred) async

Update metric states and collect all metrics to be returned.

Subclasses can optionally override this method to provide custom metric updating and collection logic. Custom metrics are not passed in compile(), they can be created in __init__ or build. They are automatically tracked and returned by self.metrics. ```

Args: x: Input data. y: Target data. y_pred: Predictions returned by the program output of program.call(x).

Returns: A dict containing values that will be passed to synalinks.callbacks.CallbackList.on_train_batch_end(). Typically, the values of the metrics listed in self.metrics are returned. Example: {'reward': 0.2, 'accuracy': 0.7}.

Source code in synalinks/src/trainers/trainer.py
async def compute_metrics(self, x, y, y_pred):
    """Update metric states and collect all metrics to be returned.

    Subclasses can optionally override this method to provide custom metric
    updating and collection logic. Custom metrics are not passed in
    `compile()`, they can be created in `__init__` or `build`. They are
    automatically tracked and returned by `self.metrics`.
    ```

    Args:
        x: Input data.
        y: Target data.
        y_pred: Predictions returned by the program output of `program.call(x)`.

    Returns:
        A `dict` containing values that will be passed to
            `synalinks.callbacks.CallbackList.on_train_batch_end()`. Typically,
            the values of the metrics listed in `self.metrics` are returned.
            Example: `{'reward': 0.2, 'accuracy': 0.7}`.
    """
    del x  # The default implementation does not use `x`.
    if self._compile_metrics is not None:
        for y_t, y_p in zip(y, y_pred):
            await self._compile_metrics.update_state(y_t, y_p)
    return self.get_metrics_result()

compute_reward(x=None, y=None, y_pred=None, sample_weight=None, training=True) async

Compute the total reward, validate it, and return it.

Subclasses can optionally override this method to provide custom reward computation logic.

Parameters:

Name Type Description Default
x list

Input data.

None
y list

Target data.

None
y_pred list

Predictions returned by the program (output of program(x)).

None
training bool

Whether we are training or evaluating the program.

True

Returns:

Type Description
float | None

The total reward as a scalar, or None if no reward results (which is the case when called by Program.test_step).

Source code in synalinks/src/trainers/trainer.py
async def compute_reward(
    self,
    x=None,
    y=None,
    y_pred=None,
    sample_weight=None,
    training=True,
):
    """Compute the total reward, validate it, and return it.

    Subclasses can optionally override this method to provide custom reward
    computation logic.

    Args:
        x (list): Input data.
        y (list): Target data.
        y_pred (list): Predictions returned by the program (output of `program(x)`).
        training (bool): Whether we are training or evaluating the program.

    Returns:
        (float | None): The total reward as a scalar, or `None` if no reward results
            (which is the case when called by `Program.test_step`).
    """
    # The default implementation does not use `x` or `training`.
    del x
    del training
    rewards = []
    if self._compile_reward is not None:
        for y_t, y_p in zip(y, y_pred):
            reward = await self._compile_reward(y_t, y_p)
            if reward is not None:
                rewards.append(reward)
    for reward in self.rewards:
        rewards.append(numpy.sum(reward))
    if len(rewards) == 1:
        total_reward = rewards[0]
    elif len(rewards) == 0:
        total_reward = numpy.zeros(())
    else:
        total_reward = numpy.mean(rewards)
    return float(total_reward)

evaluate(x=None, y=None, batch_size=None, verbose='auto', steps=None, callbacks=None, return_dict=True, **kwargs) async

Returns the reward value & metrics values for the program in test mode.

Computation is done in batches (see the batch_size arg.)

Parameters:

Name Type Description Default
x ndarray | generator

Input data. It can be: - A NumPy array (or array-like), or a list of DataModel arrays (in case the model has multiple inputs). - A list of dict mapping input names to the corresponding DataModels, if the program has named inputs. - A Python generator function yielding (inputs, targets).

None
y ndarray

Target data. Like the input data x, it can be either NumPy array(s) of DataModel(s). If x is a Python generator function, y should not be specified since targets will be obtained from x.

None
batch_size int

Integer or None. Number of samples per batch of computation. If unspecified, batch_size will default to 32. Do not specify the batch_size if your input data x is a Python generator function since they generate batches.

None
verbose int | str

"auto", 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = single line. "auto" becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (e.g. in a production environment). Defaults to "auto".

'auto'
steps int

Integer or None. Total number of steps (batches of samples) to draw before declaring the evaluation round finished. If steps is None, it will run until x is exhausted. In the case of an infinitely repeating dataset, it will run indefinitely.

None
callbacks list

List of synalinks.callbacks.Callback instances. List of callbacks to apply during evaluation.

None
return_dict bool

If True, reward and metric results are returned as a dict, with each key being the name of the metric. If False, they are returned as a list.

True

Returns:

Type Description
float | list | dict

Scalar test reward (if the program has a single output and no metrics) or list of scalars (if the program has multiple outputs and/or metrics). The attribute program.metrics_names will give you the display labels for the scalar outputs.

Source code in synalinks/src/trainers/trainer.py
async def evaluate(
    self,
    x=None,
    y=None,
    batch_size=None,
    verbose="auto",
    steps=None,
    callbacks=None,
    return_dict=True,
    **kwargs,
):
    """Returns the reward value & metrics values for the program in test mode.

    Computation is done in batches (see the `batch_size` arg.)

    Args:
        x (np.ndarray | generator): Input data. It can be:
            - A NumPy array (or array-like), or a list of `DataModel` arrays
                (in case the model has multiple inputs).
            - A list of dict mapping input names to the corresponding `DataModel`s,
                if the program has named inputs.
            - A Python generator function yielding `(inputs, targets)`.
        y (np.ndarray): Target data. Like the input data `x`, it can be either NumPy
            array(s) of `DataModel`(s). If `x` is a Python generator function,
            `y` should not be specified since targets will be obtained from
            `x`.
        batch_size (int): Integer or `None`.
            Number of samples per batch of computation.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your input data `x` is a
            Python generator function since they generate batches.
        verbose (int | str): `"auto"`, 0, 1, or 2. Verbosity mode.
            0 = silent, 1 = progress bar, 2 = single line.
            `"auto"` becomes 1 for most cases.
            Note that the progress bar is not
            particularly useful when logged to a file, so `verbose=2` is
            recommended when not running interactively
            (e.g. in a production environment). Defaults to `"auto"`.
        steps (int): Integer or `None`.
            Total number of steps (batches of samples) to draw before
            declaring the evaluation round finished. If `steps` is `None`,
            it will run until `x` is exhausted. In the case of an infinitely
            repeating dataset, it will run indefinitely.
        callbacks (list): List of `synalinks.callbacks.Callback` instances.
            List of callbacks to apply during evaluation.
        return_dict (bool): If `True`, reward and metric results are returned as a
            dict, with each key being the name of the metric.
            If `False`, they are returned as a list.

    Returns:
        (float | list | dict): Scalar test reward
            (if the program has a single output and no metrics)
            or list of scalars (if the program has multiple outputs
            and/or metrics). The attribute `program.metrics_names` will give you
            the display labels for the scalar outputs.
    """
    self._assert_compile_called("evaluate")
    use_cached_eval_dataset = kwargs.pop("_use_cached_eval_dataset", False)
    if kwargs:
        raise ValueError(f"Arguments not recognized: {kwargs}")
    # Create an iterator that yields batches of input/target data.
    if use_cached_eval_dataset:
        epoch_iterator = self._eval_epoch_iterator
    else:
        epoch_iterator = EpochIterator(
            x=x,
            y=y,
            batch_size=batch_size,
            steps_per_epoch=steps,
            shuffle=False,
            steps_per_execution=self.steps_per_execution,
        )

    if not all(module.built for module in self._flatten_modules()):
        # Build the model on one batch of data.
        for _, data in epoch_iterator:
            data_batch = data[0]
            self._symbolic_build(data_batch)
            break
    epoch_iterator.reset()

    # Container that configures and calls callbacks.
    if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=False,
            add_progbar=verbose != 0,
            verbose=verbose,
            epochs=1,
            steps=epoch_iterator.num_batches,
            program=self,
        )

    self.stop_evaluating = False
    callbacks.on_test_begin()
    logs = {}
    self.reset_metrics()
    for step, iterator in epoch_iterator:
        callbacks.on_test_batch_begin(step)
        data = iterator[0]
        x_batch, y_batch = data_adapter_utils.unpack_x_y(data)
        logs = await self.test_on_batch(
            x=x_batch,
            y=y_batch,
            return_dict=True,
        )
        callbacks.on_test_batch_end(step, logs)
        if self.stop_evaluating:
            break
    logs = self._get_metrics_result_or_logs(logs)
    callbacks.on_test_end(logs)

    if return_dict:
        return logs
    return self._flatten_metrics_in_order(logs)

fit(x=None, y=None, batch_size=None, epochs=1, verbose='auto', callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, initial_epoch=0, steps_per_epoch=None, validation_steps=None, validation_batch_size=None, validation_freq=1) async

Trains the program for a fixed number of epochs (dataset iterations).

Parameters:

Name Type Description Default
x ndarray | generator

Input data. It can be: - A NumPy array (or array-like), or a list of DataModel arrays (in case the model has multiple inputs). - A list of dict mapping input names to the corresponding DataModels, if the program has named inputs. - A Python generator function yielding (inputs, targets).

None
y ndarray

Target data. Like the input data x, it can be either NumPy array(s) of DataModel(s). If x is a Python generator function, y should not be specified since targets will be obtained from x.

None
batch_size int

Integer or None. Number of samples per batch of computation. If unspecified, batch_size will default to 32. Do not specify the batch_size if your input data x is a Python generator function since they generate batches.

None
epochs int

Integer. Number of epochs to train the program. An epoch is an iteration over the entire x and y data provided (unless the steps_per_epoch flag is set to something other than None). Note that in conjunction with initial_epoch, epochs is to be understood as "final epoch". The program is not trained for a number of iterations given by epochs, but merely until the epoch of index epochs is reached.

1
verbose int

"auto", 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = one line per epoch. "auto" becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (e.g., in a production environment). Defaults to "auto".

'auto'
callbacks list

List of synalinks.callbacks.Callback instances. List of callbacks to apply during training. See synalinks.callbacks. Note synalinks.callbacks.ProgbarLogger and synalinks.callbacks.History callbacks are created automatically and need not be passed to program.fit(). synalinks.callbacks.ProgbarLogger is created or not based on the verbose argument in program.fit().

None
validation_split float

Float between 0 and 1. Fraction of the training data to be used as validation data. The program will set apart this fraction of the training data, will not train on it, and will evaluate the reward and any program metrics on this data at the end of each epoch. The validation data is selected from the last samples in the x and y data provided, before shuffling. This argument is only supported when x and y are made of data_models. If both validation_data and validation_split are provided, validation_data will override validation_split.

0.0
validation_data tuple | iterator

Data on which to evaluate the reward and any program metrics at the end of each epoch. The program will not be trained on this data. validation_data will override validation_split. It can be: - A tuple (x_val, y_val) of DataModels lists.

None
shuffle bool

Whether to shuffle the training data before each epoch. This argument is ignored when x is a Python generator function.

True
initial_epoch int

Integer. Epoch at which to start training (useful for resuming a previous training run).

0
steps_per_epoch int

Integer or None. Total number of steps (batches of samples) before declaring one epoch finished and starting the next epoch. When training with input data_models arrays, the default None means that the value used is the number of samples in your dataset divided by the batch size, or 1 if that cannot be determined. If x is a Python generator function, the epoch will run until the input dataset is exhausted. When passing an infinitely repeating dataset, you must specify the steps_per_epoch argument, otherwise the training will run indefinitely.

None
validation_steps int

Integer or None. Only relevant if validation_data is provided. Total number of steps (batches of samples) to draw before stopping when performing validation at the end of every epoch. If validation_steps is None, validation will run until the validation_data dataset is exhausted. In the case of an infinitely repeating dataset, it will run indefinitely. If validation_steps is specified and only part of the dataset is consumed, the evaluation will start from the beginning of the dataset at each epoch. This ensures that the same validation samples are used every time.

None
validation_batch_size int

Integer or None. Number of samples per validation batch. If unspecified, will default to batch_size. Do not specify the validation_batch_size if your data is a synalinks.utils.PyDataset, tf.data.Dataset, torch.utils.data.DataLoader or Python generator function since they generate batches.

None
validation_freq int

Only relevant if validation data is provided. Specifies how many training epochs to run before a new validation run is performed, e.g. validation_freq=2 runs validation every 2 epochs.

1

Returns:

Type Description
History

A History object. Its History.history attribute is a record of training reward values and metrics values at successive epochs, as well as validation reward values and validation metrics values (if applicable).

Source code in synalinks/src/trainers/trainer.py
async def fit(
    self,
    x=None,
    y=None,
    batch_size=None,
    epochs=1,
    verbose="auto",
    callbacks=None,
    validation_split=0.0,
    validation_data=None,
    shuffle=True,
    initial_epoch=0,
    steps_per_epoch=None,
    validation_steps=None,
    validation_batch_size=None,
    validation_freq=1,
):
    """Trains the program for a fixed number of epochs (dataset iterations).

    Args:
        x (np.ndarray | generator): Input data. It can be:
            - A NumPy array (or array-like), or a list of `DataModel` arrays
                (in case the model has multiple inputs).
            - A list of dict mapping input names to the corresponding `DataModel`s,
                if the program has named inputs.
            - A Python generator function yielding `(inputs, targets)`.
        y (np.ndarray): Target data. Like the input data `x`, it can be either NumPy
            array(s) of `DataModel`(s). If `x` is a Python generator function,
            `y` should not be specified since targets will be obtained from
            `x`.
        batch_size (int): Integer or `None`.
            Number of samples per batch of computation.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your input data `x` is a
            Python generator function since they generate batches.
        epochs (int): Integer. Number of epochs to train the program.
            An epoch is an iteration over the entire `x` and `y`
            data provided (unless the `steps_per_epoch` flag is set to
            something other than None).
            Note that in conjunction with `initial_epoch`,
            `epochs` is to be understood as "final epoch".
            The program is not trained for a number of iterations
            given by `epochs`, but merely until the epoch
            of index `epochs` is reached.
        verbose (int): `"auto"`, 0, 1, or 2. Verbosity mode.
            0 = silent, 1 = progress bar, 2 = one line per epoch.
            "auto" becomes 1 for most cases.
            Note that the progress bar is not
            particularly useful when logged to a file,
            so `verbose=2` is recommended when not running interactively
            (e.g., in a production environment). Defaults to `"auto"`.
        callbacks (list): List of `synalinks.callbacks.Callback` instances.
            List of callbacks to apply during training.
            See `synalinks.callbacks`. Note
            `synalinks.callbacks.ProgbarLogger` and
            `synalinks.callbacks.History` callbacks are created
            automatically and need not be passed to `program.fit()`.
            `synalinks.callbacks.ProgbarLogger` is created
            or not based on the `verbose` argument in `program.fit()`.
        validation_split (float): Float between 0 and 1.
            Fraction of the training data to be used as validation data.
            The program will set apart this fraction of the training data,
            will not train on it, and will evaluate the reward and any program
            metrics on this data at the end of each epoch. The validation
            data is selected from the last samples in the `x` and `y` data
            provided, before shuffling.
            This argument is only supported when `x` and `y` are made of
            data_models.
            If both `validation_data` and `validation_split` are provided,
            `validation_data` will override `validation_split`.
        validation_data (tuple | iterator): Data on which to evaluate
            the reward and any program metrics at the end of each epoch.
            The program will not be trained on this data.
            `validation_data` will override `validation_split`.
            It can be:
            - A tuple `(x_val, y_val)` of `DataModel`s lists.
        shuffle (bool): Whether to shuffle the training data before each
            epoch. This argument is ignored when `x` is a Python generator function.
        initial_epoch (int): Integer.
            Epoch at which to start training
            (useful for resuming a previous training run).
        steps_per_epoch (int): Integer or `None`.
            Total number of steps (batches of samples) before declaring one
            epoch finished and starting the next epoch. When training with
            input data_models arrays, the default `None` means that the
            value used is the number of samples in your dataset divided by
            the batch size, or 1 if that cannot be determined.
            If `x` is a Python generator function, the
            epoch will run until the input dataset is exhausted. When
            passing an infinitely repeating dataset, you must specify the
            `steps_per_epoch` argument, otherwise the training will run
            indefinitely.
        validation_steps (int): Integer or `None`.
            Only relevant if `validation_data` is provided.
            Total number of steps (batches of samples) to draw before
            stopping when performing validation at the end of every epoch.
            If `validation_steps` is `None`, validation will run until the
            `validation_data` dataset is exhausted. In the case of an
            infinitely repeating dataset, it will run indefinitely. If
            `validation_steps` is specified and only part of the dataset
            is consumed, the evaluation will start from the beginning of the
            dataset at each epoch. This ensures that the same validation
            samples are used every time.
        validation_batch_size (int): Integer or `None`.
            Number of samples per validation batch.
            If unspecified, will default to `batch_size`.
            Do not specify the `validation_batch_size` if your data is a
            `synalinks.utils.PyDataset`, `tf.data.Dataset`,
            `torch.utils.data.DataLoader` or Python generator function
            since they generate batches.
        validation_freq (int): Only relevant if validation data is provided.
            Specifies how many training epochs to run
            before a new validation run is performed,
            e.g. `validation_freq=2` runs validation every 2 epochs.

    Returns:
        (History): A `History` object. Its `History.history` attribute is
            a record of training reward values and metrics values
            at successive epochs, as well as validation reward values
            and validation metrics values (if applicable).
    """
    self._assert_compile_called("fit")
    # TODO: respect compiled trainable state
    self._eval_epoch_iterator = None
    if validation_split and validation_data is None:
        # Create the validation data using the training data. Only supported
        # for numpy arrays.
        (x, y), validation_data = array_slicing.train_validation_split(
            (x, y), validation_split=validation_split
        )

    if validation_data is not None:
        (val_x, val_y) = data_adapter_utils.unpack_x_y(validation_data)
    # Create an iterator that yields batches of input/target data.
    epoch_iterator = EpochIterator(
        x=x,
        y=y,
        batch_size=batch_size,
        steps_per_epoch=steps_per_epoch,
        shuffle=False,
        steps_per_execution=self.steps_per_execution,
    )

    if not all(module.built for module in self._flatten_modules()):
        # Build the model on one batch of data.
        for _, data in epoch_iterator:
            data_batch = data[0]
            self._symbolic_build(data_batch)
            break
    epoch_iterator.reset()

    # Container that configures and calls callbacks.
    if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=True,
            add_progbar=verbose != 0,
            verbose=verbose,
            epochs=epochs,
            steps=steps_per_epoch,
            program=self,
        )

    self.stop_training = False
    callbacks.on_train_begin()
    training_logs = None
    logs = {}
    initial_epoch = self._initial_epoch or initial_epoch

    for epoch in range(initial_epoch, epochs):
        self.reset_metrics()
        callbacks.on_epoch_begin(epoch)
        with epoch_iterator.catch_stop_iteration():
            for step, iterator in epoch_iterator:
                data = iterator[0]
                x_batch, y_batch = data_adapter_utils.unpack_x_y(data)
                callbacks.on_train_batch_begin(step)
                logs = await self.train_on_batch(
                    x=x_batch,
                    y=y_batch,
                    return_dict=True,
                )
                callbacks.on_train_batch_end(step, logs)
                if self.stop_training:
                    break

        # Override with model metrics instead of last step logs if needed.
        epoch_logs = dict(self._get_metrics_result_or_logs(logs))

        # Run validation.
        if validation_data is not None and self._should_eval(epoch, validation_freq):
            # Create EpochIterator for evaluation and cache it.
            if getattr(self, "_eval_epoch_iterator", None) is None:
                self._eval_epoch_iterator = EpochIterator(
                    x=val_x,
                    y=val_y,
                    batch_size=validation_batch_size or batch_size,
                    steps_per_execution=self.steps_per_execution,
                    steps_per_epoch=validation_steps,
                    shuffle=False,
                )
            val_logs = await self.evaluate(
                x=val_x,
                y=val_y,
                batch_size=validation_batch_size or batch_size,
                steps=validation_steps,
                callbacks=callbacks,
                _use_cached_eval_dataset=True,
            )
            val_logs = {"val_" + name: val for name, val in val_logs.items()}
            epoch_logs.update(val_logs)

        callbacks.on_epoch_end(epoch, epoch_logs)
        training_logs = epoch_logs
        if self.stop_training:
            break

    if isinstance(self.optimizer, optimizers_module.Optimizer) and epochs > 0:
        await self.optimizer.finalize_variable_values(self.trainable_variables)

    # If _eval_epoch_iterator exists, delete it after all epochs are done.
    if getattr(self, "_eval_epoch_iterator", None) is not None:
        del self._eval_epoch_iterator
    callbacks.on_train_end(logs=training_logs)
    return self.history

get_compile_config()

Returns a serialized config with information for compiling the program.

This method returns a config dictionary containing all the information (optimizer, reward, metrics, etc.) with which the program was compiled.

Returns:

Type Description
dict

A dict containing information for compiling the program.

Source code in synalinks/src/trainers/trainer.py
def get_compile_config(self):
    """Returns a serialized config with information for compiling the program.

    This method returns a config dictionary containing all the information
    (optimizer, reward, metrics, etc.) with which the program was compiled.

    Returns:
        (dict): A dict containing information for compiling the program.
    """
    if self.compiled and hasattr(self, "_compile_config"):
        return self._compile_config.serialize()

get_metrics_result()

Returns the program's metrics values as a dict.

If any of the metric result is a dict (containing multiple metrics), each of them gets added to the top level returned dict of this method.

Returns:

Type Description
dict

A dict containing values of the metrics listed in self.metrics. Example: {'reward': 0.2, 'accuracy': 0.7}.

Source code in synalinks/src/trainers/trainer.py
def get_metrics_result(self):
    """Returns the program's metrics values as a dict.

    If any of the metric result is a dict (containing multiple metrics),
    each of them gets added to the top level returned dict of this method.

    Returns:
        (dict): A `dict` containing values of the metrics listed in `self.metrics`.
            Example: `{'reward': 0.2, 'accuracy': 0.7}`.
    """
    return_metrics = {}
    for metric in self.metrics:
        result = metric.result()
        if isinstance(result, dict):
            return_metrics.update(result)
        else:
            return_metrics[metric.name] = result
    return python_utils.pythonify_logs(return_metrics)

predict(x, batch_size=None, verbose='auto', steps=None, callbacks=None) async

Generates output predictions for the input samples.

Computation is done in batches. This method is designed for batch processing of large numbers of inputs. It is not intended for use inside of loops that iterate over your data and process small numbers of inputs at a time.

For small numbers of inputs that fit in one batch, directly use __call__() for faster execution, e.g., program(x), or program(x, training=False) if you have modules that behave differently during inference.

Parameters:

Name Type Description Default
x ndarray | generator

Input data. It can be: - A NumPy array (or array-like), or a list of DataModel arrays (in case the model has multiple inputs). - A list of dict mapping input names to the corresponding DataModels, if the program has named inputs. - A Python generator function yielding (inputs, targets).

required
batch_size int

Integer or None. Number of samples per batch of computation. If unspecified, batch_size will default to 32. Do not specify the batch_size if your input data x is a synalinks.utils.PyDataset, tf.data.Dataset, torch.utils.data.DataLoader or Python generator function since they generate batches.

None
verbose int

"auto", 0, 1, or 2. Verbosity mode. 0 = silent, 1 = progress bar, 2 = single line. "auto" becomes 1 for most cases. Note that the progress bar is not particularly useful when logged to a file, so verbose=2 is recommended when not running interactively (e.g. in a production environment). Defaults to "auto".

'auto'
steps int

Total number of steps (batches of samples) to draw before declaring the prediction round finished. If steps is None, it will run until x is exhausted. In the case of an infinitely repeating dataset, it will run indefinitely.

None
callbacks list

List of synalinks.callbacks.Callback instances. List of callbacks to apply during prediction.

None

Returns:

Type Description
list

JsonDataModel array(s) of predictions. If the pipeline failed, a None is added to the predictions.

Source code in synalinks/src/trainers/trainer.py
async def predict(
    self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
):
    """Generates output predictions for the input samples.

    Computation is done in batches. This method is designed for batch
    processing of large numbers of inputs. It is not intended for use inside
    of loops that iterate over your data and process small numbers of inputs
    at a time.

    For small numbers of inputs that fit in one batch,
    directly use `__call__()` for faster execution, e.g.,
    `program(x)`, or `program(x, training=False)` if you have modules
    that behave differently during inference.

    Args:
        x (np.ndarray | generator): Input data. It can be:
            - A NumPy array (or array-like), or a list of `DataModel` arrays
                (in case the model has multiple inputs).
            - A list of dict mapping input names to the corresponding `DataModel`s,
                if the program has named inputs.
            - A Python generator function yielding `(inputs, targets)`.
        batch_size (int): Integer or `None`.
            Number of samples per batch of computation.
            If unspecified, `batch_size` will default to 32.
            Do not specify the `batch_size` if your input data `x` is a
            `synalinks.utils.PyDataset`, `tf.data.Dataset`,
            `torch.utils.data.DataLoader` or Python generator function
            since they generate batches.
        verbose (int): `"auto"`, 0, 1, or 2. Verbosity mode.
            0 = silent, 1 = progress bar, 2 = single line.
            `"auto"` becomes 1 for most cases. Note that the progress bar
            is not particularly useful when logged to a file,
            so `verbose=2` is recommended when not running interactively
            (e.g. in a production environment). Defaults to `"auto"`.
        steps (int): Total number of steps (batches of samples) to draw before
            declaring the prediction round finished. If `steps` is `None`,
            it will run until `x` is exhausted. In the case of an infinitely
            repeating dataset, it will run indefinitely.
        callbacks (list): List of `synalinks.callbacks.Callback` instances.
            List of callbacks to apply during prediction.

    Returns:
        (list): `JsonDataModel` array(s) of predictions.
            If the pipeline failed, a None is added to the predictions.
    """
    # Create an iterator that yields batches of input data.
    epoch_iterator = EpochIterator(
        x=x,
        batch_size=batch_size,
        steps_per_epoch=steps,
        shuffle=False,
        steps_per_execution=self.steps_per_execution,
    )

    # Container that configures and calls callbacks.
    if not isinstance(callbacks, callbacks_module.CallbackList):
        callbacks = callbacks_module.CallbackList(
            callbacks,
            add_history=True,
            add_progbar=verbose != 0,
            verbose=verbose,
            epochs=1,
            steps=epoch_iterator.num_batches,
            model=self,
        )

    self.stop_predicting = False
    callbacks.on_test_begin()
    outputs = []
    for step, iterator in epoch_iterator:
        callbacks.on_predict_batch_begin(step)
        data = iterator[0]
        x_batch, _ = data_adapter_utils.unpack_x_y(data)
        batch_outputs = await self.predict_on_batch(x_batch)
        outputs.extend(batch_outputs)
        callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
        if self.stop_predicting:
            break
    callbacks.on_predict_end()
    return np.array(outputs, dtype="object")

predict_on_batch(x, training=False) async

Returns predictions for a single batch of samples.

Parameters:

Name Type Description Default
x ndarray

Input data. Must be array-like.

required
training bool

Boolean. True if training.

False

Returns:

Type Description
list

list(s) of JsonDataModel predictions.

Source code in synalinks/src/trainers/trainer.py
async def predict_on_batch(self, x, training=False):
    """Returns predictions for a single batch of samples.

    Args:
        x (np.ndarray): Input data. Must be array-like.
        training (bool): Boolean. True if training.

    Returns:
        (list): list(s) of JsonDataModel predictions.
    """
    tasks = []
    for inputs in x:
        tasks.append(self(inputs, training=training))
    y_pred = await asyncio.gather(*tasks)
    return y_pred

test_on_batch(x, y=None, return_dict=False) async

Test the program on a single batch of samples.

Parameters:

Name Type Description Default
x ndarray

Input data. Must be array-like.

required
y ndarray

Target data. Must be array-like.

None
return_dict bool

If True, reward and metric results are returned as a dict, with each key being the name of the metric. If False, they are returned as a list.

False

Returns:

Type Description
float | list | dict

A scalar reward value (when no metrics and return_dict=False), a list of reward and metric values (if there are metrics and return_dict=False), or a dict of metric and reward values (if return_dict=True).

Source code in synalinks/src/trainers/trainer.py
async def test_on_batch(
    self,
    x,
    y=None,
    return_dict=False,
):
    """Test the program on a single batch of samples.

    Args:
        x (np.ndarray): Input data. Must be array-like.
        y (np.ndarray): Target data. Must be array-like.
        return_dict (bool): If `True`, reward and metric results are returned as a
            dict, with each key being the name of the metric. If `False`,
            they are returned as a list.

    Returns:
        (float | list | dict): A scalar reward value
            (when no metrics and `return_dict=False`), a list of reward
            and metric values (if there are metrics and `return_dict=False`),
            or a dict of metric and reward values (if `return_dict=True`).
    """
    y_pred = await self.predict_on_batch(x, training=False)

    reward = await self.compute_reward(
        x=x,
        y=y,
        y_pred=y_pred,
        training=False,
    )

    await self._reward_tracker.update_state(reward)

    metrics = await self.compute_metrics(x, y, y_pred)

    if return_dict:
        return metrics
    return self._flatten_metrics_in_order(metrics)

train_on_batch(x, y=None, return_dict=False) async

Runs a single backpropagation/optimization update on a single batch of data.

Parameters:

Name Type Description Default
x ndarray

Input data. Must be array-like.

required
y ndarray

Target data. Must be array-like.

None
return_dict bool

If True, reward and metric results are returned as a dict, with each key being the name of the metric. If False, they are returned as a list.

False

Returns:

Type Description
float | list | dict

A scalar reward value (when no metrics and return_dict=False), a list of reward and metric values (if there are metrics and return_dict=False), or a dict of metric and reward values (if return_dict=True).

Source code in synalinks/src/trainers/trainer.py
async def train_on_batch(
    self,
    x,
    y=None,
    return_dict=False,
):
    """Runs a single backpropagation/optimization update on a single batch of data.

    Args:
        x (np.ndarray): Input data. Must be array-like.
        y (np.ndarray): Target data. Must be array-like.
        return_dict (bool): If `True`, reward and metric results are returned as a
            dict, with each key being the name of the metric. If `False`,
            they are returned as a list.

    Returns:
        (float | list | dict): A scalar reward value
            (when no metrics and `return_dict=False`), a list of reward
            and metric values (if there are metrics and `return_dict=False`),
            or a dict of metric and reward values (if `return_dict=True`).
    """
    y_pred = await self.predict_on_batch(x, training=True)

    reward = await self.compute_reward(
        x=x,
        y=y,
        y_pred=y_pred,
        training=True,
    )

    await self._reward_tracker.update_state(reward)

    # Perform training/optimization
    if self.trainable_variables:
        await self.optimizer.apply_optimization(
            self.trainable_variables,
            reward=reward,
        )
    else:
        warnings.warn("The program does not have any trainable variables.")

    metrics = await self.compute_metrics(x, y, y_pred)

    if return_dict:
        return metrics
    return self._flatten_metrics_in_order(metrics)