Skip to content

Commit d256198

Browse files
committed
Fix NSD eye-tracking trace figure
1 parent 0513b8d commit d256198

6 files changed

Lines changed: 1621 additions & 2 deletions

File tree

711 KB
Loading
-3.7 KB
Loading
-3.33 KB
Loading
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
"""Generate NSD eye-tracking teaching figures.
2+
3+
This script intentionally uses the small cached overlay table generated during
4+
the feasibility check. The key coordinate check is:
5+
6+
x_plot = (x + 2) / 4
7+
y_plot = (2 - y) / 4
8+
9+
That means the gaze coordinates match a 4.0 x 4.0 degree target-image square,
10+
not the larger 8.4 x 8.4 fMRI stimulus-frame assumption.
11+
"""
12+
13+
from __future__ import annotations
14+
15+
from pathlib import Path
16+
17+
import matplotlib.pyplot as plt
18+
import numpy as np
19+
import pandas as pd
20+
from matplotlib.collections import LineCollection
21+
from matplotlib.patches import Rectangle
22+
from PIL import Image
23+
24+
25+
SCRIPT_DIR = Path(__file__).resolve().parent
26+
REPO_ROOT = SCRIPT_DIR.parents[1]
27+
ASSET_DIR = REPO_ROOT / "assets" / "img" / "projects"
28+
TARGET_IMAGE = (
29+
REPO_ROOT
30+
/ "data"
31+
/ "nsd_eyetracking"
32+
/ "nsddata"
33+
/ "experiments"
34+
/ "nsdimagery"
35+
/ "rawtargetimages"
36+
/ "setB"
37+
/ "shared0385_nsd28752.png"
38+
)
39+
40+
CACHED_POINTS = SCRIPT_DIR / "nsd_eye_overlay_points.csv"
41+
TMP_POINTS = REPO_ROOT / "tmp" / "nsd_eye_overlay_points.csv"
42+
43+
44+
def load_points() -> pd.DataFrame:
45+
"""Load the cached overlay points and verify the 4-degree mapping."""
46+
source = CACHED_POINTS if CACHED_POINTS.exists() else TMP_POINTS
47+
if not source.exists():
48+
raise FileNotFoundError(
49+
"Expected nsd_eye_overlay_points.csv in "
50+
f"{CACHED_POINTS} or {TMP_POINTS}"
51+
)
52+
53+
points = pd.read_csv(source)
54+
required = {
55+
"window_id",
56+
"seconds_in_image_window",
57+
"x",
58+
"y",
59+
"velocity",
60+
"fixation_candidate",
61+
"x_plot",
62+
"y_plot",
63+
}
64+
missing = required.difference(points.columns)
65+
if missing:
66+
raise ValueError(f"Missing required columns: {sorted(missing)}")
67+
68+
x_error = np.nanmax(np.abs(((points["x"] + 2) / 4) - points["x_plot"]))
69+
y_error = np.nanmax(np.abs(((2 - points["y"]) / 4) - points["y_plot"]))
70+
if x_error > 1e-9 or y_error > 1e-9:
71+
raise ValueError(
72+
"Overlay points do not match the 4.0-degree image-square mapping: "
73+
f"x_error={x_error:.3g}, y_error={y_error:.3g}"
74+
)
75+
76+
return points.sort_values(["window_id", "seconds_in_image_window"])
77+
78+
79+
def add_time_colored_trace(ax, data: pd.DataFrame, cmap, norm, linewidth: float = 1.8):
80+
"""Draw one gaze trace colored by seconds after image onset."""
81+
xy = data[["x", "y"]].to_numpy()
82+
if len(xy) < 2:
83+
return
84+
85+
segments = np.stack([xy[:-1], xy[1:]], axis=1)
86+
lines = LineCollection(segments, cmap=cmap, norm=norm, linewidth=linewidth, alpha=0.82)
87+
lines.set_array(data["seconds_in_image_window"].iloc[1:].to_numpy())
88+
ax.add_collection(lines)
89+
90+
ax.scatter(
91+
data["x"].iloc[0],
92+
data["y"].iloc[0],
93+
s=24,
94+
color="white",
95+
edgecolor="black",
96+
linewidth=0.7,
97+
zorder=5,
98+
)
99+
ax.scatter(
100+
data["x"].iloc[-1],
101+
data["y"].iloc[-1],
102+
s=24,
103+
color="black",
104+
edgecolor="white",
105+
linewidth=0.7,
106+
zorder=5,
107+
)
108+
109+
110+
def format_image_axis(ax):
111+
"""Use the verified 4.0 x 4.0 degree image square on an axis."""
112+
ax.add_patch(Rectangle((-2, -2), 4, 4, fill=False, edgecolor="black", linewidth=0.9))
113+
ax.set_xlim(-2.05, 2.05)
114+
ax.set_ylim(-2.05, 2.05)
115+
ax.set_aspect("equal", adjustable="box")
116+
ax.grid(color="white", linewidth=0.35, alpha=0.25)
117+
118+
119+
def save_repetition_trace_check(points: pd.DataFrame, image: Image.Image, output: Path) -> None:
120+
"""Save the six-panel repeated-presentation trace check."""
121+
window_ids = sorted(points["window_id"].unique())
122+
cmap = plt.get_cmap("viridis")
123+
norm = plt.Normalize(0, 3)
124+
125+
fig, axes = plt.subplots(2, 3, figsize=(12.5, 8.2), dpi=180, sharex=True, sharey=True)
126+
fig.subplots_adjust(left=0.06, right=0.84, top=0.87, bottom=0.12, wspace=0.16, hspace=0.20)
127+
fig.suptitle(
128+
"NSD repeated target-image presentations: trace check by window",
129+
fontsize=14,
130+
fontweight="bold",
131+
y=0.965,
132+
)
133+
134+
for ax, window_id in zip(axes.flat, window_ids):
135+
data = points[points["window_id"] == window_id].reset_index(drop=True)
136+
ax.imshow(image, extent=(-2, 2, -2, 2), origin="upper", alpha=0.92)
137+
add_time_colored_trace(ax, data, cmap, norm)
138+
format_image_axis(ax)
139+
ax.set_title(f"Window {window_id}: {len(data)} samples", fontsize=10, pad=7)
140+
141+
for ax in axes[-1, :]:
142+
ax.set_xlabel("Horizontal gaze coordinate (degrees)", fontsize=9)
143+
for ax in axes[:, 0]:
144+
ax.set_ylabel("Vertical gaze coordinate (degrees)", fontsize=9)
145+
146+
colorbar_axis = fig.add_axes([0.875, 0.22, 0.025, 0.52])
147+
colorbar = fig.colorbar(plt.cm.ScalarMappable(norm=norm, cmap=cmap), cax=colorbar_axis)
148+
colorbar.set_label("Seconds after image onset", fontsize=9)
149+
fig.text(
150+
0.45,
151+
0.055,
152+
"White dot = first usable sample; black dot = last usable sample. "
153+
"This uses the 4.0-degree helper mapping.",
154+
ha="center",
155+
fontsize=9,
156+
color="#333333",
157+
)
158+
fig.savefig(output, bbox_inches="tight")
159+
plt.close(fig)
160+
161+
162+
def save_dimension_compare(points: pd.DataFrame, image: Image.Image, output: Path) -> None:
163+
"""Save a two-panel check comparing 4.0-degree and 8.4-degree assumptions."""
164+
window_ids = sorted(points["window_id"].unique())
165+
colors = plt.get_cmap("tab10")
166+
window_color = {window_id: colors(i % 10) for i, window_id in enumerate(window_ids)}
167+
168+
fig, axes = plt.subplots(1, 2, figsize=(12.5, 6.2), dpi=180)
169+
fig.suptitle(
170+
"NSD eye-tracking dimension check: same traces, different image-size assumptions",
171+
fontsize=14,
172+
fontweight="bold",
173+
y=0.98,
174+
)
175+
176+
panels = [
177+
(
178+
axes[0],
179+
2.0,
180+
"Image treated as 4.0 deg wide\n(matches the helper x_plot/y_plot columns)",
181+
),
182+
(
183+
axes[1],
184+
4.2,
185+
"Image treated as 8.4 deg wide\n(full NSD fMRI stimulus-frame assumption)",
186+
),
187+
]
188+
for ax, half_width, title in panels:
189+
ax.imshow(image, extent=(-half_width, half_width, -half_width, half_width), origin="upper", alpha=0.96)
190+
ax.add_patch(
191+
Rectangle(
192+
(-half_width, -half_width),
193+
2 * half_width,
194+
2 * half_width,
195+
fill=False,
196+
edgecolor="black",
197+
linewidth=1.2,
198+
)
199+
)
200+
for window_id in window_ids:
201+
data = points[points["window_id"] == window_id]
202+
ax.plot(
203+
data["x"],
204+
data["y"],
205+
color=window_color[window_id],
206+
linewidth=1.25,
207+
alpha=0.78,
208+
label=f"window {window_id}",
209+
)
210+
ax.scatter(
211+
data["x"].iloc[0],
212+
data["y"].iloc[0],
213+
s=18,
214+
color=window_color[window_id],
215+
edgecolor="white",
216+
linewidth=0.5,
217+
zorder=4,
218+
)
219+
pad = half_width * 0.04
220+
ax.set_xlim(-half_width - pad, half_width + pad)
221+
ax.set_ylim(-half_width - pad, half_width + pad)
222+
ax.set_aspect("equal", adjustable="box")
223+
ax.set_title(title, fontsize=10.5)
224+
ax.set_xlabel("Horizontal gaze coordinate (degrees from center)")
225+
ax.set_ylabel("Vertical gaze coordinate (degrees from center)")
226+
ax.grid(color="white", linewidth=0.4, alpha=0.28)
227+
228+
handles, labels = axes[0].get_legend_handles_labels()
229+
fig.legend(handles, labels, loc="lower center", ncol=len(window_ids), frameon=False, fontsize=9)
230+
fig.text(
231+
0.5,
232+
0.055,
233+
f"Target PNG: {image.width} x {image.height} px. "
234+
f"Trace range: x {points.x.min():.2f} to {points.x.max():.2f} deg, "
235+
f"y {points.y.min():.2f} to {points.y.max():.2f} deg.",
236+
ha="center",
237+
fontsize=9,
238+
color="#333333",
239+
)
240+
fig.tight_layout(rect=(0, 0.09, 1, 0.94))
241+
fig.savefig(output, bbox_inches="tight")
242+
plt.close(fig)
243+
244+
245+
def main() -> None:
246+
ASSET_DIR.mkdir(parents=True, exist_ok=True)
247+
points = load_points()
248+
image = Image.open(TARGET_IMAGE).convert("RGB")
249+
250+
repetition_output = ASSET_DIR / "nsd_eye_tracking_repetition_trace_check.png"
251+
dimension_output = ASSET_DIR / "nsd_eye_tracking_overlay_dimension_compare.png"
252+
legacy_output = ASSET_DIR / "nsd_eye_tracking_fixation_candidates.png"
253+
254+
save_repetition_trace_check(points, image, repetition_output)
255+
save_repetition_trace_check(points, image, legacy_output)
256+
save_dimension_compare(points, image, dimension_output)
257+
258+
print(f"Wrote {repetition_output}")
259+
print(f"Wrote {legacy_output}")
260+
print(f"Wrote {dimension_output}")
261+
print("Verified mapping: x_plot = (x + 2) / 4 and y_plot = (2 - y) / 4")
262+
263+
264+
if __name__ == "__main__":
265+
main()

0 commit comments

Comments
 (0)