import h5py
from matplotlib.colors import LinearSegmentedColormap
import numpy as np
from scipy.stats import zscore, multivariate_normal
from scipy import odr
import matplotlib.pyplot as plt
plt.rcParams['svg.fonttype'] = 'none' # To get editable text in Illustrator
PIXEL_SIZE = 0.058 # mm per tracking camera pixel, as per the original code
ALL_BOUT_NAMES = ['SCS', 'LCS', 'BS', 'O-bend', 'J-turn', 'SLC', 'Slow1', 'RT', 'Slow2', 'LLC', 'AS', 'SAT', 'HAT']
IS_TURN = [False, False, False, True, True, True, False, True, False, True, False, True, True]
IS_CAPTURE = [True, True, False, False, False, False, False, False, False, False, False, False, False]
# get colors from tableau 20
COLORS = [(31/255, 119/255, 180/255), (255/255, 127/255, 14/255), (44/255, 160/255, 44/255), (214/255, 39/255, 40/255),
(148/255, 103/255, 189/255), (140/255, 86/255, 75/255), (227/255, 119/255, 194/255), (127/255, 127/255, 127/255),
(188/255, 189/255, 34/255), (23/255, 190/255, 207/255), (174/255, 199/255, 232/255), (255/255, 187/255, 120/255),
(152/255, 223/255, 138/255)]
BOUT_ENERGY = np.array([0.03, np.nan, np.nan, np.nan, 0.04, np.nan, 0.01, 0.15, 0.15, np.nan, 0.025, np.nan, 0.15])
[docs]
def get_angles_and_distances(times, head_pos, orientation):
distance = np.zeros(times.shape[0])
ori_change = np.zeros(times.shape[0])
for i in range(times.shape[0]):
this_duration = int(times[i, 1] - times[i, 0]) + 9
if this_duration >= 175:
distance[i] = np.nan
ori_change[i] = np.nan
else:
head_pos_change = head_pos[i, this_duration, :] - head_pos[i, 0, :]
distance[i] = PIXEL_SIZE * (head_pos_change[0] ** 2 + head_pos_change[1] ** 2) ** 0.5
ori_change[i] = (orientation[i, this_duration] - orientation[i, 0])
ori_change[i] = np.arctan2(np.sin(ori_change[i]), np.cos(ori_change[i])) # Normalize to [-pi, pi]
if distance[i] < 0.01:
distance[i] = np.nan
ori_change[i] = np.nan
return distance, ori_change
[docs]
class Actions:
def __init__(self, h5_file_path=None, bouts_to_save=None):
if h5_file_path is None:
self.actions = []
else:
self.actions = self.get_extracted_actions(h5_file_path, bouts_to_save=bouts_to_save)
[docs]
def get_action(self, action_name):
for action in self.actions:
if action['name'] == action_name:
return action
raise ValueError(f"Action {action_name} not found in actions.")
[docs]
def display_actions(self):
xx, yy = np.mgrid[0:20:.1, -180:180:.5]
pos = np.dstack((xx, yy))
plt.figure()
for id, action in enumerate(self.actions):
if action['name'] == 'Null':
scatter = plt.scatter(0, 0, color=action['color'], label='Null', s=100, edgecolors='black', zorder=5)
continue
a_mean = np.copy(action['mean'])
a_cov = np.copy(action['cov'])
# convert angle to degrees:
a_mean[1] *= 180/np.pi
a_cov[1, 1] *= (180/np.pi)**2
a_cov[0, 1] *= (180/np.pi)
a_cov[1, 0] *= (180/np.pi)
rv = multivariate_normal(a_mean, a_cov, allow_singular=True)
pdf = rv.pdf(pos)
half_max = np.max(pdf) / 2
# CS = plt.contour(xx, yy, pdf, levels=[half_max], alpha=0.8, colors=[action['color']])
levels = np.linspace(np.max(pdf)/2, np.max(pdf), 10)
# create a colormap that goes from transparent to the action color, ie rgb stay the same but alpha goes from 0 to 1
this_cmap = LinearSegmentedColormap.from_list('this_cmap', [(action['color'][0], action['color'][1], action['color'][2], 0), (action['color'][0], action['color'][1], action['color'][2], 1)])
CS = plt.contourf(xx, yy, pdf, levels=levels, cmap=this_cmap, antialiased=True)
label = f'{action["name"]}'
# plt.clabel(CS, CS.levels, fmt={CS.levels[0]:label}, fontsize=8, inline_spacing=1)
plt.xlabel('Distance (mm)')
plt.ylabel('Angle (degrees)')
# create a legend with the action names and colors
# create a custom legend
import matplotlib.patches as mpatches
patches = []
for action in self.actions:
# if action['name'] == 'Null':
# continue
this_name = action['name']
if '_L' in this_name:
this_name = this_name.split('_L')[0]
elif '_R' in this_name:
continue
patches.append(mpatches.Patch(color=action['color'], label=this_name))
plt.legend(handles=patches, loc='upper right', borderaxespad=0.)
plt.title('Action Distributions')
[docs]
def get_all_actions(self):
return self.actions
[docs]
def sharpen_distributions(self, narrowing_coefficient=3, capture_narrowing_coefficient=10):
"""
Sharpens the distributions by dividing the covariance matrix by a narrowing coefficient.
This is useful for making the actions more distinct.
"""
for action in self.actions:
if action['is_capture']:
action['cov'] /= capture_narrowing_coefficient
action['cov'][1, 1] /= capture_narrowing_coefficient # reduce angle variance for capture actions
else:
action['cov'] /= narrowing_coefficient
[docs]
def to_hdf5(self, file_path):
"""
Saves the actions to an HDF5 file.
"""
with h5py.File(file_path, 'w') as f:
for i, action in enumerate(self.actions):
group = f.create_group(action['name'])
group.create_dataset('mean', data=action['mean'])
group.create_dataset('cov', data=action['cov'])
group.attrs['is_turn'] = action['is_turn']
group.attrs['is_capture'] = action['is_capture']
group.attrs['id'] = i
group.attrs['color'] = action['color']
[docs]
def from_hdf5(self, file_path):
"""
Loads the actions from an HDF5 file.
"""
self.actions = []
ids = []
with h5py.File(file_path, 'r') as f:
for group_name in f.keys():
group = f[group_name]
action = {
'name': group_name,
'mean': group['mean'][:],
'cov': group['cov'][:],
'is_turn': group.attrs['is_turn'],
'is_capture': group.attrs['is_capture'],
'color': group.attrs['color']
}
ids.append(group.attrs['id'])
self.actions.append(action)
# sort actions by id
self.actions = [x for _, x in sorted(zip(ids, self.actions), key=lambda pair: pair[0])]
print(f"Loaded {len(self.actions)} actions from {file_path}")
[docs]
def get_opposing_dict(self):
"""
Returns a dictionary mapping each action to its opposing action.
"""
opposing_dict = {}
for id, action in enumerate(self.actions):
if action['is_turn']:
opposing_name = action['name'].replace('_R', '_L') if '_R' in action['name'] else action['name'].replace('_L', '_R')
opposing_id = [i for i, a in enumerate(self.actions) if a['name'] == opposing_name][0]
opposing_dict[id] = opposing_id
return opposing_dict
[docs]
def add_null_action(self):
"""
Adds a null action with zero mean and covariance (no movement).
"""
null_action = {
'name': 'Null',
'mean': np.array([0.0, 0.0]),
'cov': np.array([[0.0, 0.0], [0.0, 0.0]]),
'is_turn': False,
'is_capture': False,
'color': (0, 0, 0)
}
self.actions.append(null_action)
if __name__ == "__main__":
import matplotlib.pyplot as plt
h5_file_path = "./external_data/filtered_jmpool_kin.h5" # Bout data file path
# obtained from https://www.pnas.org/doi/10.1073/pnas.2410254121
# https://doi.org/10.5281/zenodo.13605471
actions = Actions(h5_file_path, bouts_to_save=None) # Use None to extract all bouts
actions.sharpen_distributions(narrowing_coefficient=3, capture_narrowing_coefficient=10)
actions.display_actions()
actions.add_null_action()
print(f'opposing_dict: {actions.get_opposing_dict()}')
actions.to_hdf5("actions_all_bouts_with_null.h5")
actions.from_hdf5("actions_all_bouts_with_null.h5")
actions.display_actions()
plt.show()