Source code for simfish_rl.hdf5_logger

# Copyright 2025 Asaph Zylbertal
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""A simple HDF5 logger.

"""

import h5py
import numpy as np
import os
import time
from typing import TextIO, Union

from acme.utils import paths
from acme.utils.loggers import base as base_loggers
from acme.utils.observers import base as base_observers
import dm_env

from typing import Dict, Optional

[docs] class EnvInfoKeep(base_observers.EnvLoopObserver): """An observer that collects and accumulates scalars from env's info.""" def __init__(self): self._metrics = None self.env_variables = None def _accumulate_metrics(self, env: dm_env.Environment, obs, actor_state: Optional[np.ndarray] = None) -> None: if not hasattr(env, 'get_info'): return info = getattr(env, 'get_info')() info['action'] = [int(obs.action)] info['reward'] = [obs.reward] info['vis_observation'] = [obs.observation[0]] info['internal_state'] = [obs.observation[1]] if actor_state is not None: info['actor_state'] = [actor_state] if not info: return for k, v in info.items(): self._metrics[k] = self._metrics.get(k, []) + v
[docs] def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep ) -> None: """Observes the initial state.""" sediment = env.arena.global_sediment_grating self._metrics = {'sediment': [sediment], 'salt_location': [env.salt_location], 'env_variables': env.env_variables} self._accumulate_metrics(env, timestep.observation)
[docs] def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray, actor_state: Optional[np.ndarray] = None) -> None: """Records one environment step.""" self._accumulate_metrics(env, timestep.observation, actor_state)
[docs] def get_metrics(self) -> Dict[str, base_observers.Number]: """Returns metrics collected for the current episode.""" return self._metrics
[docs] class SimpleEnvInfoKeep(base_observers.EnvLoopObserver): """An observer that collects and accumulates scalars from env's info.""" def __init__(self): self._metrics = None
[docs] def observe_first(self, env: dm_env.Environment, timestep: dm_env.TimeStep ) -> None: """Observes the initial state.""" self.env = env self._metrics = {}
[docs] def observe(self, env: dm_env.Environment, timestep: dm_env.TimeStep, action: np.ndarray, actor_state: Optional[np.ndarray] = None) -> None: """Records one environment step.""" pass
[docs] def get_metrics(self) -> Dict[str, base_observers.Number]: """Returns metrics collected for the current episode.""" self._metrics['reward_salt'] = self.env.salt_associated_reward self._metrics['reward_energy'] = self.env.energy_associated_reward self._metrics['reward_predator'] = self.env.predator_associated_reward self._metrics['reward_wall'] = self.env.wall_associated_reward self._metrics['reward_consumption'] = self.env.consumption_associated_reward self._metrics['predator_avoided'] = self.env.total_attacks_avoided self._metrics['predator_caught'] = self.env.total_attacks_captured return self._metrics
[docs] class HDF5Logger(base_loggers.Logger): """Standard HDF5 logger. """ _open = open def __init__( self, directory_or_file: Union[str, TextIO] = '~/acme', label: str = '', add_uid: bool = True, wait_min: float = 10, ): """Instantiates the logger. Args: directory_or_file: Either a directory path as a string, or a file TextIO object. label: Extra label to add to logger. This is added as a suffix to the directory. add_uid: Whether to add a UID to the file path. See `paths.process_path` for details. """ self._add_uid = add_uid directory = paths.process_path( directory_or_file, 'logs', label, add_uid=self._add_uid) self.base_path = os.path.join(directory, 'logs_') self.called = 0 self.wait_min = wait_min
[docs] def write(self, data: base_loggers.LoggingData): data = base_loggers.to_numpy(data) self.called += 1 # write the fields of data into an hdf5 file with h5py.File(self.base_path + str(self.called) + '.hdf5', 'w') as f: for key, value in data.items(): # if key contains 'prey'... if key == 'prey_x' or key == 'prey_y': # value is a list of lists with variable length. find the maximal length max_prey_num = max([len(x) for x in value]) # create a 2d numpy array with the maximal length prey_array = np.zeros((len(value), max_prey_num)) # fill with nans prey_array.fill(np.nan) # fill the array with the values from the list of lists for i in range(len(value)): prey_array[i, :len(value[i])] = value[i] # save the array f.create_dataset(key, data=prey_array) elif 'env_variables' in key: # if key contains 'env_variables', save it as a group if 'env_variables' not in f: f.create_group('env_variables') # save the env variables for env_key, env_value in value.items(): if isinstance(value, np.ndarray): f['env_variables'].create_dataset(env_key, data=env_value) else: f['env_variables'].attrs[env_key] = env_value else: f.create_dataset(key, data=value) # wait to space out the writes time.sleep(self.wait_min * 60)
[docs] def close(self): pass