Coverage for app/plot.py: 100%

67 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-28 09:13 +0000

1"""plot module""" 

2 

3import importlib.util 

4import itertools 

5import math 

6import typing as tp 

7from pathlib import Path 

8 

9import hiplot 

10import hiplot.streamlit_helpers 

11import numpy as np 

12import plotly.graph_objects as go 

13import plotly.subplots as ps 

14from streamlit import runtime 

15 

16COLORS = ["red", "green", "blue", "black", "pink", "purple", "yellow", "grey", "brown"] * 100 

17 

18 

19def subplots( 

20 n: int, 

21 m: int | None = None, 

22 **kwargs, 

23) -> tuple[go.Figure, list[tuple[int, int]]]: 

24 """Create n subplots 

25 :param n: number of subplots 

26 :param m: if int, maximum number of columns 

27 :param kwargs: keyword arguments passed to plotly.subplots.make_subplots 

28 

29 Examples 

30 -------- 

31 >>> subplots(3)[1] 

32 [(1, 1), (2, 1), (3, 1)] 

33 >>> subplots(9, 2)[1] 

34 [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1)]""" 

35 

36 nb_cols = int(np.sqrt(n)) 

37 if isinstance(m, int) and nb_cols > m: 

38 nb_cols = m 

39 nb_rows = int(math.ceil(n / nb_cols)) 

40 positions = list(itertools.product(range(1, nb_rows + 1), range(1, nb_cols + 1)))[:n] 

41 return ps.make_subplots(rows=nb_rows, cols=nb_cols, **kwargs), positions 

42 

43 

44def plot_decays( 

45 xs_data: list[np.ndarray], 

46 ys_data: list[np.ndarray], 

47 quantity: str, 

48 ys_data2: list[np.ndarray] | None = None, 

49 labels: list[str] | None = None, 

50 label2: str = " (fit)", 

51 **kwargs, 

52) -> go.Figure: 

53 """Plot decays 

54 :param xs_data: list of np.ndarray associated with the x-axis 

55 :param ys_data: list of np.ndarray associated with the y-axis (raw data) 

56 :param quantity: quantity fitted: 'TRPL' or 'TRMC' 

57 :param ys_data2: optional list of np.ndarray associated with the y-axis 

58 :param labels: list of labels or floats 

59 :param label2: optional string appended to the labels of ys_data2 

60 :param kwargs: keyword arguments passed to the update_layout method of the figure""" 

61 

62 figure = go.Figure() 

63 

64 if labels is None: 

65 labels = [f"Decay {i + 1}" for i in range(len(xs_data))] 

66 

67 for i in range(len(xs_data)): 

68 scatter = go.Scatter( 

69 x=xs_data[i], 

70 y=ys_data[i], 

71 name=labels[i], 

72 line=dict(color=COLORS[i]), 

73 ) 

74 figure.add_trace(scatter) 

75 if ys_data2 is not None: 

76 scatter = go.Scatter( 

77 x=xs_data[i], 

78 y=ys_data2[i], 

79 name=labels[i] + label2, 

80 line=dict(color=COLORS[i], dash="dash"), 

81 ) 

82 figure.add_trace(scatter) 

83 

84 # Axes 

85 font = dict(size=16, color="black") 

86 x_min = min(np.min(x_data) for x_data in xs_data) 

87 x_max = max(np.max(x_data) for x_data in xs_data) 

88 xrange = [x_min - 0.01 * (x_max - x_min), x_max + 0.01 * (x_max - x_min)] 

89 figure.update_xaxes( 

90 title_text="Time (ns)", 

91 tickformat=",", 

92 title_font=font, 

93 tickfont=font, 

94 showgrid=True, 

95 gridcolor="lightgray", 

96 range=xrange, 

97 ) 

98 if quantity == "TRPL": 

99 ylabel = "Intensity (a.u.)" 

100 else: 

101 ylabel = "Intensity (cm<sup>2</sup>/(Vs))" 

102 

103 figure.update_yaxes( 

104 title_text=ylabel, 

105 title_font=font, 

106 tickfont=font, 

107 showgrid=True, 

108 gridcolor="lightgray", 

109 ) 

110 yaxes_layout = figure.layout.yaxis.to_plotly_json() 

111 

112 figure.update_layout( 

113 margin=dict(l=10, r=10, t=40, b=10, pad=0), 

114 plot_bgcolor="#f0f4f8", 

115 updatemenus=list( 

116 [ 

117 dict( 

118 active=0, 

119 buttons=list( 

120 [ 

121 dict( 

122 label="Linear Scale", 

123 method="relayout", 

124 args=[{"yaxis": {"type": "linear", **yaxes_layout}}], 

125 ), 

126 dict( 

127 label="Log Scale", 

128 method="relayout", 

129 args=[ 

130 { 

131 "yaxis": { 

132 "type": "log", 

133 "tickformat": ".0e", 

134 "dtick": 1, 

135 **yaxes_layout, 

136 } 

137 }, 

138 ], 

139 ), 

140 ] 

141 ), 

142 x=0.9, 

143 y=1.15, 

144 ) 

145 ] 

146 ), 

147 **kwargs, 

148 ) 

149 

150 return figure 

151 

152 

