Skip to content

CSVLogger

CSVLogger

Bases: Callback

Callback that streams epoch results to a CSV file.

Supports all values that can be represented as a string, including 1D iterables such as np.ndarray.

Parameters:

Name Type Description Default
filepath str | PathLike

Filepath of the CSV file, e.g. 'run/log.csv'.

required
separator str

String used to separate elements in the CSV file.

','
append bool

True: append if file exists (useful for continuing training). False: overwrite existing file.

False

Example:

csv_logger = CSVLogger(filepath='training.log')
program.fit(x_train, y_train, callbacks=[csv_logger])
Source code in synalinks/src/callbacks/csv_logger.py
@synalinks_export("synalinks.callbacks.CSVLogger")
class CSVLogger(Callback):
    """Callback that streams epoch results to a CSV file.

    Supports all values that can be represented as a string,
    including 1D iterables such as `np.ndarray`.

    Args:
        filepath (str | os.PathLike): Filepath of the CSV file, e.g. `'run/log.csv'`.
        separator (str): String used to separate elements in the CSV file.
        append (bool): True: append if file exists (useful for continuing
            training). False: overwrite existing file.

    Example:

    ```python
    csv_logger = CSVLogger(filepath='training.log')
    program.fit(x_train, y_train, callbacks=[csv_logger])
    ```
    """

    def __init__(self, filepath, separator=",", append=False):
        super().__init__()
        self.sep = separator
        self.filepath = file_utils.path_to_string(filepath)
        self.append = append
        self.writer = None
        self.keys = None
        self.append_header = True

    def on_train_begin(self, logs=None):
        if self.append:
            if file_utils.exists(self.filepath):
                with file_utils.File(self.filepath, "r") as f:
                    self.append_header = not bool(len(f.readline()))
            mode = "a"
        else:
            mode = "w"
        self.csv_file = file_utils.File(self.filepath, mode)

    def on_epoch_end(self, epoch, logs=None):
        logs = logs or {}

        def handle_value(k):
            is_zero_dim_ndarray = isinstance(k, np.ndarray) and k.ndim == 0
            if isinstance(k, str):
                return k
            elif isinstance(k, collections.abc.Iterable) and not is_zero_dim_ndarray:
                return f'"[{", ".join(map(str, k))}]"'
            else:
                return k

        if self.keys is None:
            self.keys = sorted(logs.keys())
            # When validation_freq > 1, `val_` keys are not in first epoch logs
            # Add the `val_` keys so that its part of the fieldnames of writer.
            val_keys_found = False
            for key in self.keys:
                if key.startswith("val_"):
                    val_keys_found = True
                    break
            if not val_keys_found:
                self.keys.extend(["val_" + k for k in self.keys])

        if not self.writer:

            class CustomDialect(csv.excel):
                delimiter = self.sep

            fieldnames = ["epoch"] + self.keys

            self.writer = csv.DictWriter(
                self.csv_file, fieldnames=fieldnames, dialect=CustomDialect
            )
            if self.append_header:
                self.writer.writeheader()

        row_dict = collections.OrderedDict({"epoch": epoch})
        row_dict.update((key, handle_value(logs.get(key, "NA"))) for key in self.keys)
        self.writer.writerow(row_dict)
        self.csv_file.flush()

    def on_train_end(self, logs=None):
        self.csv_file.close()
        self.writer = None