simfish_rl package

Submodules

simfish_rl.eval_simfish_agent module

class simfish_rl.eval_simfish_agent.StateEnvLoop(environment: Environment, actor: Actor, counter: Counter | None = None, logger: Logger | None = None, should_update: bool = True, label: str = 'environment_loop', observers: Sequence[EnvLoopObserver] = ())[source]

Bases: EnvironmentLoop

An environment loop that keeps track of the state.

run_episode() Mapping[str, Any][source]

Run one episode.

Each episode is a loop which interacts first with the environment to get an observation and then give that observation to the agent in order to retrieve an action.

Returns:

An instance of loggers.LoggingData.

simfish_rl.eval_simfish_agent.eval_agent(experiment: ExperimentConfig, directory: str, num_episodes: int = 100, log_subdir: str = 'eval_model')[source]

simfish_rl.hdf5_logger module

A simple HDF5 logger.

class simfish_rl.hdf5_logger.EnvInfoKeep[source]

Bases: EnvLoopObserver

An observer that collects and accumulates scalars from env’s info.

get_metrics() Dict[str, int | float][source]

Returns metrics collected for the current episode.

observe(env: Environment, timestep: TimeStep, action: ndarray, actor_state: ndarray | None = None) None[source]

Records one environment step.

observe_first(env: Environment, timestep: TimeStep) None[source]

Observes the initial state.

class simfish_rl.hdf5_logger.HDF5Logger(directory_or_file: str | TextIO = '~/acme', label: str = '', add_uid: bool = True, wait_min: float = 10)[source]

Bases: Logger

Standard HDF5 logger.

close()[source]

Closes the logger, not expecting any further write.

write(data: Mapping[str, Any])[source]

Writes data to destination (file, terminal, database, etc).

class simfish_rl.hdf5_logger.SimpleEnvInfoKeep[source]

Bases: EnvLoopObserver

An observer that collects and accumulates scalars from env’s info.

get_metrics() Dict[str, int | float][source]

Returns metrics collected for the current episode.

observe(env: Environment, timestep: TimeStep, action: ndarray, actor_state: ndarray | None = None) None[source]

Records one environment step.

observe_first(env: Environment, timestep: TimeStep) None[source]

Observes the initial state.

simfish_rl.simfish_r2d2_learner module

Simfish R2D2 Learner with support for mirrored/reflected training examples. This module implements a custom R2D2 learner for the Simfish environment that extends the base Acme R2D2 learner. Unlike the standard Atari R2D2 learner, this implementation adds data augmentation through reflection/mirroring of observations and actions. The learner performs training steps on both original and reflected samples, effectively doubling the training data by exploiting the symmetry in the Simfish environment. This is particularly useful for environments where flipping preserves semantic meaning with appropriate action mappings. Key differences from standard R2D2 learner: - Supports reflection of visual observations along spatial axes - Maps actions appropriately when observations are reflected - Randomly interleaves regular and mirrored training steps - Maintains action mirror mappings for consistent reflection Classes:

SimfishR2D2Learner: R2D2 learner with reflection-based data augmentation

Functions:

reflect_observations: Flips observations along spatial dimensions reflect_actions: Maps actions to their mirrored equivalents reflect_samples: Creates reflected versions of replay samples

class simfish_rl.simfish_r2d2_learner.SimfishR2D2Learner(networks: ~acme.jax.networks.base.UnrollableNetwork, batch_size: int, random_key: ~jax._src.prng.PRNGKeyArray, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: ~typing.Iterator[~acme.jax.utils.PrefetchingSplit], optimizer: ~optax._src.base.GradientTransformation, actions_mirror: ~typing.Dict[int, int], bootstrap_n: int = 5, tx_pair: ~rlax._src.nonlinear_bellman.TxPair = (<function signed_hyperbolic>, <function signed_parabolic>), clip_rewards: bool = False, max_abs_reward: float = 1.0, use_core_state: bool = True, prefetch_size: int = 2, replay_client: ~reverb.client.Client | None = None, counter: ~acme.utils.counting.Counter | None = None, logger: ~acme.utils.loggers.base.Logger | None = None)[source]

Bases: R2D2Learner

A learner for the Simfish R2D2 agent, with support for reflection.

step()[source]

Perform an update step of the learner’s parameters.

simfish_rl.simfish_r2d2_learner.reflect_actions(values: Any | Array | Iterable[NestedSpec] | Mapping[Any, NestedSpec], actions_mirror: Dict[int, int]) Any[source]
simfish_rl.simfish_r2d2_learner.reflect_observations(values: Any | Array | Iterable[NestedSpec] | Mapping[Any, NestedSpec]) Any[source]
simfish_rl.simfish_r2d2_learner.reflect_samples(samples: ReplaySample, actions_mirror: Dict[int, int]) ReplaySample[source]

Reflects the observations in the samples.

simfish_rl.simfish_r2d2_network module

