diff --git a/crystal_toolkit/components/phonon.py b/crystal_toolkit/components/phonon.py index 5ef76d90..0f805f54 100644 --- a/crystal_toolkit/components/phonon.py +++ b/crystal_toolkit/components/phonon.py @@ -9,17 +9,20 @@ from dash import dcc, html from dash.dependencies import Component, Input, Output, State from dash.exceptions import PreventUpdate -from dash_mp_components import CrystalToolkitAnimationScene, CrystalToolkitScene +from dash_mp_components import CrystalToolkitScene, PhononAnimationScene +from emmet.core.phonon import PhononBS # crystal animation algo from pymatgen.analysis.graphs import StructureGraph from pymatgen.analysis.local_env import CrystalNN +from pymatgen.core import Species from pymatgen.ext.matproj import MPRester from pymatgen.phonon.bandstructure import PhononBandStructureSymmLine from pymatgen.phonon.dos import CompletePhononDos from pymatgen.phonon.plotter import PhononBSPlotter from pymatgen.transformations.standard_transformations import SupercellTransformation +from crystal_toolkit.core.legend import Legend from crystal_toolkit.core.mpcomponent import MPComponent from crystal_toolkit.core.panelcomponent import PanelComponent from crystal_toolkit.core.scene import Convex, Cylinders, Lines, Scene, Spheres @@ -34,9 +37,14 @@ MARKER_COLOR = "red" MARKER_SIZE = 12 MARKER_SHAPE = "x" -MAX_MAGNITUDE = 300 +MAX_MAGNITUDE = 500 MIN_MAGNITUDE = 0 +DEFAULTS: dict[str, str | bool] = { + "color_scheme": "VESTA", +} + + # TODOs: # - look for additional projection methods in phonon DOS (currently only atom # projections supported) @@ -66,13 +74,6 @@ def __init__( **kwargs, ) - bs, _ = PhononBandstructureAndDosComponent._get_ph_bs_dos( - self.initial_data["default"] - ) - self.create_store("bs-store", bs) - self.create_store("bs", None) - self.create_store("dos", None) - @property def _sub_layouts(self) -> dict[str, Component]: # defaults @@ -80,11 +81,16 @@ def _sub_layouts(self) -> dict[str, Component]: fig = PhononBandstructureAndDosComponent.get_figure(None, None) # Main plot - graph = dcc.Graph( - figure=fig, - config={"displayModeBar": False}, - responsive=False, - id=self.id("ph-bsdos-graph"), + graph = html.Div( + [ + dcc.Graph( + figure=fig, + config={"displayModeBar": False}, + responsive=True, + id=self.id("ph-bsdos-graph"), + style={"height": "400px"}, + ) + ] ) # Brillouin zone @@ -153,77 +159,124 @@ def _sub_layouts(self) -> dict[str, Component]: summary_dict = self._get_data_list_dict(None, None) summary_table = get_data_list(summary_dict) - # crystal visualization - - tip = html.H5( - "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", + tip = html.Div( + html.Span( + "💡 Tips: Click different q-points and bands in the dispersion diagram to see the crystal vibration!", + style={ + "border": "0.5px dashed black", + "display": "inline-flex", + "alignItems": "center", + "justifyContent": "center", + "textAlign": "center", + }, + ), + style={ + "display": "flex", + "justifyContent": "center", + }, ) + # crystal visualization crystal_animation = html.Div( - CrystalToolkitAnimationScene( - data={}, - sceneSize="500px", + # CrystalToolkitAnimationScene( + PhononAnimationScene( + data={"app": "phonon"}, + sceneSize="400px", id=self.id("crystal-animation"), settings={"defaultZoom": 1.2}, axisView="SW", showControls=False, # disable download for now - ), - style={"width": "60%"}, + ) + ) + + hr = html.Hr( + style={ + "backgroundColor": "#C5C5C6", + "border": "none", + "margin": "8px 0", + } ) - crystal_animation_controls = html.Div( + crystal_animation_controls = html.Details( [ - html.Br(), - html.Div(tip, style={"textAlign": "center"}), - html.Br(), - html.H5("Control Panel", style={"textAlign": "center"}), - html.H6("Supercell modification"), - html.Br(), + html.Summary("Control Panel"), html.Div( [ - self.get_numerical_input( - kwarg_label="scale-x", - default=1, - is_int=True, - label="x", - min=1, - style={"width": "5rem"}, + html.H6( + "Supercell modification", style={"textAlign": "center"} + ), + html.Div( + [ + self.get_numerical_input( + kwarg_label="scale-x", + default=1, + persistence_type="session", + is_int=True, + label="x", + min=1, + style={"height": "16px"}, + ), + self.get_numerical_input( + kwarg_label="scale-y", + default=1, + persistence_type="session", + is_int=True, + label="y", + min=1, + style={"height": "16px"}, + ), + self.get_numerical_input( + kwarg_label="scale-z", + default=1, + persistence_type="session", + is_int=True, + label="z", + min=1, + style={"height": "16px"}, + ), + ], + style={ + "display": "flex", + "justify-content": "center", + "gap": "16px", + }, ), - self.get_numerical_input( - kwarg_label="scale-y", - default=1, - is_int=True, - label="y", - min=1, - style={"width": "5rem"}, + hr, + html.Div( + self.get_slider_input( + kwarg_label="magnitude", + default=0.5, + step=0.01, + domain=[0, 1], + label="Vibration magnitude", + # styleInput={"height": "40px"}, + ), ), - self.get_numerical_input( - kwarg_label="scale-z", - default=1, - is_int=True, - label="z", - min=1, - style={"width": "5rem"}, + hr, + html.Div( + self.get_slider_input( + kwarg_label="velocity", + default=0.5, + step=0.01, + domain=[0, 1], + label="Velocity", + ) ), - html.Button( - "Update", - id=self.id("supercell-controls-btn"), - style={"height": "40px"}, + hr, + html.Div( + html.Button( + "Update", + id=self.id("supercell-controls-btn"), + style={"height": "40px"}, + ), + style={"textAlign": "center", "width": "100%"}, ), ], - style={"display": "flex"}, - ), - html.Br(), - html.Div( - self.get_slider_input( - kwarg_label="magnitude", - default=0.5, - step=0.01, - domain=[0, 1], - label="Vibration magnitude", - ) + style={ + "width": "100%", + }, ), - ], + ] ) return { @@ -244,15 +297,22 @@ def _get_animation_panel(self): [ Column( [ + sub_layouts["tip"], + html.Br(), Columns( [ sub_layouts["crystal-animation"], sub_layouts["crystal-animation-controls"], - ] - ) - ] + ], + style={ + "display": "flex", + "justify-content": "center", + "gap": "10px", + }, + ), + ], ), - ] + ], ) def layout(self) -> html.Div: @@ -280,7 +340,28 @@ def layout(self) -> html.Div: return html.Div([graph, crystal_animation, controls, brillouin_zone]) @staticmethod - def _get_eigendisplacement( + def _complex_vectors_serialization(vectors): + # `ph_bs.eigendisplacements[band][qpoint]` is np.complex which is not serializable + # this function transfer complex eigenvector to a list of Re and Im + # For example, + # vectors = [(np.complex128(3.0634449212096337e-09+0j), + # np.complex128(-3.720119057521199e-08+0j), + # np.complex128(-0.0016537315137792753+0j)), + # (np.complex128(3.063444921240483e-09+0j), + # np.complex128(-3.720119057492181e-08+0j), + # np.complex128(-0.0016537315137792735+0j))] + # output: + # [[[3.0634449212096337e-09, 0.0], + # [-3.720119057521199e-08, 0.0], + # [-0.0016537315137792753, 0.0]], + # [[3.063444921240483e-09, 0.0], + # [-3.720119057492181e-08, 0.0], + # [-0.0016537315137792735, 0.0]]] + arr = np.asarray(vectors, dtype=np.complex128) + return np.stack([arr.real, arr.imag], axis=-1).astype(float).tolist() + + @staticmethod + def _get_time_function_json( ph_bs: BandStructureSymmLine, json_data: dict, band: int = 0, @@ -288,125 +369,78 @@ def _get_eigendisplacement( precision: int = 15, magnitude: int = MAX_MAGNITUDE / 2, total_repeat_cell_cnt: int = 1, + velocity: float = 1.0, ) -> dict: if not ph_bs or not json_data: return {} - assert json_data["contents"][0]["name"] == "atoms" assert json_data["contents"][1]["name"] == "bonds" rdata = deepcopy(json_data) - def calc_max_displacement(idx: int) -> list: - """ - Retrieve the eigendisplacement for a given atom index from `ph_bs` and compute its maximum displacement. - - Parameters: - idx (int): The atom index. - - Returns: - list: The maximum displacement vector in the form [x_max_displacement, y_max_displacement, z_max_displacement] - - This function extracts the real component of the atom's eigendisplacement, - scales it by the specified magnitude, and returns the resulting vector. - """ - - # get the atom index - assert total_repeat_cell_cnt != 0 - - modified_idx = ( - int(idx // total_repeat_cell_cnt) if total_repeat_cell_cnt else idx - ) - - return [ - round(complex(vec).real * magnitude, precision) - for vec in ph_bs.eigendisplacements[band][qpoint][modified_idx] - ] - - def calc_animation_step(max_displacement: list, coef: int) -> list: - """ - Calculate the displacement for an animation frame based on the given coefficient. - - Parameters: - max_displacement (list): A list of maximum displacements along each axis, - formatted as [x_max_displacement, y_max_displacement, z_max_displacement]. - coef (int): A coefficient indicating the motion direction. - - 0: no movement - - 1: forward movement - - -1: backward movement - - Returns: - list: The displacement vector [x_displacement, y_displacement, z_displacement]. - - This function generates oscillatory motion by scaling the maximum displacement - with the provided coefficient. - """ - return [round(coef * md, precision) for md in max_displacement] - - # Compute per-frame atomic motion. - # `rcontent["animate"]` stores the displacement (distance difference) from the previous coordinates. + # atoms contents0 = json_data["contents"][0]["contents"] - for cidx, content in enumerate(contents0): - max_displacement = calc_max_displacement(content["_meta"][0]) + for cidx, _ in enumerate(contents0): rcontent = rdata["contents"][0]["contents"][cidx] - # put animation frame to the given atom index - rcontent["animate"] = [ - calc_animation_step(max_displacement, coef) for coef in DISPLACE_COEF - ] - rcontent["keyframes"] = list(range(len(DISPLACE_COEF))) - rcontent["animateType"] = "displacement" - # Compute per-frame bonding motion. - # Explanation: - # Each bond connects two atoms, `u` and `v`, represented as (u)----(v) - # To model the bond motion, it is divided into two segments: - # from `u` to the midpoint and from the midpoint to `v`, i.e., (u)--(mid)--(v) - # Thus, two cylinders are created: one for (u)--(mid) and another for (v)--(mid). - # For each cylinder, displacements are assigned to the endpoints — for example, - # the (u)--(mid) cylinder uses: - # [ - # [u_x_displacement, u_y_displacement, u_z_displacement], - # [mid_x_displacement, mid_y_displacement, mid_z_displacement] - # ]. - contents1 = json_data["contents"][1]["contents"] + # put required data to the given atom index + rcontent[ + "animate" + ] = [] # we just need `animate` field indicating animtaion rendering + # bonds + contents1 = json_data["contents"][1]["contents"] for cidx, content in enumerate(contents1): - bond_animation = [] assert len(content["_meta"]) == len(content["positionPairs"]) + rcontent = rdata["contents"][1]["contents"][cidx] + rcontent["animate"] = [] + + # remove unused sense (polyhedra and magmoms) + del rdata["contents"][2:4] + + # displacement formula: u(R,t) = A * e^(i(q⋅R-ωt)) + rdata["app"] = "phonon" + + # omega (ω) + rdata["omega"] = ph_bs.frequencies[band][qpoint] + + # Take mp-149 as an example: + # ph_bs.qpoints is "frac_coords of the given lattice by default (from Pymatgen)" + # transfer from frac_coords to cart_coords + # the size of ph_bs.structure.lattice.matrix: (3, 3) (lattice size) + # the size of ph_bs.qpoints: (149, 3) (wave vector for each qpoint) + # the size of q: (149, 3) + # q: + q = np.einsum( + "ij,kj->ik", + ph_bs.structure.lattice.reciprocal_lattice.matrix, + np.array(ph_bs.qpoints), + ).T + + # phases (q⋅R): should be a number + # we calculate the phase with all atoms and qpoints here + # the size of q: (149, 3) + # the size of ph_bs.structure.cart_coords: (2, 3) (the coordinate of two atoms in the unit cell) + # the size of phase: (149, 2) + phases = np.einsum( + "ij,kj->ik", + q, + ph_bs.structure.cart_coords, + ) + rdata["phases"] = phases[qpoint].tolist() - for atom_idx_pair in content["_meta"]: - max_displacements = list( - map(calc_max_displacement, atom_idx_pair) - ) # max displacement for u and v - - u_to_middle_bond_animation = [] - - for coef in DISPLACE_COEF: - # Calculate the midpoint displacement between atom u and v for each animation frame. - u_displacement, v_displacement = [ - np.array(calc_animation_step(max_displacement, coef)) - for max_displacement in max_displacements - ] - middle_end_displacement = np.add(u_displacement, v_displacement) / 2 - - u_to_middle_bond_animation.append( - [ - u_displacement, # u atom displacement - [ - round(dis, precision) for dis in middle_end_displacement - ], # middle point displacement - ] - ) - - bond_animation.append(u_to_middle_bond_animation) + # amplitude (A) + rdata["amplitude"] = magnitude - rdata["contents"][1]["contents"][cidx]["animate"] = bond_animation - rdata["contents"][1]["contents"][cidx]["keyframes"] = list( - range(len(DISPLACE_COEF)) + # eigenVectors + rdata["eigenVectors"] = ( + PhononBandstructureAndDosComponent._complex_vectors_serialization( + ph_bs.eigendisplacements[band][qpoint] ) - rdata["contents"][1]["contents"][cidx]["animateType"] = "displacement" + ) + + # velocity + rdata["velocity"] = velocity - # remove unused sense - for i in range(2, 4): - rdata["contents"][i]["visible"] = False + rdata["name"] = "StructureGraphPhonon" return rdata @@ -646,7 +680,6 @@ def get_ph_dos_traces(dos: CompletePhononDos, freq_range: tuple[float, float]): "xaxis": "x2", "yaxis": "y2", } - dos_traces.append(trace_tdos) # Projected DOS @@ -820,6 +853,44 @@ def get_figure( return figure + def _make_legend(self, legend): + # this is copied and customized from crystal_toolkit.components.structure.StructureMoleculeComponent + # in order to get the consistent legend with the structure viewer + if not legend: + return html.Div(id=self.id("legend")) + + def get_font_color(hex_code): + # ensures contrasting font color for background color + c = tuple(int(hex_code[1:][i : i + 2], 16) for i in (0, 2, 4)) + return ( + "black" + if 1 - (c[0] * 0.299 + c[1] * 0.587 + c[2] * 0.114) / 255 < 0.5 + else "white" + ) + + legend_colors = { + key: self._legend.get_color(Species(key)) + for key, val in legend["composition"].items() + } + + legend_elements = [ + html.Span( + html.Span( + name, className="icon", style={"color": get_font_color(color)} + ), + className="button is-static is-rounded", + style={"backgroundColor": color}, + ) + for name, color in legend_colors.items() + ] + + return html.Div( + legend_elements, + id=self.id("legend"), + style={"display": "flex"}, + className="buttons", + ) + def generate_callbacks(self, app, cache) -> None: @app.callback( Output(self.id("ph-bsdos-graph"), "figure"), @@ -831,6 +902,7 @@ def generate_callbacks(self, app, cache) -> None: ) def update_graph(bs, dos, nclick): if isinstance(bs, dict): + # bs = PhononBS.from_pmg(bs) bs = PhononBandStructureSymmLine.from_dict(bs) if isinstance(dos, dict): dos = CompletePhononDos.from_dict(dos) @@ -889,21 +961,29 @@ def highlight_bz_on_hover_bs(hover_data, click_data, label_select): @app.callback( Output(self.id("crystal-animation"), "data"), + Output(self.id("crystal-animation"), "children"), Input(self.id("ph-bsdos-graph"), "clickData"), Input(self.id("ph_bs"), "data"), Input(self.id("supercell-controls-btn"), "n_clicks"), - Input(self.get_kwarg_id("magnitude"), "value"), + State(self.get_kwarg_id("magnitude"), "value"), State(self.get_kwarg_id("scale-x"), "value"), State(self.get_kwarg_id("scale-y"), "value"), State(self.get_kwarg_id("scale-z"), "value"), + State(self.get_kwarg_id("velocity"), "value"), # prevent_initial_call=True ) def update_crystal_animation( - cd, bs, sueprcell_update, magnitude_fraction, scale_x, scale_y, scale_z + cd, + bs, + sueprcell_update, + magnitude_fraction, + scale_x, + scale_y, + scale_z, + velocity, ): # Avoids using `get_all_kwargs_id` for all `Input`; instead, uses `State` to prevent flickering when users modify `scale_x`, `scale_y`, or `scale_z` fields, # ensuring updates occur only after the `supercell-controls-btn`` is clicked. - if not bs: raise PreventUpdate @@ -915,12 +995,15 @@ def update_crystal_animation( scale_x = kwargs.get("scale-x") scale_y = kwargs.get("scale-y") scale_z = kwargs.get("scale-z") + velocity = kwargs.get("velocity") if isinstance(bs, dict): - bs = PhononBandStructureSymmLine.from_dict(bs) + bs = PhononBS.from_pmg(bs) + # bs = PhononBandStructureSymmLine.from_dict(bs) struct = bs.structure total_repeat_cell_cnt = 1 + # update structure if the controls got triggered if sueprcell_update: total_repeat_cell_cnt = scale_x * scale_y * scale_z @@ -932,11 +1015,34 @@ def update_crystal_animation( struct = trans.apply_transformation(struct) struc_graph = StructureGraph.from_local_env_strategy(struct, CrystalNN()) + + # legend + legend = Legend( + struc_graph.structure, + color_scheme=DEFAULTS["color_scheme"], + # radius_scheme=radius_strategy, + cmap_range=None, + ) + self._legend = legend + legend_layout = html.Div(self._make_legend(legend.get_legend())) + + # scene scene = struc_graph.get_scene( draw_image_atoms=False, bonded_sites_outside_unit_cell=False, - site_get_scene_kwargs={"retain_atom_idx": True}, + site_get_scene_kwargs={ + "retain_atom_idx": True, + "total_repeat_cell_cnt": total_repeat_cell_cnt, + }, + legend=legend, ) + + # axis + axes = struct.lattice._axes_from_lattice() + axes.visible = True + scene.contents.append(axes) + + # json_data = scene.to_json() qpoint = 0 @@ -951,14 +1057,15 @@ def update_crystal_animation( MAX_MAGNITUDE - MIN_MAGNITUDE ) * magnitude_fraction + MIN_MAGNITUDE - return PhononBandstructureAndDosComponent._get_eigendisplacement( + return PhononBandstructureAndDosComponent._get_time_function_json( ph_bs=bs, json_data=json_data, band=band_num, qpoint=qpoint, total_repeat_cell_cnt=total_repeat_cell_cnt, magnitude=magnitude, - ) + velocity=velocity, + ), [None, legend_layout] class PhononBandstructureAndDosPanelComponent(PanelComponent): diff --git a/crystal_toolkit/core/scene.py b/crystal_toolkit/core/scene.py index ad2d0621..daa473c0 100644 --- a/crystal_toolkit/core/scene.py +++ b/crystal_toolkit/core/scene.py @@ -98,6 +98,8 @@ def remove_defaults(scene_dict): """Reduce file size of JSON by removing any key which is just its default value.""" trimmed_dict = {} for key, val in scene_dict.items(): + # if key == "_meta": + # trimmed_dict[key] = val if isinstance(val, dict): val = remove_defaults(val) # noqa: PLW2901 elif isinstance(val, list): diff --git a/crystal_toolkit/renderables/site.py b/crystal_toolkit/renderables/site.py index bab81bfd..ac9579e9 100644 --- a/crystal_toolkit/renderables/site.py +++ b/crystal_toolkit/renderables/site.py @@ -48,6 +48,7 @@ def get_site_scene( magmom_scale: float = 1.0, legend: Legend | None = None, retain_atom_idx: bool = False, + total_repeat_cell_cnt: int = 1, ) -> Scene: """Get a Scene object for a Site. @@ -72,6 +73,7 @@ def get_site_scene( magmom_scale (float, optional): Defaults to 1.0. legend (Legend | None, optional): Defaults to None. retain_atom_idx (bool, optional): Defaults to False. + total_repeat_cell_cnt (int, optional): Defaults to 1. Returns: Scene: The scene object containing atoms, bonds, polyhedra, magmoms. @@ -137,7 +139,15 @@ def get_site_scene( phiEnd=phiEnd, clickable=True, tooltip=name, - _meta=[site_idx] if retain_atom_idx else None, + _meta=[ + { + "unit_cell_atom_idx": [site_idx // total_repeat_cell_cnt], + "atom_idx": [site_idx], + } + ] + if retain_atom_idx + else None, + # _meta=[site_idx // total_repeat_cell_cnt] if retain_atom_idx else None, ) atoms.append(sphere) @@ -210,7 +220,17 @@ def get_site_scene( radius=bond_radius / 2, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + # _meta=[site_idx // total_repeat_cell_cnt, connected_site.index // total_repeat_cell_cnt] + # if retain_atom_idx + # else None, + _meta={ + "unit_cell_atom_idx": [ + site_idx // total_repeat_cell_cnt, + connected_site.index + // total_repeat_cell_cnt, + ], + "atom_idx": [site_idx, connected_site.index], + } if retain_atom_idx else None, ) @@ -224,7 +244,16 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] + # _meta=[site_idx // total_repeat_cell_cnt, connected_site.index // total_repeat_cell_cnt] + # if retain_atom_idx + # else None, + _meta={ + "unit_cell_atom_idx": [ + site_idx // total_repeat_cell_cnt, + connected_site.index // total_repeat_cell_cnt, + ], + "atom_idx": [site_idx, connected_site.index], + } if retain_atom_idx else None, ) @@ -237,7 +266,16 @@ def get_site_scene( radius=bond_radius, clickable=True, tooltip=name_cyl, - _meta=[site_idx, connected_site.index] if retain_atom_idx else None, + # _meta=[site_idx // total_repeat_cell_cnt, connected_site.index // total_repeat_cell_cnt] if retain_atom_idx else None, + _meta={ + "unit_cell_atom_idx": [ + site_idx // total_repeat_cell_cnt, + connected_site.index // total_repeat_cell_cnt, + ], + "atom_idx": [site_idx, connected_site.index], + } + if retain_atom_idx + else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) @@ -261,7 +299,16 @@ def get_site_scene( positionPairs=[[position, bond_midpoint.tolist()]], color=color, radius=bond_radius, - _meta=[site_idx, connected_site.index] if retain_atom_idx else None, + # _meta=[site_idx // total_repeat_cell_cnt, connected_site.index // total_repeat_cell_cnt] if retain_atom_idx else None, + _meta={ + "unit_cell_atom_idx": [ + site_idx // total_repeat_cell_cnt, + connected_site.index // total_repeat_cell_cnt, + ], + "atom_idx": [site_idx, connected_site.index], + } + if retain_atom_idx + else None, ) bonds.append(cylinder) all_positions.append(connected_position.tolist()) diff --git a/pyproject.toml b/pyproject.toml index 5e38d4a4..2da1dc29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,7 @@ requires-python = ">=3.10" authors = [{ name = "Matt Horton", email = "mkhorton@lbl.gov" }] dependencies = [ - "dash-mp-components>=0.5.0rc0", + "dash-mp-components==0.5.0rc1", "dash>=2.11.0", "flask-caching", "frozendict",