Coverage for tests/test_plot.py: 100%

97 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-28 09:13 +0000

1"""Test module for the functions in the `plot.py` module. 

2 

3This module contains unit tests for the functions implemented in the `plot.py` module. The purpose of these tests is to 

4ensure the correct functionality of each function in different scenarios and to validate that the expected outputs are 

5returned. 

6 

7Tests should cover various edge cases, valid inputs, and any other conditions that are necessary to confirm the 

8robustness of the functions.""" 

9 

10import hiplot 

11import numpy as np 

12import plotly.graph_objects as go 

13import pytest 

14 

15from app.plot import parallel_plot, plot_carrier_concentrations, plot_decays, subplots 

16 

17 

18class TestSubplots: 

19 

20 def test_simple_subplot_creation(self) -> None: 

21 fig, positions = subplots(3) 

22 assert isinstance(fig, go.Figure) 

23 assert positions == [(1, 1), (2, 1), (3, 1)] 

24 

25 def test_subplot_with_max_columns(self) -> None: 

26 fig, positions = subplots(9, 2) 

27 assert isinstance(fig, go.Figure) 

28 assert positions == [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1)] 

29 

30 def test_subplot_with_kwargs(self) -> None: 

31 fig, positions = subplots(4, subplot_titles=["A", "B", "C", "D"]) 

32 assert isinstance(fig, go.Figure) 

33 assert positions == [(1, 1), (1, 2), (2, 1), (2, 2)] 

34 

35 

36class TestPlotDecays: 

37 

38 @pytest.fixture 

39 def decay_data(self) -> dict[str, list]: 

40 """Example data""" 

41 

42 return { 

43 "xs_data": [np.array([1, 2, 3]), np.array([1, 2, 3])], 

44 "ys_data": [np.array([10, 5, 2]), np.array([8, 4, 1])], 

45 "ys_data_fit": [np.array([9, 5, 2.5]), np.array([8.5, 4.2, 1.1])], 

46 "labels": ["Sample A", "Sample B"], 

47 } 

48 

49 def test_plot_decays_trpl(self, decay_data) -> None: 

50 

51 fig = plot_decays( 

52 decay_data["xs_data"], 

53 decay_data["ys_data"], 

54 "TRPL", 

55 labels=decay_data["labels"], 

56 ) 

57 assert isinstance(fig, go.Figure) 

58 assert len(fig.data) == 2 # Two traces for two datasets 

59 assert fig.data[0].name == "Sample A" 

60 assert fig.data[1].name == "Sample B" 

61 assert "Time (ns)" in fig.layout.xaxis.title.text 

62 assert "Intensity (a.u.)" in fig.layout.yaxis.title.text 

63 

64 def test_plot_decays_trmc(self, decay_data) -> None: 

65 

66 fig = plot_decays( 

67 decay_data["xs_data"], 

68 decay_data["ys_data"], 

69 "TRMC", 

70 labels=decay_data["labels"], 

71 ) 

72 assert isinstance(fig, go.Figure) 

73 assert len(fig.data) == 2 

74 assert "Time (ns)" in fig.layout.xaxis.title.text 

75 assert "Intensity (cm<sup>2</sup>/(Vs))" in fig.layout.yaxis.title.text 

76 

77 def test_plot_decays_trmc_auto_labels(self, decay_data) -> None: 

78 

79 fig = plot_decays( 

80 decay_data["xs_data"], 

81 decay_data["ys_data"], 

82 "TRMC", 

83 ) 

84 expected = ["Decay 1", "Decay 2"] 

85 for trace, label in zip(fig.data, expected): 

86 assert trace.name == label 

87 

88 def test_plot_decays_with_fit(self, decay_data) -> None: 

89 

90 fig = plot_decays( 

91 decay_data["xs_data"], 

92 decay_data["ys_data"], 

93 "TRPL", 

94 ys_data2=decay_data["ys_data_fit"], 

95 labels=decay_data["labels"], 

96 label2=" (fitted)", 

97 ) 

98 assert isinstance(fig, go.Figure) 