153def plot_carrier_concentrations( 

154 xs_data: list[np.ndarray], 

155 ys_data: list[dict[str, np.ndarray]], 

156 N0s: list[float], 

157 titles: list[str], 

158 xlabel: str, 

159 model, 

160) -> go.Figure: 

161 """Plot all the charge carrier concentrations 

162 :param xs_data: x-axis data 

163 :param ys_data: list of dicts of concentration arrays for each initial carrier concentration 

164 :param N0s: list of initial carrier concentration 

165 :param titles: list of titles of each subplot 

166 :param xlabel: x-axis label 

167 :param model: model used""" 

168 

169 figure, positions = subplots(len(N0s), 2, subplot_titles=titles) 

170 font = dict(size=16, color="black") 

171 

172 for i, N0, position, x_data, y_data in zip(range(len(N0s)), N0s, positions, xs_data, ys_data): 

173 

174 for key in y_data: 

175 showlegend = True if i == 0 else False 

176 scatter = go.Scatter( 

177 x=x_data, 

178 y=y_data[key] / N0, 

179 name=model.CONC_LABELS_HTML[key], 

180 showlegend=showlegend, 

181 line=dict(color=model.CONC_COLORS[key]), 

182 ) 

183 figure.add_trace(scatter, row=position[0], col=position[1]) 

184 

185 figure.update_xaxes( 

186 title_text=xlabel, 

187 row=position[0], 

188 col=position[1], 

189 tickformat=",", 

190 title_font=font, 

191 tickfont=font, 

192 showgrid=True, 

193 gridcolor="lightgray", 

194 ) 

195 figure.update_yaxes( 

196 title_text="Concentration (N<sub>0</sub>)", 

197 row=position[0], 

198 col=position[1], 

199 title_font=font, 

200 tickfont=font, 

201 showgrid=True, 

202 gridcolor="lightgray", 

203 ) 

204 

205 figure.update_layout( 

206 height=900, 

207 margin=dict(l=0, r=0, t=40, b=0, pad=0), 

208 plot_bgcolor="#f0f4f8", 

209 annotations=[dict(font=font, y=anno["y"] + 0.01) for anno in figure.layout.annotations], 

210 ) 

211 return figure 

212 

213 

214def parallel_plot( 

215 popts: list[dict[str, np.ndarray | float]], 

216 hidden_keys: list[str], 

217) -> hiplot.Experiment: 

218 """Plot data in a parallel plot 

219 :param popts: list of dicts 

220 :param hidden_keys: list of keys not displayed""" 

221 

222 # Determine the key order 

223 order = ["ID"] + [key for key in popts[0] if key not in hidden_keys] 

224 

225 # Add ID to the data 

226 data = [] 

227 for i, popt in enumerate(popts): 

228 popt["ID"] = i + 1 

229 data.append(popt) 

230 popts = data 

231 

232 # Plot the data 

233 xp = hiplot.Experiment.from_iterable(popts) 

234 for key in popts[0]: 

235 xp.parameters_definition[key].label_html = key 

236 xp.display_data(hiplot.Displays.PARALLEL_PLOT).update({"hide": hidden_keys + ["uid"], "order": order[::-1]}) 

237 xp.display_data(hiplot.Displays.TABLE).update({"hide": ["from_uid", "uid"]}) 

238 return xp 

239 

240 

241class _StreamlitHelpers: # pragma: no cover 

242 component: tp.Optional[tp.Callable[..., tp.Any]] = None 

243 

244 @staticmethod 

245 def is_running_within_streamlit() -> bool: 

246 """Check if the code is running within a Streamlit environment.""" 

247 

248 try: 

249 import streamlit as st 

250 except: # pylint: disable=bare-except 

251 return False 

252 return bool(runtime.exists()) 

253 

254 @classmethod 

255 def create_component(cls) -> tp.Optional[tp.Callable[..., tp.Any]]: 

256 """Create and return a Streamlit component for rendering HiPlot visualizations.""" 

257 

258 if cls.component is not None: 

259 return cls.component 

260 import streamlit as st 

261 

262 try: 

263 import streamlit.components.v1 as components 

264 except ModuleNotFoundError as e: 

265 raise RuntimeError( 

266 f"""Your streamlit version ({st.__version__}) is too old and does not support components. 

267Please update streamlit with `pip install -U streamlit`""" 

268 ) from e 

269 assert runtime.exists() 

270 

271 # Locate HiPlot module and resolve the path to its static build 

272 spec = importlib.util.find_spec("hiplot") 

273 if spec is None or not spec.origin: 

274 raise RuntimeError("HiPlot module could not be found. Ensure it is installed.") 

275 hiplot_path = Path(spec.origin).parent 

276 built_path = (hiplot_path / "static" / "built" / "streamlit_component").resolve() 

277 

278 assert ( 

279 built_path / "index.html" 

280 ).is_file(), f"""HiPlot component does not appear to exist in {built_path} 

281If you did not install HiPlot using official channels (pip, conda...), maybe you forgot to build the JavaScript files? 

282See https://facebookresearch.github.io/hiplot/contributing.html#building-javascript-bundle""" 

283 

284 cls.component = components.declare_component("hiplot", path=str(built_path)) 

285 return cls.component 

286 

287 

288# Fix hiplot compatibility issues 

289hiplot.streamlit_helpers._StreamlitHelpers = _StreamlitHelpers