-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplotly_interactive_plot.py
More file actions
50 lines (42 loc) · 1.11 KB
/
plotly_interactive_plot.py
File metadata and controls
50 lines (42 loc) · 1.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import numpy as np
import plotly.graph_objects as go
# Load the data from the .npy files
metadata_path = 'metadata_5_HTrMRL.npy'
embeddings_path = 'embeddings_5_HTrMRL.npy'
metadata = np.load(metadata_path)
embeddings = np.load(embeddings_path)
# Define the task names
task_names = {
1: 'push-v2',
3: 'door-open-v2',
4: 'drawer-close-v2',
5: 'button-press-topdown-v2',
7: 'window-open-v2'
}
# Create a Plotly figure
fig = go.Figure()
# Adding each task as a separate trace
for task_index, task_name in task_names.items():
# Filter embeddings by task index
task_embeddings = embeddings[metadata == task_index]
# Add trace for each task
fig.add_trace(go.Scatter3d(
x=task_embeddings[:, 0],
y=task_embeddings[:, 1],
z=task_embeddings[:, 2],
mode='markers',
name=task_name
))
# Update layout
fig.update_layout(
title="3D Embeddings Visualization",
scene=dict(
xaxis_title='X Axis',
yaxis_title='Y Axis',
zaxis_title='Z Axis'
),
legend_title="Task Names",
hovermode="closest"
)
# Show the plot
fig.show()