Source code for simulation.simfish_env

# Copyright 2025 Asaph Zylbertal & Sam Pink
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import h5py
import numpy as np
import pymunk
import dm_env
from dm_env import specs

from .arena import Arena
from .fish import Fish
from acme.wrappers import observation_action_reward
from .constants import PHYS_DAMP, LARGE_MASS, SMALL_MASS

OAR = observation_action_reward.OAR

[docs] class BaseEnvironment(dm_env.Environment): """Base class for the Simfish environment.""" def __init__(self, env_variables, actions, seed=None): super().__init__() self.rng = np.random.default_rng(seed=seed) self.env_variables = env_variables self.phys_dt = self.env_variables['sim_step_duration_seconds'] / self.env_variables['phys_steps_per_sim_step'] self.actions = actions self.num_actions = len(actions) self.max_uv_range = np.absolute(np.log(0.001) / self.env_variables["arena_light_decay_rate"]) self.arena = Arena(self.env_variables, rng=self.rng) self.dark_row = int(self.env_variables['arena_height'] * self.env_variables['arena_dark_fraction']) self.space = pymunk.Space() self.space.gravity = pymunk.Vec2d(0.0, 0.0) self.space.damping = PHYS_DAMP self.fish = Fish(env_variables=env_variables, max_uv_range=self.max_uv_range, rng=self.rng, actions=actions) if "prey_stim_file" in self.env_variables: self.prey_stim_file = self.env_variables["prey_stim_file"] with h5py.File(self.prey_stim_file, 'r') as f: self.prey_stim_locations = f['prey_loc'][:] else: self.prey_stim_locations = None if "predator_stim_file" in self.env_variables: self.predator_stim_file = self.env_variables["predator_stim_file"] with h5py.File(self.predator_stim_file, 'r') as f: self.predator_stim_location = f['predator_loc'][:] else: self.predator_stim_location = None if self.env_variables["salt_enabled"]: self.salt_gradient = None self.xp, self.yp = np.arange(self.env_variables['arena_width']), np.arange( self.env_variables['arena_height']) self.salt_location = None self.capture_fraction = int( self.env_variables["phys_steps_per_sim_step"] * self.env_variables['capture_swim_permissive_time_fraction']) self.capture_start = 1 # int((self.env_variables['phys_steps_per_sim_step'] - self.capture_fraction) / 2) self.capture_end = self.capture_start + self.capture_fraction self.space.add(self.fish.body, self.fish.mouth, self.fish.head, self.fish.tail) self.prey_impulse_slow = self.env_variables["prey_velocity_slow"] * SMALL_MASS * (1 - PHYS_DAMP ** self.phys_dt) self.prey_impulse_fast = self.env_variables["prey_velocity_fast"] * SMALL_MASS * (1 - PHYS_DAMP ** self.phys_dt) self.prey_impulse_jump = self.env_variables["prey_velocity_jump"] * SMALL_MASS * (1 - PHYS_DAMP ** self.phys_dt) self.prey_inertia = pymunk.moment_for_circle(SMALL_MASS, 0, self.env_variables['prey_radius'], (0, 0)) self.predator_impulse = self.env_variables["predator_velocity"] * LARGE_MASS * (1 - PHYS_DAMP ** self.phys_dt) self.predator_inertia = pymunk.moment_for_circle(LARGE_MASS, 0, self.env_variables['predator_radius'], (0, 0)) self.prey_shapes = [] self.prey_cloud_wall_shapes = [] self.predator_shape = None self.energy_associated_reward = 0 self.consumption_associated_reward = 0 self.salt_associated_reward = 0 self.predator_associated_reward = 0 self.wall_associated_reward = 0 self._prey_escape_p_per_physics_step = 1 - (1-self.env_variables["prey_p_escape"])**(1/self.env_variables["phys_steps_per_sim_step"]) self._create_walls() self._define_collisions() self._reset_next_step = True self.action_used = np.zeros(self.num_actions)
[docs] def reset(self) -> dm_env.TimeStep: """Reset the environment to an initial state.""" self._reset_next_step = False self.tested_predator = False self.num_steps = 0 self.fish.touched_edge_this_step = False self.prey_caught = 0 self.predator_attacks_avoided = 0 self.energy_level_log = [] self.salt_concentration = 0 self.switch_step = None self.fish.left_eye.rng_p = np.random.default_rng(seed=self.rng.integers(0, 10000)) self.fish.right_eye.rng_p = np.random.default_rng(seed=self.rng.integers(0, 10000)) if "fish_init_energy_level" in self.env_variables: self.fish.energy_level = self.env_variables["fish_init_energy_level"] else: self.fish.energy_level = 1 # Reset salt gradient if self.env_variables["salt_enabled"]: self._reset_salt_gradient() else: self.salt_location = [np.nan, np.nan] self._clear_environmental_features() self.arena.reset() self.mask_buffer = [] self.action_buffer = [] self.position_buffer = [] self.fish_angle_buffer = [] self.failed_capture_attempts = 0 if self.env_variables['test_sensory_system']: self.fish.body.position = (self.env_variables['arena_width'] / 2, self.env_variables['arena_height'] / 2) self.fish.body.angle = 0 else: if "fish_init_x" in self.env_variables and "fish_init_y" in self.env_variables: self.fish.body.position = (self.env_variables['fish_init_x'], self.env_variables['fish_init_y']) else: self.fish.body.position = (self.rng.integers(self.env_variables['fish_mouth_radius'] + 40, self.env_variables['arena_width'] - (self.env_variables[ 'fish_mouth_radius'] + 40)), self.rng.integers(self.env_variables['fish_mouth_radius'] + 40, self.env_variables['arena_height'] - (self.env_variables[ 'fish_mouth_radius'] + 40))) if "fish_init_angle" in self.env_variables: self.fish.body.angle = self.env_variables['fish_init_angle'] else: self.fish.body.angle = self.rng.random() * 2 * np.pi self.fish.body.velocity = (0, 0) self.fish.capture_possible = False if self.env_variables["prey_cloud_num"] > 0: self.prey_cloud_locations = [ [self.rng.integers( low=(self.env_variables["prey_cloud_region_size"] / 2) + self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'], high=self.env_variables['arena_width'] - ( self.env_variables['prey_radius'] + self.env_variables[ 'fish_mouth_radius']) - (self.env_variables["prey_cloud_region_size"] / 2)), self.rng.integers( low=(self.env_variables["prey_cloud_region_size"] / 2) + self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'], high=self.env_variables['arena_height'] - ( self.env_variables['prey_radius'] + self.env_variables[ 'fish_mouth_radius']) - (self.env_variables["prey_cloud_region_size"] / 2))] for cloud in range(int(self.env_variables["prey_cloud_num"]))] if not self.env_variables["prey_reproduction_mode"]: self._create_prey_cloud_walls() if self.env_variables["test_sensory_system"]: self._create_prey(prey_position=(self.env_variables['arena_width'] / 2 + 30, self.env_variables['arena_height'] / 2 + 30)) self._create_prey(prey_position=(self.env_variables['arena_width'] / 2, self.env_variables['arena_height'] / 2 - 40)) self._create_prey(prey_position=(self.env_variables['arena_width'] / 2, self.env_variables['arena_height'] / 2 + 40)) self._create_prey(prey_position=(self.env_variables['arena_width'] / 2 + 30, self.env_variables['arena_height'] / 2 - 30)) self._create_prey(prey_position=(self.env_variables['arena_width'] / 2 + 5, self.env_variables['arena_height'] / 2 + 50)) else: for i in range(int(self.env_variables['prey_num'])): self._create_prey() self.recent_cause_of_death = None self.available_prey = self.env_variables["prey_num"] self.predator_body = None self.predator_shape = None self.predator_target = None self.last_action = None self.prey_consumed_this_step = False self.event_captured_by_predator = False self.event_survived_predator = False self.survived_attack = False self.predator_prob = np.zeros(self.env_variables['max_sim_steps_per_episode']) predator_epoch_starts = self.rng.integers( low=self.env_variables['predator_immunity_steps'], high=self.env_variables['max_sim_steps_per_episode'] - self.env_variables['predator_epoch_duration'], size=self.env_variables['predator_epoch_num']) for i in predator_epoch_starts: self.predator_prob[i:i + self.env_variables['predator_epoch_duration']] = self.env_variables['predator_probability_per_epoch_step'] # For Reward tracking (debugging) print(f"""REWARD CONTRIBUTIONS: Energy: {self.energy_associated_reward} Consumption: {self.consumption_associated_reward} Salt: {self.salt_associated_reward} Predator: {self.predator_associated_reward} Wall: {self.wall_associated_reward} """) print(f"actions used: {self.action_used / np.sum(self.action_used)}") self.energy_associated_reward = 0 self.consumption_associated_reward = 0 self.salt_associated_reward = 0 self.predator_associated_reward = 0 self.wall_associated_reward = 0 self.total_attacks_avoided = 0 self.total_attacks_captured = 0 self.action_used = np.zeros(self.num_actions) return dm_env.restart(self.get_observation(action=0, reward=0.))
def _define_collisions(self): """Specifies the collisions that occur in the Pymunk simulation.""" # Collision Types: # 1: Edge # 2: Prey # 3: Fish mouth # 4: Sand grains (not implemented yet) # 5: Predator # 6: Fish body # 7: Prey cloud wall self.prey_col = self.space.add_collision_handler(2, 3) self.prey_col.begin = self._prey_touch_mouth self.prey_col2 = self.space.add_collision_handler(2, 6) self.prey_col2.begin = self._prey_touch_body self.pred_col = self.space.add_collision_handler(5, 3) self.pred_col.begin = self.touch_predator self.pred_col2 = self.space.add_collision_handler(5, 6) self.pred_col2.begin = self.touch_predator self.edge_col = self.space.add_collision_handler(1, 3) if not self.env_variables["arena_wall_bounce"]: self.edge_col.begin = self._touch_wall else: self.edge_col.begin = self._touch_wall_reflect self.edge_pred_col = self.space.add_collision_handler(1, 5) self.edge_pred_col.begin = self._remove_predator self.prey_pred_col = self.space.add_collision_handler(2, 5) self.prey_pred_col.begin = self.no_collision # To prevent the differential wall being hit by fish self.fish_prey_wall = self.space.add_collision_handler(3, 7) self.fish_prey_wall.begin = self.no_collision self.fish_prey_wall2 = self.space.add_collision_handler(6, 7) self.fish_prey_wall2.begin = self.no_collision self.pred_prey_wall2 = self.space.add_collision_handler(5, 7) self.pred_prey_wall2.begin = self.no_collision def _clear_environmental_features(self): """Removes all prey and predators from the simulation""" for i, shp in enumerate(self.prey_shapes): self.space.remove(shp, shp.body) for i, shp in enumerate(self.prey_cloud_wall_shapes): self.space.remove(shp) self.prey_cloud_wall_shapes = [] if self.predator_shape is not None: self._remove_predator() self.predator_location = None self.prey_shapes = [] self.prey_bodies = [] self.prey_identifiers = [] self.paramecia_gaits = [] if self.env_variables["prey_reproduction_mode"]: self.prey_ages = [] self.total_prey_created = 0 def _reproduce_prey(self): """ Handle the reproduction of prey in the environment. """ num_prey = len(self.prey_bodies) p_prey_birth = self.env_variables["prey_birth_rate"] * (self.env_variables["prey_num"] - num_prey) for cloud in self.prey_cloud_locations: if self.rng.random(1) < p_prey_birth: if not self._check_proximity(cloud, self.env_variables["prey_cloud_region_size"]): new_location = ( self.rng.integers(low=cloud[0] - (self.env_variables["prey_cloud_region_size"] / 2), high=cloud[0] + (self.env_variables["prey_cloud_region_size"] / 2)), self.rng.integers(low=cloud[1] - (self.env_variables["prey_cloud_region_size"] / 2), high=cloud[1] + (self.env_variables["prey_cloud_region_size"] / 2)) ) self._create_prey(new_location) self.available_prey += 1 def _reset_salt_gradient(self, salt_source=None): """ Reset the salt gradient in the environment. """ if salt_source is None: salt_source_x = self.rng.integers(0, self.env_variables['arena_width'] - 1) salt_source_y = self.rng.integers(0, self.env_variables['arena_height'] - 1) else: salt_source_x = salt_source[0] salt_source_y = salt_source[1] self.salt_location = [salt_source_x, salt_source_y] salt_distance = (((salt_source_x - self.xp[:, None]) ** 2 + ( salt_source_y - self.yp[None, :]) ** 2) ** 0.5) # Measure of distance from source at every point. self.salt_gradient = np.exp(-self.env_variables["salt_concentration_decay"] * salt_distance) def _create_prey_cloud_walls(self): """ Create walls around each prey cloud to contain the prey. """ for i in self.prey_cloud_locations: half_cloud_size = (self.env_variables["prey_cloud_region_size"] / 2) wall_edges = [ pymunk.Segment( self.space.static_body, (i[0] - half_cloud_size, i[1] - half_cloud_size), (i[0] - half_cloud_size, i[1] + half_cloud_size), 1), pymunk.Segment( self.space.static_body, (i[0] - half_cloud_size, i[1] + half_cloud_size), (i[0] + half_cloud_size, i[1] + half_cloud_size), 1), pymunk.Segment( self.space.static_body, (i[0] + half_cloud_size, i[1] + half_cloud_size), (i[0] + half_cloud_size, i[1] - half_cloud_size), 1), pymunk.Segment( self.space.static_body, (i[0] - half_cloud_size, i[1] - half_cloud_size), (i[0] + half_cloud_size, i[1] - half_cloud_size), 1) ] for s in wall_edges: s.friction = 1. s.group = 1 s.collision_type = 7 self.space.add(s) self.prey_cloud_wall_shapes.append(s) def _create_walls(self): """ Create walls around the arena to contain the fish. """ wall_width = 5 static = [ pymunk.Segment( self.space.static_body, (0, wall_width), (0, self.env_variables['arena_height']), wall_width), pymunk.Segment( self.space.static_body, (wall_width, self.env_variables['arena_height']), (self.env_variables['arena_width'], self.env_variables['arena_height']), wall_width), pymunk.Segment( self.space.static_body, (self.env_variables['arena_width'] - wall_width, self.env_variables['arena_height']), (self.env_variables['arena_width'] - wall_width, wall_width), wall_width), pymunk.Segment( self.space.static_body, (wall_width, wall_width), (self.env_variables['arena_width'], wall_width), wall_width) ] for s in static: s.friction = 1. s.group = 1 s.collision_type = 1 self.space.add(s)
[docs] @staticmethod def no_collision(arbiter, space, data): return False
def _touch_wall_reflect(self, arbiter, space, data): """ Handle the reflection of the fish when it touches a wall. """ new_position_x, new_position_y = self.fish.body.position if new_position_x < 40: # Wall d new_position_x = 40 + self.env_variables["fish_head_radius"] + \ self.env_variables["fish_tail_length"] elif new_position_x > self.env_variables['arena_width'] - 40: # wall b new_position_x = self.env_variables['arena_width'] - ( 40 + self.env_variables["fish_head_radius"] + self.env_variables["fish_tail_length"]) if new_position_y < 40: # wall a new_position_y = 40 + self.env_variables["fish_head_radius"] + \ self.env_variables["fish_tail_length"] elif new_position_y > self.env_variables['arena_height'] - 40: # wall c new_position_y = self.env_variables['arena_height'] - ( 40 + self.env_variables["fish_head_radius"] + self.env_variables["fish_tail_length"]) new_position = pymunk.Vec2d(new_position_x, new_position_y) self.fish.body.position = new_position self.fish.body.velocity = (0, 0) if self.fish.body.angle < np.pi: self.fish.body.angle += np.pi else: self.fish.body.angle -= np.pi self.fish.touched_edge = True return True def _touch_wall(self, arbiter, space, data): """ Handle the collision of the fish with a wall. """ x, y = self.fish.body.position w = self.env_variables["arena_width"] h = self.env_variables["arena_height"] LOWER_BOUND = 8 RESET = 10 new_x = x new_y = y # Check X boundaries if x < LOWER_BOUND: new_x = RESET elif x > w - (LOWER_BOUND - 1): new_x = w - (RESET - 1) # Check Y boundaries if y < LOWER_BOUND: new_y = RESET elif y > h - (LOWER_BOUND - 1): new_y = h - (RESET - 1) # Since this is a collision handler, we assume we stop the fish regardless self.fish.body.velocity = (0, 0) # Only update position if we actually moved the fish if new_x != x or new_y != y: self.fish.body.position = pymunk.Vec2d(new_x, new_y) self.fish.touched_edge = True self.fish.touched_edge_this_step = True return True def _create_prey(self, prey_position=None, prey_orientation=None, prey_gait=None, prey_age=None): """ Create a new prey entity in the simulation environment. This method initializes a prey object with physical properties, position, orientation, and behavioral characteristics. The prey can be created at a random location or at a specified position, and can belong to different movement gaits (slow, normal, fast). Args: prey_position (tuple, optional): The (x, y) coordinates for the prey's initial position. If None, a random position is generated either within the arena or within a prey cloud region based on env_variables['prey_cloud_num']. Defaults to None. prey_orientation (float, optional): The initial angle/orientation of the prey in radians. If None, the prey is considered newly created and receives a unique identifier. Defaults to None. prey_gait (int, optional): The movement gait type (0=normal, 1=slow, 2=fast). Only used when prey_orientation is not None. Defaults to None. prey_age (int, optional): The age of the prey, used when prey_reproduction_mode is enabled. Only used when prey_orientation is not None. Defaults to None. Returns: None Side Effects: - Appends a new pymunk.Body to self.prey_bodies - Appends a new pymunk.Circle to self.prey_shapes - Updates self.prey_identifiers when creating new prey - Updates self.total_prey_created counter for new prey - Appends gait type to self.paramecia_gaits - Appends age to self.prey_ages if reproduction mode is enabled - Adds the prey body and shape to the pymunk space Notes: - New prey (prey_orientation=None) are assigned random positions with margins from arena edges - When prey_cloud_num > 0, new prey spawn within randomly selected cloud regions - Prey shapes have elasticity set to 1.0 and collision_type set to 2 - Gait probabilities are determined by env_variables['prey_p_fast'] and env_variables['prey_p_slow'] """ self.prey_bodies.append(pymunk.Body(SMALL_MASS, self.prey_inertia)) self.prey_shapes.append(pymunk.Circle(self.prey_bodies[-1], self.env_variables['prey_radius'])) self.prey_shapes[-1].elasticity = 1.0 self.prey_bodies[-1].angle = self.rng.uniform(0, np.pi * 2) if prey_position is None: if self.env_variables["prey_cloud_num"] == 0: self.prey_bodies[-1].position = ( self.rng.integers( self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'] + 40, self.env_variables['arena_width'] - ( self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'] + 40)), self.rng.integers( self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'] + 40, self.env_variables['arena_height'] - ( self.env_variables['prey_radius'] + self.env_variables['fish_mouth_radius'] + 40))) else: cloud = self.rng.choice(self.prey_cloud_locations) self.prey_bodies[-1].position = ( self.rng.integers(low=cloud[0] - (self.env_variables["prey_cloud_region_size"] / 2), high=cloud[0] + (self.env_variables["prey_cloud_region_size"] / 2)), self.rng.integers(low=cloud[1] - (self.env_variables["prey_cloud_region_size"] / 2), high=cloud[1] + (self.env_variables["prey_cloud_region_size"] / 2)) ) else: self.prey_bodies[-1].position = prey_position if prey_orientation is not None: self.prey_bodies[-1].angle = prey_orientation if prey_orientation is None: # When is a new prey being created self.prey_identifiers.append(self.total_prey_created) self.total_prey_created += 1 self.paramecia_gaits.append( self.rng.choice([0, 1, 2], 1, p=[1 - (self.env_variables["prey_p_fast"] + self.env_variables["prey_p_slow"]), self.env_variables["prey_p_slow"], self.env_variables["prey_p_fast"]])[0]) if self.env_variables["prey_reproduction_mode"]: self.prey_ages.append(0) else: self.paramecia_gaits.append(int(prey_gait)) if self.env_variables["prey_reproduction_mode"]: self.prey_ages.append(int(prey_age)) self.prey_shapes[-1].collision_type = 2 self.space.add(self.prey_bodies[-1], self.prey_shapes[-1]) def _check_proximity(self, feature_position, sensing_distance): """ Check if the fish is within the sensing distance of a feature. """ sensing_area = [[feature_position[0] - sensing_distance, feature_position[0] + sensing_distance], [feature_position[1] - sensing_distance, feature_position[1] + sensing_distance]] is_in_area = sensing_area[0][0] <= self.fish.body.position[0] <= sensing_area[0][1] and \ sensing_area[1][0] <= self.fish.body.position[1] <= sensing_area[1][1] if is_in_area: return True else: return False def _check_proximity_all_prey(self, sensing_distance): """ Check if the fish is within the sensing distance of any prey. """ all_prey_positions = np.array([pr.position for pr in self.prey_bodies]) fish_position = np.expand_dims(np.array(self.fish.body.position), 0) fish_prey_vectors = all_prey_positions - fish_position fish_prey_distances = ((fish_prey_vectors[:, 0] ** 2) + (fish_prey_vectors[:, 1] ** 2)) ** 0.5 within_range = fish_prey_distances < sensing_distance return within_range def _move_prey(self, micro_step): """ Move all prey in the environment according to their current gaits and behaviors. This method updates prey positions and orientations based on their movement patterns, which include slow, fast, and stationary gaits. Prey can also respond to stimuli such as being touched by the fish or sensing the fish nearby. Parameters ---------- micro_step : int The current micro-step within a full simulation step. Certain operations like gait switching and angle changes only occur when micro_step is 0. Notes ----- The method performs the following operations: - Applies impulses based on current gait (slow, fast, or stationary) - Adds jump impulse if prey was touched by fish (when fish hasn't consumed prey) - Updates gaits probabilistically once per step (micro_step == 0) - Applies random angle changes including occasional large turns - Checks proximity to fish for escape behavior - Triggers escape impulses when prey sense nearby fish The prey movement is deterministic within each micro-step but probabilistic across full steps, with behavior controlled by various environment variables (prey_p_switch, prey_p_slow, prey_p_fast, prey_max_turning_angle, prey_p_large_turn, etc.). """ if len(self.prey_bodies) == 0: return # Generate impulses impulse_types = [0, self.prey_impulse_slow, self.prey_impulse_fast] impulses = [impulse_types[gait] for gait in self.paramecia_gaits] if not self.fish.prey_consumed: for touched_index in self.touched_prey_indices: # Impulse from being touched by fish impulses[touched_index] += self.prey_impulse_jump # Do once per step. if micro_step == 0: gaits_to_switch = self.rng.random(len(self.prey_shapes)) < self.env_variables["prey_p_switch"] switch_to = self.rng.choice([0, 1, 2], len(self.prey_shapes), p=[1 - (self.env_variables["prey_p_slow"] + self.env_variables["prey_p_fast"]), self.env_variables["prey_p_slow"], self.env_variables["prey_p_fast"]]) self.paramecia_gaits = [switch_to[i] if gaits_to_switch[i] else old_gait for i, old_gait in enumerate(self.paramecia_gaits)] # Angles of change angle_changes = self.rng.uniform(-self.env_variables['prey_max_turning_angle'], self.env_variables['prey_max_turning_angle'], len(self.prey_shapes)) # Large angle changes large_turns = self.rng.uniform(-np.pi, np.pi, len(self.prey_shapes)) large_turns_implemented = self.rng.random(len(self.prey_shapes)) < self.env_variables["prey_p_large_turn"] angle_changes = angle_changes + (large_turns * large_turns_implemented) self.prey_within_range = self._check_proximity_all_prey(self.env_variables["prey_sensing_distance"]) for i, prey_body in enumerate(self.prey_bodies): if micro_step == 0: prey_body.angle = prey_body.angle + angle_changes[i] prey_body.apply_impulse_at_local_point((impulses[i], 0)) if self.prey_within_range[i]: # Motion from prey escape if self.rng.random() < self._prey_escape_p_per_physics_step: prey_body.apply_impulse_at_local_point((self.prey_impulse_jump, 0)) def _prey_touch_mouth(self, arbiter, space, data): """Handle the event when prey touches the mouth of the fish, specifically check if the capture is valid.""" valid_capture = False for i, shp in enumerate(self.prey_shapes): if shp == arbiter.shapes[0]: touched_prey_index = i break if self.fish.capture_possible: # Check if angles line up. prey_position = self.prey_bodies[touched_prey_index].position fish_position = self.fish.body.position # vector from fish to prey vector = prey_position - fish_position # Taking fish as origin raw_diff = np.arctan2(vector[1], vector[0]) - self.fish.body.angle angle_diff = np.arctan2(np.sin(raw_diff), np.cos(raw_diff)) if np.abs(angle_diff) < self.env_variables["capture_swim_permissive_angle"]: valid_capture = True self._remove_prey(touched_prey_index) if valid_capture: self.prey_caught += 1 self.fish.prey_consumed = True self.prey_consumed_this_step = True return False else: self.touched_prey_indices.append(touched_prey_index) return True def _prey_touch_body(self, arbiter, space, data): """Handle the event when prey touches the body of the fish (for potential escape).""" touched_prey_index = None for i, shp in enumerate(self.prey_shapes): if shp == arbiter.shapes[0]: touched_prey_index = i break if touched_prey_index is None: # already removed (prey touched mouth first) return True self.touched_prey_indices.append(touched_prey_index) return True def _remove_prey(self, prey_index): """Remove prey from the simulation.""" self.space.remove(self.prey_shapes[prey_index], self.prey_shapes[prey_index].body) del self.prey_shapes[prey_index] del self.prey_bodies[prey_index] if self.env_variables["prey_reproduction_mode"]: del self.prey_ages[prey_index] del self.paramecia_gaits[prey_index] del self.prey_identifiers[prey_index] while True: if prey_index in self.touched_prey_indices: self.touched_prey_indices.remove(prey_index) else: break def _move_predator(self): """Move the predator towards its target (original fish location).""" if self._check_predator_at_target() or self._check_predator_outside_walls(): self._remove_predator() else: self.predator_body.angle = np.pi / 2 - np.arctan2( self.predator_target[0] - self.predator_body.position[0], self.predator_target[1] - self.predator_body.position[1]) self.predator_body.apply_impulse_at_local_point((self.predator_impulse, 0))
[docs] def touch_predator(self, arbiter, space, data): """Handle the event when the fish touches the predator.""" if self.env_variables["test_sensory_system"]: self._remove_predator() if self.num_steps > self.env_variables['predator_immunity_steps']: self.fish.touched_predator = True return False else: return True
[docs] def get_predator_angles_distance(self): """ Get the angles and distance to the predator for visual processing. """ predator_position = None if self.predator_stim_location is not None: this_predator_stim_location = self.predator_stim_location[0, self.num_steps-1, :] if (this_predator_stim_location[0] != 0) or (this_predator_stim_location[1] != 0): predator_position = this_predator_stim_location if self.predator_body is not None: predator_position = self.predator_body.position if predator_position is None: return np.nan, np.nan, np.nan fish_position = self.fish.body.position distance = np.sqrt( (predator_position[0] - fish_position[0]) ** 2 + (predator_position[1] - fish_position[1]) ** 2) predator_half_angular_size = np.arctan2(self.env_variables['predator_radius'], distance) distance -= self.env_variables['predator_radius'] predator_vector = predator_position - fish_position # Taking fish as origin # Will generate values between -pi/2 and pi/2 which require adjustment depending on quadrant. angle = np.arctan2(predator_vector[1], predator_vector[0]) left_edge = angle + predator_half_angular_size right_edge = angle - predator_half_angular_size left_edge = np.arctan2(np.sin(left_edge), np.cos(left_edge)) # Normalise to -pi to pi right_edge = np.arctan2(np.sin(right_edge), np.cos(right_edge)) # Normalise to -pi to pi return left_edge, right_edge, distance
def _get_fish_proximity_to_walls(self): """ Get the distances from the fish to each wall. """ fish_position = self.fish.body.position left_distance = fish_position[0] right_distance = self.env_variables["arena_width"] - fish_position[0] bottom_distance = self.env_variables["arena_height"] - fish_position[1] top_distance = fish_position[1] return left_distance, bottom_distance, right_distance, top_distance def _select_predator_angle_of_attack(self): """ Select the angle of attack for the predator based on the fish's proximity to the walls. """ left_dist, bottom_dist, right_dist, top_dist = self._get_fish_proximity_to_walls() left = left_dist < self.env_variables["predator_distance_from_fish"] right = right_dist < self.env_variables["predator_distance_from_fish"] top = top_dist < self.env_variables["predator_distance_from_fish"] bottom = bottom_dist < self.env_variables["predator_distance_from_fish"] if left and top: angle_from_fish = self.rng.integers(90, 180) elif left and bottom: angle_from_fish = self.rng.integers(0, 90) elif right and top: angle_from_fish = self.rng.integers(180, 270) elif right and bottom: angle_from_fish = self.rng.integers(270, 360) elif left: angle_from_fish = self.rng.integers(0, 180) elif top: angle_from_fish = self.rng.integers(90, 270) elif bottom: angles = [self.rng.integers(270, 360), self.rng.integers(0, 90)] angle_from_fish = self.rng.choice(angles) elif right: angle_from_fish = self.rng.integers(180, 360) else: angle_from_fish = self.rng.integers(0, 360) angle_from_fish = np.radians(angle_from_fish) return angle_from_fish def _create_predator(self): """Create the predator.""" self.predator_body = pymunk.Body(LARGE_MASS, self.predator_inertia) self.predator_shape = pymunk.Circle(self.predator_body, self.env_variables['predator_radius']) self.predator_shape.elasticity = 1.0 fish_position = self.fish.body.position if self.env_variables["test_sensory_system"]: angle_from_fish = np.radians(300) else: angle_from_fish = self._select_predator_angle_of_attack() dy = self.env_variables["predator_distance_from_fish"] * np.cos(angle_from_fish) dx = self.env_variables["predator_distance_from_fish"] * np.sin(angle_from_fish) x_position = fish_position[0] + dx y_position = fish_position[1] + dy self.predator_body.position = (x_position, y_position) self.predator_target = fish_position self.predator_location = (x_position, y_position) self.predator_shape.collision_type = 5 self.predator_shape.filter = pymunk.ShapeFilter( mask=pymunk.ShapeFilter.ALL_MASKS() ^ 2) # Category 2 objects cant collide with predator self.space.add(self.predator_body, self.predator_shape) def _check_predator_outside_walls(self): x_position, y_position = self.predator_body.position[0], self.predator_body.position[1] if x_position < 0: return True elif x_position > self.env_variables["arena_width"]: return True if y_position < 0: return True elif y_position > self.env_variables["arena_height"]: return True def _check_predator_at_target(self): if (round(self.predator_body.position[0]), round(self.predator_body.position[1])) == ( round(self.predator_target[0]), round(self.predator_target[1])): self.predator_attacks_avoided += 1 return True else: return False def _remove_predator(self, arbiter=None, space=None, data=None): if self.predator_body is not None: self.space.remove(self.predator_shape, self.predator_shape.body) self.predator_shape = None self.predator_body = None self.predator_location = None self.predator_target = None if not self.fish.touched_predator: self.survived_attack = True return False def _bring_fish_in_bounds(self): """Bring the fish back into the arena bounds if it goes out.""" if self.fish.body.position[0] < 4 or self.fish.body.position[1] < 4 or \ self.fish.body.position[0] > self.env_variables["arena_width"] - 4 or \ self.fish.body.position[1] > self.env_variables["arena_height"] - 4: new_position = pymunk.Vec2d(np.clip(self.fish.body.position[0], 6, self.env_variables["arena_width"] - 30), np.clip(self.fish.body.position[1], 6, self.env_variables["arena_height"] - 30)) self.fish.body.position = new_position
[docs] def get_info(self): """Get the current information of the simulation for logging or debugging.""" info_dict = { 'fish_x': [self.fish.body.position[0]], 'fish_y': [self.fish.body.position[1]], 'fish_angle': [self.fish.body.angle], 'prey_x': [[pr.position[0] for pr in self.prey_bodies]], 'prey_y': [[pr.position[1] for pr in self.prey_bodies]], 'predator_x': [self.predator_body.position[0]] if self.predator_body else [0], 'predator_y': [self.predator_body.position[1]] if self.predator_body else [0], 'event_consumed_prey': [self.prey_consumed_this_step], 'event_survived_predator': [self.event_survived_predator], 'event_captured_by_predator': [self.event_captured_by_predator], } return info_dict
[docs] def step(self, action: int) -> dm_env.TimeStep: """Performs a step in the simulation with the given action.""" self.event_survived_predator = False self.event_captured_by_predator = False self.action_used[action] += 1 if self._reset_next_step: return self.reset() self.fish.making_capture = False self.prey_consumed_this_step = False self.last_action = action reward = 0 self.fish.take_action(action) done = False if self._is_predator_spawned(): self._create_predator() for micro_step in range(self.env_variables['phys_steps_per_sim_step']): self.touched_prey_indices = [] if self.fish.making_capture and self.capture_start <= micro_step <= self.capture_end: self.fish.capture_possible = True else: self.fish.capture_possible = False self.space.step(self.phys_dt) self._move_prey(micro_step) if self.predator_body is not None: self._move_predator() if self.fish.prey_consumed: if len(self.prey_shapes) == 0: done = True self.recent_cause_of_death = "Prey-All-Eaten" self.fish.prey_consumed = False if self.fish.touched_edge: self.fish.touched_edge = False if self.fish.touched_predator: self.event_captured_by_predator = True reward += self.env_variables['reward_predator_caught'] self.survived_attack = False self.predator_associated_reward += self.env_variables["reward_predator_caught"] self.total_attacks_captured += 1 self._remove_predator() self.fish.touched_predator = False if (self.predator_body is None) and self.survived_attack: self.event_survived_predator = True reward += self.env_variables["reward_predator_avoidance"] self.predator_associated_reward += self.env_variables["reward_predator_avoidance"] self.total_attacks_avoided += 1 self.survived_attack = False self._bring_fish_in_bounds() # Energy level energy_reward = self.fish.update_energy_level(self.prey_consumed_this_step) reward += energy_reward if self.prey_consumed_this_step: reward += self.env_variables["reward_consumption"] self.consumption_associated_reward += self.env_variables["reward_consumption"] self.energy_associated_reward += energy_reward self.energy_level_log.append(self.fish.energy_level) if self.fish.energy_level < 0: print("Fish ran out of energy") done = True self.recent_cause_of_death = "Starvation" # Salt if self.env_variables["salt_enabled"]: self.salt_concentration = self.salt_gradient[int(self.fish.body.position[0]), int(self.fish.body.position[1])] reward += self.env_variables["reward_salt_factor"] * self.salt_concentration self.salt_associated_reward += self.env_variables["reward_salt_factor"] * self.salt_concentration else: self.salt_concentration = 0 if self.fish.touched_edge_this_step: reward += self.env_variables["reward_wall_touch"] self.wall_associated_reward += self.env_variables["reward_wall_touch"] self.fish.touched_edge_this_step = False if self.env_variables["prey_reproduction_mode"] and self.env_variables["prey_cloud_num"] > 0 and not self.env_variables["test_sensory_system"]: self._reproduce_prey() self.prey_ages = [age + 1 for age in self.prey_ages] for i, age in enumerate(self.prey_ages): if age > self.env_variables["prey_safe_duration"] and\ self.rng.random(1) < self.env_variables["prey_p_death"]: if not self._check_proximity(self.prey_bodies[i].position, 200): self._remove_prey(i) self.available_prey -= 1 self.num_steps += 1 if self.num_steps >= self.env_variables["max_sim_steps_per_episode"]: print("Fish ran out of time") done = True self.recent_cause_of_death = "Time" observation = self.get_observation(action, reward) if done: self._reset_next_step = True return dm_env.termination(reward=reward, observation=observation) else: return dm_env.transition(reward=reward, observation=observation)
[docs] def observation_spec(self) -> specs.BoundedArray: """Returns the observation spec.""" len_internal_state = 3 vis_shape = (len(self.fish.left_eye.interpolated_observation_angles), 3, 2) obs_spec = [specs.Array(shape=vis_shape, dtype='float32', name="visual_input"), specs.Array(shape=(len_internal_state,), dtype='float32', name="internal_state")] return OAR(observation=obs_spec, action=specs.Array(shape=(), dtype=int), reward=specs.Array(shape=(), dtype=np.float64), )
[docs] def action_spec(self) -> specs.DiscreteArray: """Returns the action spec.""" return specs.DiscreteArray( dtype=int, num_values=self.num_actions, name="action")
[docs] def get_observation(self, action, reward): self.arena.red_FOV.update_field_of_view(self.fish.body.position) self.arena.uv_FOV.update_field_of_view(self.fish.body.position) visual_input = self.resolve_visual_input() # print minimal and maximal values of visual input: visual_input = visual_input.astype(np.float32) # Calculate internal state is_in_light = self.fish.body.position[1] > self.dark_row internal_state = np.array([is_in_light, self.fish.energy_level, self.salt_concentration], dtype=np.float32) return OAR(observation=[visual_input, internal_state],action=action, reward=reward)
def _is_predator_spawned(self): if self.env_variables["test_sensory_system"]: if self.num_steps > 10 and not self.tested_predator: self.tested_predator = True return True else: if self.predator_location is None and self.rng.random() < self.predator_prob[self.num_steps]: buffer_region = self.env_variables["predator_radius"] * 1.5 left_dist, bottom_dist, right_dist, top_dist = self.get_fish_proximity_to_walls() near_wall = left_dist < buffer_region or right_dist < buffer_region or top_dist < buffer_region or bottom_dist < buffer_region if not near_wall: return True return False
[docs] def resolve_visual_input(self): """ Compute the visual input from both eyes of the fish. This method calculates the photon readings from the left and right eyes based on the fish's current position, orientation, and the environment state. It accounts for prey locations, predator positions, sediment masking, and UV luminance. Returns ------- np.ndarray A 3D array containing stacked photon readings from the left and right eyes. Shape is (height, width, 2) where the last dimension represents [left_eye, right_eye] photon values. Notes ----- - Eye positions are calculated relative to the fish's body axis and are translated by 'eyes_biasx' parameter - Prey locations include both actual prey bodies and stimulated prey locations from the current timestep - Predator information (angles and distance) is computed if a predator exists - The method uses masked sediment and UV luminance masks from the arena - Prey locations are transformed relative to the fish's position and shifted by max_uv_range for coordinate alignment """ # Relative eye positions to center of FOV right_eye_pos = ( -np.cos(np.pi / 2 - self.fish.body.angle) * self.env_variables[ 'eyes_biasx'], +np.sin(np.pi / 2 - self.fish.body.angle) * self.env_variables[ 'eyes_biasx']) left_eye_pos = ( +np.cos(np.pi / 2 - self.fish.body.angle) * self.env_variables[ 'eyes_biasx'], -np.sin(np.pi / 2 - self.fish.body.angle) * self.env_variables[ 'eyes_biasx']) if self.predator_body is not None: predator_bodies = np.array([self.predator_body.position]) else: predator_bodies = np.array([]) prey_locations = [i.position for i in self.prey_bodies] if self.prey_stim_locations is not None: step_stim_locations = [self.prey_stim_locations[ii, self.num_steps-1, :] for ii in range(self.prey_stim_locations.shape[0])] # remove prey_locations.extend(step_stim_locations) masked_sediment = self.arena.get_masked_sediment() uv_luminance_mask = self.arena.get_uv_luminance_mask() if len(prey_locations) > 0: prey_locations_array = np.array(prey_locations) - np.array(self.fish.body.position) + self.max_uv_range else: prey_locations_array = np.array([]) predator_left, predator_right, predator_distance = self.get_predator_angles_distance() left_photons = self.fish.left_eye.read(masked_sediment, left_eye_pos[0], left_eye_pos[1], self.fish.body.angle, uv_luminance_mask, prey_locations_array, predator_left, predator_right, predator_distance) right_photons = self.fish.right_eye.read(masked_sediment, right_eye_pos[0], right_eye_pos[1], self.fish.body.angle, uv_luminance_mask, prey_locations_array, predator_left, predator_right, predator_distance) observation = np.dstack((left_photons, right_photons)) return observation