Coverage for tests/test_main.py: 100%
248 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 09:13 +0000
1"""Test module for the functions in the `main.py` module.
3This module contains unit tests for the functions implemented in the `main.py` module. The purpose of these tests is to
4ensure the correct functionality of each function in different scenarios and to validate that the expected outputs are
5returned.
7Tests should cover various edge cases, valid inputs, and any other conditions that are necessary to confirm the
8robustness of the functions."""
10import copy
11import os
12from io import BytesIO
13from unittest.mock import MagicMock, patch
15import numpy as np
16from streamlit.testing.v1 import AppTest
18from app.resources import APP_MODES, BTD_TRMC_DATA, BTD_TRPL_DATA, BT_TRMC_DATA, BT_TRPL_DATA
19from app.utility.data import are_close
22class TestApp:
24 main_path = "app/main.py"
26 def teardown_method(self) -> None:
27 """Teardown method that runs after each test."""
29 # Make sure that no exception happened
30 assert len(self.at.exception) == 0
32 def test_default(self) -> None:
34 # Start the app and run it
35 self.at = AppTest(self.main_path, default_timeout=100).run()
36 assert len(self.at.error) == 0
37 assert self.at.expander[-1].label == "License & Disclaimer"
39 def get_widget_by_key(
40 self,
41 widget: str,
42 key: str,
43 verbose: bool = False,
44 ):
45 """Get a widget given its key
46 :param widget: widget type
47 :param key: key
48 :param verbose: if True, print the index of the widget"""
50 keys = [wid.key for wid in getattr(self.at, widget)]
51 if verbose:
52 print(keys) # pragma: no cover
53 index = keys.index(key)
54 if verbose:
55 print(index) # pragma: no cover
56 return getattr(self.at, widget)[index]
58 def set_period(self, value: str):
59 """Set the repetition period"""
61 self.get_widget_by_key("text_input", "period_input_").set_value(value).run()
63 def set_app_mode(self, value: str):
64 """Set the app mode"""
66 self.get_widget_by_key("selectbox", "fit_mode_").set_value(value).run()
68 def set_quantity(self, value: str):
69 """Set the fit quantity"""
71 self.get_widget_by_key("radio", "quantity_input_").set_value(value).run()
73 def set_model(self, value: str):
74 """Set the model"""
76 self.get_widget_by_key("selectbox", "model_name_").set_value(value).run()
78 def run(self):
79 """Press the run button"""
81 self.get_widget_by_key("button", "run_button").click().run()
83 def set_data_delimiter(self, value: str):
84 """Set the data delimiter"""
86 self.get_widget_by_key("radio", "data_delimiter").set_value(value).run()
88 def set_data_format(self, value: str):
89 """Set the data format"""
91 self.get_widget_by_key("radio", "data_format_").set_value(value).run()
93 def set_preprocess(self, value: bool):
94 """Set the pre-processing option"""
96 self.get_widget_by_key("checkbox", "preprocess_").set_value(value).run()
98 def set_fixed_parameter(self, key: str, value: str, model: str = "BTA"):
99 """Set the fixed parameter value"""
101 self.get_widget_by_key("text_input", model + key + "fixed").set_value(value).run()
103 def set_guess_parameter(self, key: str, value: str, model: str = "BTA"):
104 """Set the guess parameter value"""
106 self.get_widget_by_key("text_input", model + key + "guess").set_value(value).run()
108 def set_guesses_parameter(self, key: str, value: str, model: str = "BTA"):
109 """Set the guess value range"""
111 self.get_widget_by_key("text_input", model + key + "guesses").set_value(value).run()
113 def set_matching_input(self, value):
114 """Set the matching input"""
116 self.get_widget_by_key("text_input", "matching_input").set_value(value).run()
118 def fill_N0s(self, N0s: list[float | str]) -> None:
119 """Fill the photoexcited carrier concentrations inputs
120 :param N0s: photoexcited carrier concentration values"""
122 widgets = [self.get_widget_by_key("text_input", f"fluence_{i}") for i in range(len(N0s))]
123 assert len(widgets) == len(N0s)
125 for text_input, N0 in zip(widgets, N0s):
126 text_input.input(str(N0))
127 self.at.run()
129 @staticmethod
130 def create_mock_file(mock_file_uploader: MagicMock, data: np.ndarray) -> None:
131 """Create a temporary CSV file with uneven columns and mock file upload.
132 :param mock_file_uploader: MagicMock
133 :param data: data to be uploaded"""
135 # Save the data
136 temp_path = "_temp.csv"
137 np.savetxt(temp_path, data, fmt="%s", delimiter=",")
139 # Load the data to the mock file
140 with open(temp_path, "rb") as f:
141 mock_file_uploader.return_value = BytesIO(f.read())
143 # Remove the file
144 os.remove(temp_path)
146 # -------------------------------------------------- TEST FITTING --------------------------------------------------
148 @patch("streamlit.sidebar.file_uploader")
149 def _test_fitting(
150 self,
151 dataset: tuple[list[np.ndarray], list[np.ndarray], list[float]],
152 quantity: str,
153 model: str,
154 expected_output: dict,
155 preprocess: bool,
156 mock_file_uploader: MagicMock,
157 ) -> None:
159 # Load the data
160 data = np.transpose([dataset[0][0]] + dataset[1])
161 self.create_mock_file(mock_file_uploader, data)
163 # Start the app and run it
164 self.at = AppTest(self.main_path, default_timeout=100).run()
166 # Select the correct quantity
167 self.set_quantity(quantity)
169 # Pre-processing
170 if preprocess:
171 self.set_preprocess(True)
173 # Check the number of fluence inputs
174 self.fill_N0s(dataset[2])
176 # Select the model
177 self.set_model(model)
179 # Click on the button and assert the fit results
180 self.run()
182 assert are_close(self.at.session_state["results"][0]["popts"][0], expected_output["popt"])
183 assert are_close(self.at.session_state["results"][0]["contributions"], expected_output["contributions"])
185 self.set_period("200")
187 assert are_close(self.at.session_state["carrier_accumulation"]["CA"], expected_output["ca"])
188 assert len(self.at.error) == 0
190 def test_bt_trpl(self) -> None:
191 expected = {
192 "popt": {
193 "I": 1.0,
194 "N_0": 1000000000000000.0,
195 "k_A": 0.0,
196 "k_B": 5.036092379781616e-19,
197 "k_T": 0.01013367471651538,
198 "y_0": 0.0,
199 },
200 "contributions": {
201 "A": [0.0, 0.0, 0.0],
202 "B": [3.21140477, 24.33019951, 74.24002711],
203 "T": [96.78859523, 75.66980049, 25.75997289],
204 },
205 "ca": [0.20075745903090358, 0.9503120802981269, 0.6080441775140444],
206 }
207 self._test_fitting(BT_TRPL_DATA, "TRPL", "BTA", expected, False)
209 def test_btd_trpl(self) -> None:
210 expected = {
211 "popt": {
212 "I": 1.0,
213 "N_0": 51000000000000.0,
214 "N_T": 59388399740318.6,
215 "k_B": 5.07768333654429e-19,
216 "k_D": 7.695781718031608e-19,
217 "k_T": 1.1840584706593167e-16,
218 "p_0": 71425163414191.94,
219 "y_0": 0.0,
220 },
221 "contributions": {
222 "B": [1.97464011, 25.01522655, 62.98962314, 86.000232, 92.82901167],
223 "D": [0.63268051, 8.28295762, 8.90099192, 4.14960937, 1.75052228],
224 "T": [97.39267938, 66.70181583, 28.10938494, 9.85015863, 5.42046605],
225 },
226 "ca": [47.390257057209226, 17.86713876214925, 14.169801097560752, 25.484481563736466, 20.172770374561566],
227 }
228 self._test_fitting(BTD_TRPL_DATA, "TRPL", "BTD", expected, False)
230 def test_bt_trmc(self) -> None:
231 expected = {
232 "popt": {
233 "N_0": 1000000000000000.0,
234 "k_A": 0.0,
235 "k_B": 4.9191697581358015e-19,
236 "k_T": 0.010046921491709,
237 "mu": 10.021005406705859,
238 "y_0": 0.0,
239 },
240 "contributions": {
241 "A": [0.0, 0.0, 0.0],
242 "B": [2.37117424, 18.60734623, 63.77640104],
243 "T": [97.62882576, 81.39265377, 36.22359896],
244 },
245 "ca": [0.17048546661642128, 0.8133504291988225, 0.5281220522039942],
246 }
247 self._test_fitting(BT_TRMC_DATA, "TRMC", "BTA", expected, False)
249 def test_btd_trmc(self) -> None:
250 expected = {
251 "popt": {
252 "N_0": 51000000000000.0,
253 "N_T": 59033305468562.54,
254 "k_B": 5.067910058596048e-19,
255 "k_D": 8.038666341707501e-19,
256 "k_T": 1.1646600592587665e-16,
257 "mu_e": 20.316487864059788,
258 "mu_h": 29.881798526675844,
259 "p_0": 64418783311953.03,
260 "y_0": 0.0,
261 },
262 "contributions": {
263 "B": [1.71049246, 22.07876196, 53.65212212, 76.01675686, 85.26077901],
264 "D": [56.99175709, 44.95497873, 25.57202925, 11.08623379, 5.1548549],
265 "T": [41.29775044, 32.96625931, 20.77584864, 12.89700935, 9.58436609],
266 },
267 "ca": [17.754055972544826, 6.065924114804922, 9.763315562751995, 17.252574397074692, 16.546710504457472],
268 }
269 self._test_fitting(BTD_TRMC_DATA, "TRMC", "BTD", expected, False)
271 def test_bt_trpl_preprocess(self) -> None:
272 expected = {
273 "popt": {
274 "k_T": 0.01027692599435874,
275 "k_B": 5.223353716145634e-19,
276 "k_A": 0.0,
277 "y_0": 0.0,
278 "I": 1.0,
279 "N_0": 1000000000000000.0,
280 },
281 "contributions": {
282 "T": np.array([96.71939708, 75.26854715, 25.35779815]),
283 "B": np.array([3.28060292, 24.73145285, 74.64220185]),
284 "A": np.array([0.0, 0.0, 0.0]),
285 },
286 "ca": [0.19826769835378788, 0.9274715441877357, 0.5803950947538994],
287 }
288 data = list(copy.deepcopy(BT_TRPL_DATA))
289 x0 = np.linspace(0, 50, 51)
290 data[0] = [np.concatenate([x0, x + x0[-1]]) for x in data[0]]
291 data[1] = [np.concatenate([np.zeros(len(x0)), x]) for x in data[1]]
292 self._test_fitting(BT_TRPL_DATA, "TRPL", "BTA", expected, True)
294 # -------------------------------------------------- TEST INVALID --------------------------------------------------
296 @patch("streamlit.sidebar.file_uploader")
297 def test_invalid_carrier_concentrations(self, mock_file_uploader: MagicMock) -> None:
299 # Load the data
300 data = np.random.randn(10, 3)
301 self.create_mock_file(mock_file_uploader, data)
303 # Start the app and run it
304 self.at = AppTest(self.main_path, default_timeout=100).run()
306 # Check the number of fluence inputs
307 self.fill_N0s(["f", "3"])
309 expected = "Uh-oh! The initial carrier concentrations input is not valid"
310 assert self.at.error[0].value == expected
312 @patch("streamlit.sidebar.file_uploader")
313 def test_invalid_fixed_guess_value(self, mock_file_uploader: MagicMock) -> None:
315 # Load the data
316 data = np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1])
317 self.create_mock_file(mock_file_uploader, data)
319 # Start the app and run it
320 self.at = AppTest(self.main_path, default_timeout=100).run()
322 # Check the number of fluence inputs
323 self.fill_N0s(BT_TRPL_DATA[2])
325 # Change the fixed value to an incorrect value
326 self.set_fixed_parameter("k_T", "3")
327 self.set_fixed_parameter("k_T", "f")
328 assert self.at.session_state["models"]["BTA"]["TRPL"].fvalues["k_T"] == 3
330 # Change the guess value to an incorrect value
331 self.set_fixed_parameter("k_T", "") # reset the fixed value
332 self.set_guess_parameter("k_T", "3") # set the guess value
333 self.set_guess_parameter("k_T", "f") # guess value with incorrect string
334 assert self.at.session_state["models"]["BTA"]["TRPL"].gvalues["k_T"] == 3
336 # Change the fixed value range to an incorrect value
337 self.set_app_mode(APP_MODES[1])
338 self.set_guesses_parameter("k_T", "2,5,6")
339 self.set_guesses_parameter("k_T", "2,5,f")
340 assert self.at.session_state["models"]["BTA"]["TRPL"].gvalues_range["k_T"] == [2.0, 5.0, 6.0]
342 @patch("streamlit.sidebar.file_uploader")
343 def test_bad_fitting(self, mock_file_uploader: MagicMock) -> None:
345 # Load the data
346 data = np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1])
347 self.create_mock_file(mock_file_uploader, data)
349 # Start the app and run it
350 self.at = AppTest(self.main_path, default_timeout=100).run()
352 # Check the number of fluence inputs
353 self.fill_N0s(BT_TRPL_DATA[2])
355 # Change the fixed value to an incorrect value
356 self.set_guess_parameter("k_T", "-1")
358 # Click on the button and assert the fit results
359 self.run()
361 expected = "The data could not be fitted. Try changing the parameter guess or fixed values."
362 assert self.at.error[0].value == expected
364 @patch("streamlit.sidebar.file_uploader")
365 def test_uneven_column_file(self, mock_file_uploader: MagicMock):
367 x = np.linspace(0, 10, 50)
368 y = np.cos(x[:-10]) # Make y shorter than x
369 y_str = [str(_y) for _y in y] + [""] * (len(x) - len(y))
370 temp_path = "_temp.csv"
371 np.savetxt(temp_path, np.transpose([x, y_str]), fmt="%s", delimiter=",")
372 with open("_temp.csv", "rb") as f:
373 mock_file_uploader.return_value = BytesIO(f.read())
374 os.remove(temp_path)
376 # Start the app and run it
377 self.at = AppTest(self.main_path, default_timeout=100).run()
379 expected = "Uh-oh! The data could not be loaded. Error: Mismatch at index 1: x and y columns must have the same length."
380 assert self.at.error[0].value == expected
381 assert len(self.at.error) == 1
383 @patch("streamlit.sidebar.file_uploader")
384 def test_uneven_column_file2(self, mock_file_uploader: MagicMock):
386 # Create uneven data
387 x = np.linspace(0, 10, 50)
388 y = np.cos(x[:-10])
389 y_str = [str(_y) for _y in y] + [""] * (len(x) - len(y))
391 # Load the data
392 self.create_mock_file(mock_file_uploader, np.transpose([x, y_str]))
394 # Start the app and run it
395 self.at = AppTest(self.main_path, default_timeout=100).run()
397 expected = "Uh-oh! The data could not be loaded. Error: Mismatch at index 1: x and y columns must have the same length."
398 assert self.at.error[0].value == expected
399 assert len(self.at.error) == 1
401 @patch("streamlit.sidebar.file_uploader")
402 def test_column_file(self, mock_file_uploader: MagicMock):
404 # Create uneven data
405 x = [np.linspace(0, 10, 50)] * 3
406 y = [np.cos(x[0])] * 2
407 data = np.transpose([x[0], y[0], x[1], y[1], x[2]])
409 # Load the data
410 self.create_mock_file(mock_file_uploader, data)
412 # Start the app and run it
413 self.at = AppTest(self.main_path, default_timeout=100).run()
415 # Change the data format
416 self.set_data_format("X1/Y1/X2/Y2...")
418 expected = "Uh-oh! The data could not be loaded. Error: Mismatch: x data and y data must have the same number of columns."
419 assert self.at.error[0].value == expected
420 assert len(self.at.error) == 1
422 @patch("streamlit.sidebar.file_uploader")
423 def test_incorrect_delimiter(self, mock_file_uploader: MagicMock):
425 # Load the data
426 data = np.random.randn(10, 3)
427 self.create_mock_file(mock_file_uploader, data)
429 # Start the app and run it
430 self.at = AppTest(self.main_path, default_timeout=100).run()
432 # Change the delimiter
433 self.set_data_delimiter(";")
435 expected = "Uh-oh! The data could not be loaded. Error: Unknown error, Check that the correct delimiter has been selected."
436 assert self.at.error[0].value == expected
437 assert len(self.at.error) == 1
439 @patch("streamlit.sidebar.file_uploader")
440 def test_failed_ca(self, mock_file_uploader: MagicMock) -> None:
442 # Save and load the data
443 self.create_mock_file(mock_file_uploader, np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1]))
445 # Start the app and run it
446 self.at = AppTest(self.main_path, default_timeout=100).run()
448 # Check the number of fluence inputs
449 self.fill_N0s(BT_TRPL_DATA[2])
451 # Set the period
452 self.set_period("0.0001")
454 # Click on the button and assert the fit results
455 self.run()
457 assert self.at.session_state["carrier_accumulation"] == {}
458 expected = "Carrier accumulation could not be calculated due to excessive computational requirements."
459 assert self.at.warning[0].value == expected
461 # Set the app mode
462 self.set_app_mode(APP_MODES[1])
464 # Click on the run button and assert the fit results
465 self.run()
467 assert self.at.session_state["carrier_accumulation"] == []
469 @patch("streamlit.sidebar.file_uploader")
470 def test_failed_matching(self, mock_file_uploader: MagicMock) -> None:
472 # Save and load the data
473 self.create_mock_file(mock_file_uploader, np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1]))
475 # Start the app and run it
476 self.at = AppTest(self.main_path, default_timeout=100).run()
478 # Check the number of fluence inputs
479 self.fill_N0s(BT_TRPL_DATA[2])
481 # Click on the button and assert the fit results
482 self.run()
484 self.set_matching_input("f")
486 assert self.at.warning[0].value == "Please input correct values."
488 @patch("streamlit.sidebar.file_uploader")
489 def test_bad_grid_fitting(self, mock_file_uploader: MagicMock) -> None:
491 # Load the data
492 data = np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1])
493 self.create_mock_file(mock_file_uploader, data)
495 # Start the app and run it
496 self.at = AppTest(self.main_path, default_timeout=100).run()
498 self.set_app_mode(APP_MODES[1])
500 # Check the number of fluence inputs
501 self.fill_N0s(BT_TRPL_DATA[2])
503 # Change the fixed value to an incorrect value
504 self.set_guesses_parameter("k_T", "-1, -2")
506 # Click on the button and assert the fit results
507 self.run()
509 expected = "The data could not be fitted. Try changing the parameter guess or fixed values."
510 assert self.at.error[0].value == expected
512 # ----------------------------------------------------- OTHERS -----------------------------------------------------
514 @patch("streamlit.sidebar.file_uploader")
515 def test_settings_changed(self, mock_file_uploader: MagicMock) -> None:
517 # Load the data
518 data = np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1])
519 self.create_mock_file(mock_file_uploader, data)
521 # Start the app and run it
522 self.at = AppTest(self.main_path, default_timeout=100).run()
524 # Check the number of fluence inputs
525 self.fill_N0s(BT_TRPL_DATA[2])
527 # Click on the button and assert the fit results
528 self.run()
530 # Change settings
531 self.set_guess_parameter("k_T", "1")
533 expected = "You have changed some of the input settings. Press 'Run' to apply these changes."
534 assert self.at.warning[0].value == expected
536 @patch("streamlit.sidebar.file_uploader")
537 def test_stored_ca(self, mock_file_uploader: MagicMock) -> None:
539 # Load the data
540 data = np.transpose([BT_TRPL_DATA[0][0]] + BT_TRPL_DATA[1])
541 self.create_mock_file(mock_file_uploader, data)
543 self.at = AppTest(self.main_path, default_timeout=100).run() # start the app and run it
544 self.fill_N0s(BT_TRPL_DATA[2]) # check the number of fluence inputs
545 self.run() # click on the run button
546 self.set_period("100") # set the period
548 # Rerun
549 self.at.run()
550 assert self.at.session_state.carrier_accumulation is not None
551 assert len(self.at.error) == 0
553 # ------------------------------------------------ TEST GRID FITTING -----------------------------------------------
555 @patch("streamlit.sidebar.file_uploader")
556 def _test_grid_fitting(
557 self,
558 dataset: tuple[list[np.ndarray], list[np.ndarray], list[float]],
559 quantity: str,
560 model: str,
561 expected_output: dict,
562 mock_file_uploader: MagicMock,
563 ) -> None:
565 # Save and load the data
566 self.create_mock_file(mock_file_uploader, np.transpose([dataset[0][0]] + dataset[1]))
568 # Start the app and run it
569 self.at = AppTest(self.main_path, default_timeout=100).run()
571 # Set the app mode
572 self.set_app_mode(APP_MODES[1])
574 # Select the correct quantity
575 self.set_quantity(quantity)
577 # Check the number of fluence inputs
578 self.fill_N0s(dataset[2])
580 # Select the model
581 self.set_model(model)
583 # Check that the run button is present
584 assert len(self.at.sidebar.button) == 1
585 assert self.at.sidebar.button[0].label == "Run"
587 # Click on the button and assert the fit results
588 self.run()
590 assert self.at.markdown[1].value == "## Parallel plot"
591 self.set_period("200")
593 ca = [f["CA"] for f in self.at.session_state["carrier_accumulation"]]
594 assert are_close(ca, expected_output)
595 assert len(self.at.error) == 0
597 def test_bt_trpl_grid(self) -> None:
598 expected = [
599 [0.20075751831579725, 0.9503122259985342, 0.6080441291965166],
600 [0.20075752653634926, 0.9503122402142683, 0.6080441126882885],
601 [0.2007576800148292, 0.9503126618979119, 0.608044067096708],
602 [0.2007575570692477, 0.9503123227196708, 0.6080440988083968],
603 ]
604 self._test_grid_fitting(BT_TRPL_DATA, "TRPL", "BTA", expected)
606 def test_bt_trmc_grid(self) -> None:
607 expected = [
608 [0.1704854700170455, 0.8133504278268588, 0.5281220312685408],
609 [0.17048547467589104, 0.8133504527513102, 0.5281220530171438],
610 [0.17048546359939576, 0.8133504112995238, 0.5281220345702942],
611 [0.17048545828642347, 0.8133503975122747, 0.5281220391482655],
612 [0.17048548691176446, 0.8133504687864224, 0.5281220151080401],
613 [0.17048546618355087, 0.8133504175408257, 0.528122032816325],
614 [0.1704854670983691, 0.8133504197771313, 0.5281220324267089],
615 [0.17048547165501304, 0.8133504321135743, 0.5281220299483136],
616 ]
618 self._test_grid_fitting(BT_TRMC_DATA, "TRMC", "BTA", expected)