99 assert len(fig.data) == 4 # Four traces: two for data and two for fit 

100 assert fig.data[0].name == "Sample A" 

101 assert fig.data[1].name == "Sample A (fitted)" 

102 assert fig.data[2].name == "Sample B" 

103 assert fig.data[3].name == "Sample B (fitted)" 

104 assert "dash" in fig.data[1].line 

105 assert "dash" in fig.data[3].line 

106 

107 

108class SimpleCarrierModel: 

109 """Simple carrier model for testing purposes""" 

110 

111 def __init__(self) -> None: 

112 self.CONC_LABELS_HTML = {"e": "Electrons", "h": "Holes"} 

113 self.CONC_COLORS = {"e": "blue", "h": "red"} 

114 

115 

116class TestPlotCarrierConcentrations: 

117 

118 @pytest.fixture 

119 def carrier_data(self) -> dict: 

120 """Example data""" 

121 

122 model = SimpleCarrierModel() 

123 

124 # Test data 

125 xs_data = [np.array([1, 2, 3]), np.array([1, 2, 3])] 

126 ys_data = [ 

127 {"e": np.array([10, 5, 2]), "h": np.array([10, 6, 3])}, 

128 {"e": np.array([8, 4, 1]), "h": np.array([8, 5, 2])}, 

129 ] 

130 N_0s = [1e15, 1e16] 

131 titles = ["Sample A", "Sample B"] 

132 xlabel = "Time (ns)" 

133 

134 return { 

135 "model": model, 

136 "xs_data": xs_data, 

137 "ys_data": ys_data, 

138 "N_0s": N_0s, 

139 "titles": titles, 

140 "xlabel": xlabel, 

141 } 

142 

143 def test_plot_carrier_concentrations(self, carrier_data) -> None: 

144 fig = plot_carrier_concentrations( 

145 carrier_data["xs_data"], 

146 carrier_data["ys_data"], 

147 carrier_data["N_0s"], 

148 carrier_data["titles"], 

149 carrier_data["xlabel"], 

150 carrier_data["model"], 

151 ) 

152 assert isinstance(fig, go.Figure) 

153 assert len(fig.data) == 4 # 2 samples x 2 carrier types = 4 traces 

154 

155 # Check that the traces are correctly named and colored 

156 electron_traces = [trace for trace in fig.data if trace.name == "Electrons"] 

157 hole_traces = [trace for trace in fig.data if trace.name == "Holes"] 

158 

159 assert len(electron_traces) == 2 

160 assert len(hole_traces) == 2 

161 

162 # Check that colors are correctly applied 

163 for trace in electron_traces: 

164 assert trace.line.color == "blue" 

165 

166 for trace in hole_traces: 

167 assert trace.line.color == "red" 

168 

169 # Check that showlegend is only true for the first occurrence of each type 

170 assert electron_traces[0].showlegend is True 

171 assert electron_traces[1].showlegend is False 

172 assert hole_traces[0].showlegend is True 

173 assert hole_traces[1].showlegend is False 

174 

175 

176class TestParallelPlot: 

177 def test_basic_plot(self) -> None: 

178 

179 data = [ 

180 {"a": 1, "b": 2}, 

181 {"a": 3, "b": 4}, 

182 ] 

183 hidden_keys = ["b"] 

184 

185 xp = parallel_plot(data, hidden_keys) 

186 

187 assert isinstance(xp, hiplot.Experiment) 

188 assert len(xp.datapoints) == 2 

189 assert xp.datapoints[0].values["ID"] == 1 

190 assert xp.datapoints[1].values["ID"] == 2 

191 

192 def test_hide_and_order(self) -> None: 

193 

194 data = [{"x": 10, "y": 20, "z": 30}] 

195 hidden_keys = ["y"] 

196 

197 xp = parallel_plot(data, hidden_keys) 

198 display_data = xp.display_data(hiplot.Displays.PARALLEL_PLOT) 

199 

200 assert display_data["hide"] == ["y", "uid"] 

201 assert display_data["order"] == ["z", "x", "ID"]