Coverage for app/plot.py: 100%
67 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
1"""plot module"""
3import importlib.util
4import itertools
5import math
6import typing as tp
7from pathlib import Path
9import hiplot
10import hiplot.streamlit_helpers
11import numpy as np
12import plotly.graph_objects as go
13import plotly.subplots as ps
14from streamlit import runtime
16COLORS = ["red", "green", "blue", "black", "pink", "purple", "yellow", "grey", "brown"] * 100
19def subplots(
20 n: int,
21 m: int | None = None,
22 **kwargs,
23) -> tuple[go.Figure, list[tuple[int, int]]]:
24 """Create n subplots
25 :param n: number of subplots
26 :param m: if int, maximum number of columns
27 :param kwargs: keyword arguments passed to plotly.subplots.make_subplots
29 Examples
30 --------
31 >>> subplots(3)[1]
32 [(1, 1), (2, 1), (3, 1)]
33 >>> subplots(9, 2)[1]
34 [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1)]"""
36 nb_cols = int(np.sqrt(n))
37 if isinstance(m, int) and nb_cols > m:
38 nb_cols = m
39 nb_rows = int(math.ceil(n / nb_cols))
40 positions = list(itertools.product(range(1, nb_rows + 1), range(1, nb_cols + 1)))[:n]
41 return ps.make_subplots(rows=nb_rows, cols=nb_cols, **kwargs), positions
44def plot_decays(
45 xs_data: list[np.ndarray],
46 ys_data: list[np.ndarray],
47 quantity: str,
48 ys_data2: list[np.ndarray] | None = None,
49 labels: list[str] | None = None,
50 label2: str = " (fit)",
51 **kwargs,
52) -> go.Figure:
53 """Plot decays
54 :param xs_data: list of np.ndarray associated with the x-axis
55 :param ys_data: list of np.ndarray associated with the y-axis (raw data)
56 :param quantity: quantity fitted: 'TRPL' or 'TRMC'
57 :param ys_data2: optional list of np.ndarray associated with the y-axis
58 :param labels: list of labels or floats
59 :param label2: optional string appended to the labels of ys_data2
60 :param kwargs: keyword arguments passed to the update_layout method of the figure"""
62 figure = go.Figure()
64 if labels is None:
65 labels = [f"Decay {i + 1}" for i in range(len(xs_data))]
67 for i in range(len(xs_data)):
68 scatter = go.Scatter(
69 x=xs_data[i],
70 y=ys_data[i],
71 name=labels[i],
72 line=dict(color=COLORS[i]),
73 )
74 figure.add_trace(scatter)
75 if ys_data2 is not None:
76 scatter = go.Scatter(
77 x=xs_data[i],
78 y=ys_data2[i],
79 name=labels[i] + label2,
80 line=dict(color=COLORS[i], dash="dash"),
81 )
82 figure.add_trace(scatter)
84 # Axes
85 font = dict(size=16, color="black")
86 x_min = min(np.min(x_data) for x_data in xs_data)
87 x_max = max(np.max(x_data) for x_data in xs_data)
88 xrange = [x_min - 0.01 * (x_max - x_min), x_max + 0.01 * (x_max - x_min)]
89 figure.update_xaxes(
90 title_text="Time (ns)",
91 tickformat=",",
92 title_font=font,
93 tickfont=font,
94 showgrid=True,
95 gridcolor="lightgray",
96 range=xrange,
97 )
98 if quantity == "TRPL":
99 ylabel = "Intensity (a.u.)"
100 else:
101 ylabel = "Intensity (cm<sup>2</sup>/(Vs))"
103 figure.update_yaxes(
104 title_text=ylabel,
105 title_font=font,
106 tickfont=font,
107 showgrid=True,
108 gridcolor="lightgray",
109 )
110 yaxes_layout = figure.layout.yaxis.to_plotly_json()
112 figure.update_layout(
113 margin=dict(l=10, r=10, t=40, b=10, pad=0),
114 plot_bgcolor="#f0f4f8",
115 updatemenus=list(
116 [
117 dict(
118 active=0,
119 buttons=list(
120 [
121 dict(
122 label="Linear Scale",
123 method="relayout",
124 args=[{"yaxis": {"type": "linear", **yaxes_layout}}],
125 ),
126 dict(
127 label="Log Scale",
128 method="relayout",
129 args=[
130 {
131 "yaxis": {
132 "type": "log",
133 "tickformat": ".0e",
134 "dtick": 1,
135 **yaxes_layout,
136 }
137 },
138 ],
139 ),
140 ]
141 ),
142 x=0.9,
143 y=1.15,
144 )
145 ]
146 ),
147 **kwargs,
148 )
150 return figure
153def plot_carrier_concentrations(
154 xs_data: list[np.ndarray],
155 ys_data: list[dict[str, np.ndarray]],
156 N0s: list[float],
157 titles: list[str],
158 xlabel: str,
159 model,
160) -> go.Figure:
161 """Plot all the charge carrier concentrations
162 :param xs_data: x-axis data
163 :param ys_data: list of dicts of concentration arrays for each initial carrier concentration
164 :param N0s: list of initial carrier concentration
165 :param titles: list of titles of each subplot
166 :param xlabel: x-axis label
167 :param model: model used"""
169 figure, positions = subplots(len(N0s), 2, subplot_titles=titles)
170 font = dict(size=16, color="black")
172 for i, N0, position, x_data, y_data in zip(range(len(N0s)), N0s, positions, xs_data, ys_data):
174 for key in y_data:
175 showlegend = True if i == 0 else False
176 scatter = go.Scatter(
177 x=x_data,
178 y=y_data[key] / N0,
179 name=model.CONC_LABELS_HTML[key],
180 showlegend=showlegend,
181 line=dict(color=model.CONC_COLORS[key]),
182 )
183 figure.add_trace(scatter, row=position[0], col=position[1])
185 figure.update_xaxes(
186 title_text=xlabel,
187 row=position[0],
188 col=position[1],
189 tickformat=",",
190 title_font=font,
191 tickfont=font,
192 showgrid=True,
193 gridcolor="lightgray",
194 )
195 figure.update_yaxes(
196 title_text="Concentration (N<sub>0</sub>)",
197 row=position[0],
198 col=position[1],
199 title_font=font,
200 tickfont=font,
201 showgrid=True,
202 gridcolor="lightgray",
203 )
205 figure.update_layout(
206 height=900,
207 margin=dict(l=0, r=0, t=40, b=0, pad=0),
208 plot_bgcolor="#f0f4f8",
209 annotations=[dict(font=font, y=anno["y"] + 0.01) for anno in figure.layout.annotations],
210 )
211 return figure
214def parallel_plot(
215 popts: list[dict[str, np.ndarray | float]],
216 hidden_keys: list[str],
217) -> hiplot.Experiment:
218 """Plot data in a parallel plot
219 :param popts: list of dicts
220 :param hidden_keys: list of keys not displayed"""
222 # Determine the key order
223 order = ["ID"] + [key for key in popts[0] if key not in hidden_keys]
225 # Add ID to the data
226 data = []
227 for i, popt in enumerate(popts):
228 popt["ID"] = i + 1
229 data.append(popt)
230 popts = data
232 # Plot the data
233 xp = hiplot.Experiment.from_iterable(popts)
234 for key in popts[0]:
235 xp.parameters_definition[key].label_html = key
236 xp.display_data(hiplot.Displays.PARALLEL_PLOT).update({"hide": hidden_keys + ["uid"], "order": order[::-1]})
237 xp.display_data(hiplot.Displays.TABLE).update({"hide": ["from_uid", "uid"]})
238 return xp
241class _StreamlitHelpers: # pragma: no cover
242 component: tp.Optional[tp.Callable[..., tp.Any]] = None
244 @staticmethod
245 def is_running_within_streamlit() -> bool:
246 """Check if the code is running within a Streamlit environment."""
248 try:
249 import streamlit as st
250 except: # pylint: disable=bare-except
251 return False
252 return bool(runtime.exists())
254 @classmethod
255 def create_component(cls) -> tp.Optional[tp.Callable[..., tp.Any]]:
256 """Create and return a Streamlit component for rendering HiPlot visualizations."""
258 if cls.component is not None:
259 return cls.component
260 import streamlit as st
262 try:
263 import streamlit.components.v1 as components
264 except ModuleNotFoundError as e:
265 raise RuntimeError(
266 f"""Your streamlit version ({st.__version__}) is too old and does not support components.
267Please update streamlit with `pip install -U streamlit`"""
268 ) from e
269 assert runtime.exists()
271 # Locate HiPlot module and resolve the path to its static build
272 spec = importlib.util.find_spec("hiplot")
273 if spec is None or not spec.origin:
274 raise RuntimeError("HiPlot module could not be found. Ensure it is installed.")
275 hiplot_path = Path(spec.origin).parent
276 built_path = (hiplot_path / "static" / "built" / "streamlit_component").resolve()
278 assert (
279 built_path / "index.html"
280 ).is_file(), f"""HiPlot component does not appear to exist in {built_path}
281If you did not install HiPlot using official channels (pip, conda...), maybe you forgot to build the JavaScript files?
282See https://facebookresearch.github.io/hiplot/contributing.html#building-javascript-bundle"""
284 cls.component = components.declare_component("hiplot", path=str(built_path))
285 return cls.component
288# Fix hiplot compatibility issues
289hiplot.streamlit_helpers._StreamlitHelpers = _StreamlitHelpers