R2D2 Network Architecture for Simfish RL Environment. This module implements a Recurrent Experience Replay in Distributed Reinforcement Learning (R2D2) network architecture tailored for the Simfish environment. The network processes visual observations from two eyes along with action and reward information. Architecture Overview: ———————-

  1. Input Processing (OAREmbedding):

  • Takes observation-action-reward (OAR) tuples

  • Processes observations through DeepSimfishTorso

  • One-hot encodes actions

  • Normalizes rewards using tanh to [-1, 1]

  • Concatenates all features

  1. Visual Processing (DeepSimfishTorso):

  • Splits binocular input into left and right eye channels

  • Each eye processes through a retina module: * Conv1D(8 filters, kernel=3, stride=1) + ReLU * Conv1D(8 filters, kernel=5, stride=2) + ReLU * Conv1D(8 filters, kernel=5, stride=2) + ReLU

  • Concatenates left and right eye features

  • Passes through MLP head with configurable hidden sizes (default: 32)

  • Optionally concatenates internal state information

  1. Recurrent Core (R2D2SimfishNetwork):

  • LSTM with 64 hidden units for temporal processing

  • Maintains hidden and cell states across timesteps

  1. Action Value Estimation:

  • Duelling MLP head with 64 hidden units

  • Separates state value and action advantage streams

  • Outputs Q-values for all actions

The network supports both single-step and batch unroll operations, making it suitable for distributed training with experience replay. Classes: ——– - OAREmbedding: Embeds observation, action, and reward inputs - Flatten: Utility module for flattening tensors - retina: 1D convolutional network for processing eye input - DeepSimfishTorso: Binocular visual processing network - R2D2SimfishNetwork: Main recurrent network with LSTM core and duelling head Functions: ———- - make_r2d2_networks: Factory function to create R2D2 networks from environment specs

class simfish_rl.simfish_r2d2_network.DeepSimfishTorso(hidden_sizes: Sequence[int] = (256,), name: str = 'deep_simfish_torso')[source]

Bases: Module

Neural network torso for processing binocular visual input in Simfish.

This class processes left and right eye inputs through a shared retina network, concatenates the processed outputs, and passes them through an MLP head. Optionally concatenates internal state if provided.

Parameters:
  • hidden_sizes – Sequence of integers specifying the sizes of hidden layers in the MLP head. Defaults to (256,).

  • name – Name of the module. Defaults to ‘deep_simfish_torso’.

Input:
x: A list of jnp.ndarray where:
  • x[0]: Visual input with shape (batch, position, channel, eye) where eye dimension has size 2 (left=0, right=1). Right eye positions are reversed to match left eye orientation.

  • x[1] (optional): Internal state to concatenate with processed visual output.

Returns:

Processed output combining binocular visual features from the MLP

head, optionally concatenated with internal state if provided.

Return type:

jnp.ndarray

Note

The right eye input is reversed along the position axis to align with the left eye’s spatial orientation before processing through the shared retina network.

class simfish_rl.simfish_r2d2_network.Flatten(name='Flatten')[source]

Bases: Module

class simfish_rl.simfish_r2d2_network.OAREmbedding(torso: SupportsCall, num_actions: int)[source]

Bases: Module

Module for embedding (observation, action, reward) inputs together.

num_actions: int
torso: SupportsCall
class simfish_rl.simfish_r2d2_network.R2D2SimfishNetwork(num_actions: int)[source]

Bases: RNNCore

R2D2 network architecture for Simfish environment.

This network implements the R2D2 (Recurrent Replay Distributed DQN) architecture tailored for the Simfish environment. It combines observation, action, and reward embeddings with an LSTM core and a duelling DQN head.

The network processes sequences of (observation, action, reward) tuples through: 1. An embedding layer (OAREmbedding with DeepSimfishTorso) 2. An LSTM core for temporal dependencies 3. A duelling MLP head for Q-value estimation

_embed

OAREmbedding layer that processes observations, actions, and rewards using a DeepSimfishTorso with hidden sizes [32].

_core

LSTM layer with 64 hidden units for recurrent processing.

_duelling_head

Duelling DQN head with hidden sizes [64] for Q-value estimation.

_num_actions

Number of possible actions in the environment.

__call__()[source]

Processes a single timestep of inputs through the network.

initial_state()[source]

Returns the initial LSTM hidden state for a given batch size.

unroll()[source]

Efficiently processes a sequence of inputs using static unrolling.

initial_state(batch_size: int | None, **unused_kwargs) LSTMState[source]

Constructs an initial state for this core.

Parameters:

batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.

Returns:

Arbitrarily nested initial state for this core.

unroll(inputs: OAR, state: LSTMState) Tuple[Array, LSTMState][source]

Efficient unroll that applies torso, core, and duelling mlp in one pass.

simfish_rl.simfish_r2d2_network.make_r2d2_networks(env_spec: EnvironmentSpec) R2D2SimfishNetwork[source]

Builds default R2D2 networks for simple networks.

class simfish_rl.simfish_r2d2_network.retina[source]

Bases: Module

A convolutional neural network module for processing retinal input data. This module implements a 1D convolutional network designed to process retinal information, consisting of three convolutional layers with ReLU activations. The network architecture: - Conv1D: 8 filters, kernel size 3, stride 1 - ReLU activation - Conv1D: 8 filters, kernel size 5, stride 2 - ReLU activation - Conv1D: 8 filters, kernel size 5, stride 2 - ReLU activation :param inputs: Input array with shape either:

  • (position, channel) - PE format for single sample

  • (batch, position, channel) - BPE format for batched samples

Returns:

Flattened output features with shape:
  • (batch, features) for batched inputs

  • (features,) for single sample inputs

Return type:

jnp.ndarray

Raises:

ValueError – If input rank is not 2 (PC) or 3 (BPC).

Example

>>> retina_net = retina()
>>> # Single sample: (position, channel)
>>> output = retina_net(jnp.ones((100, 2)))
>>> # Batched: (batch, position, channel)
>>> output = retina_net(jnp.ones((32, 100, 2)))

Module contents