simfish_rl package
Submodules
simfish_rl.eval_simfish_agent module
- class simfish_rl.eval_simfish_agent.StateEnvLoop(environment: Environment, actor: Actor, counter: Counter | None = None, logger: Logger | None = None, should_update: bool = True, label: str = 'environment_loop', observers: Sequence[EnvLoopObserver] = ())[source]
Bases:
EnvironmentLoopAn environment loop that keeps track of the state.
simfish_rl.hdf5_logger module
A simple HDF5 logger.
- class simfish_rl.hdf5_logger.EnvInfoKeep[source]
Bases:
EnvLoopObserverAn observer that collects and accumulates scalars from env’s info.
- class simfish_rl.hdf5_logger.HDF5Logger(directory_or_file: str | TextIO = '~/acme', label: str = '', add_uid: bool = True, wait_min: float = 10)[source]
Bases:
LoggerStandard HDF5 logger.
simfish_rl.simfish_r2d2_learner module
Simfish R2D2 Learner with support for mirrored/reflected training examples. This module implements a custom R2D2 learner for the Simfish environment that extends the base Acme R2D2 learner. Unlike the standard Atari R2D2 learner, this implementation adds data augmentation through reflection/mirroring of observations and actions. The learner performs training steps on both original and reflected samples, effectively doubling the training data by exploiting the symmetry in the Simfish environment. This is particularly useful for environments where flipping preserves semantic meaning with appropriate action mappings. Key differences from standard R2D2 learner: - Supports reflection of visual observations along spatial axes - Maps actions appropriately when observations are reflected - Randomly interleaves regular and mirrored training steps - Maintains action mirror mappings for consistent reflection Classes:
SimfishR2D2Learner: R2D2 learner with reflection-based data augmentation
- Functions:
reflect_observations: Flips observations along spatial dimensions reflect_actions: Maps actions to their mirrored equivalents reflect_samples: Creates reflected versions of replay samples
- class simfish_rl.simfish_r2d2_learner.SimfishR2D2Learner(networks: ~acme.jax.networks.base.UnrollableNetwork, batch_size: int, random_key: ~jax._src.prng.PRNGKeyArray, burn_in_length: int, discount: float, importance_sampling_exponent: float, max_priority_weight: float, target_update_period: int, iterator: ~typing.Iterator[~acme.jax.utils.PrefetchingSplit], optimizer: ~optax._src.base.GradientTransformation, actions_mirror: ~typing.Dict[int, int], bootstrap_n: int = 5, tx_pair: ~rlax._src.nonlinear_bellman.TxPair = (<function signed_hyperbolic>, <function signed_parabolic>), clip_rewards: bool = False, max_abs_reward: float = 1.0, use_core_state: bool = True, prefetch_size: int = 2, replay_client: ~reverb.client.Client | None = None, counter: ~acme.utils.counting.Counter | None = None, logger: ~acme.utils.loggers.base.Logger | None = None)[source]
Bases:
R2D2LearnerA learner for the Simfish R2D2 agent, with support for reflection.
- simfish_rl.simfish_r2d2_learner.reflect_actions(values: Any | Array | Iterable[NestedSpec] | Mapping[Any, NestedSpec], actions_mirror: Dict[int, int]) Any[source]
simfish_rl.simfish_r2d2_network module
R2D2 Network Architecture for Simfish RL Environment. This module implements a Recurrent Experience Replay in Distributed Reinforcement Learning (R2D2) network architecture tailored for the Simfish environment. The network processes visual observations from two eyes along with action and reward information. Architecture Overview: ———————-
Input Processing (OAREmbedding):
Takes observation-action-reward (OAR) tuples
Processes observations through DeepSimfishTorso
One-hot encodes actions
Normalizes rewards using tanh to [-1, 1]
Concatenates all features
Visual Processing (DeepSimfishTorso):
Splits binocular input into left and right eye channels
Each eye processes through a retina module: * Conv1D(8 filters, kernel=3, stride=1) + ReLU * Conv1D(8 filters, kernel=5, stride=2) + ReLU * Conv1D(8 filters, kernel=5, stride=2) + ReLU
Concatenates left and right eye features
Passes through MLP head with configurable hidden sizes (default: 32)
Optionally concatenates internal state information
Recurrent Core (R2D2SimfishNetwork):
LSTM with 64 hidden units for temporal processing
Maintains hidden and cell states across timesteps
Action Value Estimation:
Duelling MLP head with 64 hidden units
Separates state value and action advantage streams
Outputs Q-values for all actions
The network supports both single-step and batch unroll operations, making it suitable for distributed training with experience replay. Classes: ——– - OAREmbedding: Embeds observation, action, and reward inputs - Flatten: Utility module for flattening tensors - retina: 1D convolutional network for processing eye input - DeepSimfishTorso: Binocular visual processing network - R2D2SimfishNetwork: Main recurrent network with LSTM core and duelling head Functions: ———- - make_r2d2_networks: Factory function to create R2D2 networks from environment specs
- class simfish_rl.simfish_r2d2_network.DeepSimfishTorso(hidden_sizes: Sequence[int] = (256,), name: str = 'deep_simfish_torso')[source]
Bases:
ModuleNeural network torso for processing binocular visual input in Simfish.
This class processes left and right eye inputs through a shared retina network, concatenates the processed outputs, and passes them through an MLP head. Optionally concatenates internal state if provided.
- Parameters:
hidden_sizes – Sequence of integers specifying the sizes of hidden layers in the MLP head. Defaults to (256,).
name – Name of the module. Defaults to ‘deep_simfish_torso’.
- Input:
- x: A list of jnp.ndarray where:
x[0]: Visual input with shape (batch, position, channel, eye) where eye dimension has size 2 (left=0, right=1). Right eye positions are reversed to match left eye orientation.
x[1] (optional): Internal state to concatenate with processed visual output.
- Returns:
- Processed output combining binocular visual features from the MLP
head, optionally concatenated with internal state if provided.
- Return type:
jnp.ndarray
Note
The right eye input is reversed along the position axis to align with the left eye’s spatial orientation before processing through the shared retina network.
- class simfish_rl.simfish_r2d2_network.OAREmbedding(torso: SupportsCall, num_actions: int)[source]
Bases:
ModuleModule for embedding (observation, action, reward) inputs together.
- num_actions: int
- torso: SupportsCall
- class simfish_rl.simfish_r2d2_network.R2D2SimfishNetwork(num_actions: int)[source]
Bases:
RNNCoreR2D2 network architecture for Simfish environment.
This network implements the R2D2 (Recurrent Replay Distributed DQN) architecture tailored for the Simfish environment. It combines observation, action, and reward embeddings with an LSTM core and a duelling DQN head.
The network processes sequences of (observation, action, reward) tuples through: 1. An embedding layer (OAREmbedding with DeepSimfishTorso) 2. An LSTM core for temporal dependencies 3. A duelling MLP head for Q-value estimation
- _embed
OAREmbedding layer that processes observations, actions, and rewards using a DeepSimfishTorso with hidden sizes [32].
- _core
LSTM layer with 64 hidden units for recurrent processing.
- _duelling_head
Duelling DQN head with hidden sizes [64] for Q-value estimation.
- _num_actions
Number of possible actions in the environment.
- initial_state(batch_size: int | None, **unused_kwargs) LSTMState[source]
Constructs an initial state for this core.
- Parameters:
batch_size – Optional int or an integral scalar tensor representing batch size. If None, the core may either fail or (experimentally) return an initial state without a batch dimension.
- Returns:
Arbitrarily nested initial state for this core.
- simfish_rl.simfish_r2d2_network.make_r2d2_networks(env_spec: EnvironmentSpec) R2D2SimfishNetwork[source]
Builds default R2D2 networks for simple networks.
- class simfish_rl.simfish_r2d2_network.retina[source]
Bases:
ModuleA convolutional neural network module for processing retinal input data. This module implements a 1D convolutional network designed to process retinal information, consisting of three convolutional layers with ReLU activations. The network architecture: - Conv1D: 8 filters, kernel size 3, stride 1 - ReLU activation - Conv1D: 8 filters, kernel size 5, stride 2 - ReLU activation - Conv1D: 8 filters, kernel size 5, stride 2 - ReLU activation :param inputs: Input array with shape either:
(position, channel) - PE format for single sample
(batch, position, channel) - BPE format for batched samples
- Returns:
- Flattened output features with shape:
(batch, features) for batched inputs
(features,) for single sample inputs
- Return type:
jnp.ndarray
- Raises:
ValueError – If input rank is not 2 (PC) or 3 (BPC).
Example
>>> retina_net = retina() >>> # Single sample: (position, channel) >>> output = retina_net(jnp.ones((100, 2))) >>> # Batched: (batch, position, channel) >>> output = retina_net(jnp.ones((32, 100, 2)))