Coverage for tests/test_plot.py: 100%
97 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
1"""Test module for the functions in the `plot.py` module.
3This module contains unit tests for the functions implemented in the `plot.py` module. The purpose of these tests is to
4ensure the correct functionality of each function in different scenarios and to validate that the expected outputs are
5returned.
7Tests should cover various edge cases, valid inputs, and any other conditions that are necessary to confirm the
8robustness of the functions."""
10import hiplot
11import numpy as np
12import plotly.graph_objects as go
13import pytest
15from app.plot import parallel_plot, plot_carrier_concentrations, plot_decays, subplots
18class TestSubplots:
20 def test_simple_subplot_creation(self) -> None:
21 fig, positions = subplots(3)
22 assert isinstance(fig, go.Figure)
23 assert positions == [(1, 1), (2, 1), (3, 1)]
25 def test_subplot_with_max_columns(self) -> None:
26 fig, positions = subplots(9, 2)
27 assert isinstance(fig, go.Figure)
28 assert positions == [(1, 1), (1, 2), (2, 1), (2, 2), (3, 1), (3, 2), (4, 1), (4, 2), (5, 1)]
30 def test_subplot_with_kwargs(self) -> None:
31 fig, positions = subplots(4, subplot_titles=["A", "B", "C", "D"])
32 assert isinstance(fig, go.Figure)
33 assert positions == [(1, 1), (1, 2), (2, 1), (2, 2)]
36class TestPlotDecays:
38 @pytest.fixture
39 def decay_data(self) -> dict[str, list]:
40 """Example data"""
42 return {
43 "xs_data": [np.array([1, 2, 3]), np.array([1, 2, 3])],
44 "ys_data": [np.array([10, 5, 2]), np.array([8, 4, 1])],
45 "ys_data_fit": [np.array([9, 5, 2.5]), np.array([8.5, 4.2, 1.1])],
46 "labels": ["Sample A", "Sample B"],
47 }
49 def test_plot_decays_trpl(self, decay_data) -> None:
51 fig = plot_decays(
52 decay_data["xs_data"],
53 decay_data["ys_data"],
54 "TRPL",
55 labels=decay_data["labels"],
56 )
57 assert isinstance(fig, go.Figure)
58 assert len(fig.data) == 2 # Two traces for two datasets
59 assert fig.data[0].name == "Sample A"
60 assert fig.data[1].name == "Sample B"
61 assert "Time (ns)" in fig.layout.xaxis.title.text
62 assert "Intensity (a.u.)" in fig.layout.yaxis.title.text
64 def test_plot_decays_trmc(self, decay_data) -> None:
66 fig = plot_decays(
67 decay_data["xs_data"],
68 decay_data["ys_data"],
69 "TRMC",
70 labels=decay_data["labels"],
71 )
72 assert isinstance(fig, go.Figure)
73 assert len(fig.data) == 2
74 assert "Time (ns)" in fig.layout.xaxis.title.text
75 assert "Intensity (cm<sup>2</sup>/(Vs))" in fig.layout.yaxis.title.text
77 def test_plot_decays_trmc_auto_labels(self, decay_data) -> None:
79 fig = plot_decays(
80 decay_data["xs_data"],
81 decay_data["ys_data"],
82 "TRMC",
83 )
84 expected = ["Decay 1", "Decay 2"]
85 for trace, label in zip(fig.data, expected):
86 assert trace.name == label
88 def test_plot_decays_with_fit(self, decay_data) -> None:
90 fig = plot_decays(
91 decay_data["xs_data"],
92 decay_data["ys_data"],
93 "TRPL",
94 ys_data2=decay_data["ys_data_fit"],
95 labels=decay_data["labels"],
96 label2=" (fitted)",
97 )
98 assert isinstance(fig, go.Figure)
99 assert len(fig.data) == 4 # Four traces: two for data and two for fit
100 assert fig.data[0].name == "Sample A"
101 assert fig.data[1].name == "Sample A (fitted)"
102 assert fig.data[2].name == "Sample B"
103 assert fig.data[3].name == "Sample B (fitted)"
104 assert "dash" in fig.data[1].line
105 assert "dash" in fig.data[3].line
108class SimpleCarrierModel:
109 """Simple carrier model for testing purposes"""
111 def __init__(self) -> None:
112 self.CONC_LABELS_HTML = {"e": "Electrons", "h": "Holes"}
113 self.CONC_COLORS = {"e": "blue", "h": "red"}
116class TestPlotCarrierConcentrations:
118 @pytest.fixture
119 def carrier_data(self) -> dict:
120 """Example data"""
122 model = SimpleCarrierModel()
124 # Test data
125 xs_data = [np.array([1, 2, 3]), np.array([1, 2, 3])]
126 ys_data = [
127 {"e": np.array([10, 5, 2]), "h": np.array([10, 6, 3])},
128 {"e": np.array([8, 4, 1]), "h": np.array([8, 5, 2])},
129 ]
130 N_0s = [1e15, 1e16]
131 titles = ["Sample A", "Sample B"]
132 xlabel = "Time (ns)"
134 return {
135 "model": model,
136 "xs_data": xs_data,
137 "ys_data": ys_data,
138 "N_0s": N_0s,
139 "titles": titles,
140 "xlabel": xlabel,
141 }
143 def test_plot_carrier_concentrations(self, carrier_data) -> None:
144 fig = plot_carrier_concentrations(
145 carrier_data["xs_data"],
146 carrier_data["ys_data"],
147 carrier_data["N_0s"],
148 carrier_data["titles"],
149 carrier_data["xlabel"],
150 carrier_data["model"],
151 )
152 assert isinstance(fig, go.Figure)
153 assert len(fig.data) == 4 # 2 samples x 2 carrier types = 4 traces
155 # Check that the traces are correctly named and colored
156 electron_traces = [trace for trace in fig.data if trace.name == "Electrons"]
157 hole_traces = [trace for trace in fig.data if trace.name == "Holes"]
159 assert len(electron_traces) == 2
160 assert len(hole_traces) == 2
162 # Check that colors are correctly applied
163 for trace in electron_traces:
164 assert trace.line.color == "blue"
166 for trace in hole_traces:
167 assert trace.line.color == "red"
169 # Check that showlegend is only true for the first occurrence of each type
170 assert electron_traces[0].showlegend is True
171 assert electron_traces[1].showlegend is False
172 assert hole_traces[0].showlegend is True
173 assert hole_traces[1].showlegend is False
176class TestParallelPlot:
177 def test_basic_plot(self) -> None:
179 data = [
180 {"a": 1, "b": 2},
181 {"a": 3, "b": 4},
182 ]
183 hidden_keys = ["b"]
185 xp = parallel_plot(data, hidden_keys)
187 assert isinstance(xp, hiplot.Experiment)
188 assert len(xp.datapoints) == 2
189 assert xp.datapoints[0].values["ID"] == 1
190 assert xp.datapoints[1].values["ID"] == 2
192 def test_hide_and_order(self) -> None:
194 data = [{"x": 10, "y": 20, "z": 30}]
195 hidden_keys = ["y"]
197 xp = parallel_plot(data, hidden_keys)
198 display_data = xp.display_data(hiplot.Displays.PARALLEL_PLOT)
200 assert display_data["hide"] == ["y", "uid"]
201 assert display_data["order"] == ["z", "x", "ID"]