Skip to content

Commit

Permalink
Merge pull request #118 from sony/feature/20240314-add-format-specifier
Browse files Browse the repository at this point in the history
Add format specifier to file writer
  • Loading branch information
TakayoshiTakayanagi authored Mar 27, 2024
2 parents ade5c23 + 99558f9 commit 2b4b7e8
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 5 deletions.
6 changes: 3 additions & 3 deletions nnabla_rl/writers/file_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -22,14 +22,14 @@


class FileWriter(Writer):
def __init__(self, outdir, file_prefix):
def __init__(self, outdir, file_prefix, fmt="%.3f"):
super(FileWriter, self).__init__()
if isinstance(outdir, str):
outdir = pathlib.Path(outdir)
self._outdir = outdir
files.create_dir_if_not_exist(outdir=outdir)
self._file_prefix = file_prefix
self._fmt = '%.3f'
self._fmt = fmt

def write_scalar(self, iteration_num, scalar):
outfile = self._outdir / (self._file_prefix + '_scalar.tsv')
Expand Down
2 changes: 2 additions & 0 deletions test_resources/writers/evaluation_results_scalar%.3f.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
iteration mean std_dev min max median
1 2.000 1.414 0.000 4.000 2.000
2 changes: 2 additions & 0 deletions test_resources/writers/evaluation_results_scalar%.5f.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
iteration mean std_dev min max median
1 2.00000 1.41421 0.00000 4.00000 2.00000
2 changes: 2 additions & 0 deletions test_resources/writers/evaluation_results_scalar%f.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
iteration mean std_dev min max median
1 2.000000 1.414214 0.000000 4.000000 2.000000
25 changes: 23 additions & 2 deletions tests/writers/test_file_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020,2021 Sony Corporation.
# Copyright 2021 Sony Group Corporation.
# Copyright 2021,2022,2023,2024 Sony Group Corporation.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
import tempfile

import numpy as np
import pytest

from nnabla_rl.writers.file_writer import FileWriter

Expand Down Expand Up @@ -62,6 +63,27 @@ def test_write_histogram(self):
os.path.join(test_file_dir, 'evaluation_results_histogram.tsv')
self._check_same_tsv_file(file_path, test_file_path)

@pytest.mark.parametrize("format", ["%f", "%.3f", "%.5f"])
def test_data_formatting(self, format):
with tempfile.TemporaryDirectory() as tmpdir:
test_returns = np.arange(5)
test_results = {}
test_results['mean'] = np.mean(test_returns)
test_results['std_dev'] = np.std(test_returns)
test_results['min'] = np.min(test_returns)
test_results['max'] = np.max(test_returns)
test_results['median'] = np.median(test_returns)

writer = FileWriter(outdir=tmpdir, file_prefix='actual_results', fmt=format)
writer.write_scalar(1, test_results)

actual_file_path = os.path.join(tmpdir, 'actual_results_scalar.tsv')

this_file_dir = os.path.dirname(__file__)
expected_file_dir = this_file_dir.replace('tests', 'test_resources')
expected_file_path = os.path.join(expected_file_dir, f'evaluation_results_scalar{format}.tsv')
self._check_same_tsv_file(actual_file_path, expected_file_path)

def _check_same_tsv_file(self, file_path1, file_path2):
# check each line
with open(file_path1, mode='rt') as data_1, \
Expand All @@ -71,5 +93,4 @@ def _check_same_tsv_file(self, file_path1, file_path2):


if __name__ == "__main__":
import pytest
pytest.main()

0 comments on commit 2b4b7e8

Please sign in to comment.