Coverage for app/fitting.py: 100%
66 statements
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 18:30 +0000
« prev ^ index » next coverage.py v7.10.5, created at 2025-08-28 18:30 +0000
1"""fitting module"""
3from typing import Callable
5import numpy as np
6import scipy.optimize as sco
8from utility.dict import merge_dicts
11def normalize_to_unit(number: float) -> tuple[float, int]:
12 """Normalize a number to the range [0.1, 1], adjusting the power (exponent) accordingly.
13 The function returns the normalized value and the exponent.
15 Examples
16 --------
17 >>> normalize_to_unit(1.433364345e9)
18 (0.1433364345, 10)
19 >>> normalize_to_unit(14e-6)
20 (0.13999999999999999, -4)
21 >>> normalize_to_unit(-14e-6)
22 (-0.13999999999999999, -4)"""
24 exponent = 0 # exponent value
25 value = number # normalized value
27 if value == 0.0:
28 value = 0.0
29 exponent = 0
31 elif abs(value) < 1:
32 while abs(value) < 0.1:
33 value *= 10.0
34 exponent -= 1
36 elif abs(value) > 1:
37 while abs(value) > 1:
38 value /= 10.0
39 exponent += 1
41 return value, exponent
44class Fit(object):
45 """Fit class"""
47 def __init__(
48 self,
49 xs_data: list[np.ndarray],
50 ys_data: list[np.ndarray],
51 function: Callable,
52 p0: dict,
53 detached_parameters: list[str],
54 fixed_parameters: list[dict],
55 **kwargs,
56 ):
57 """Object constructor
58 :param xs_data: array or list of arrays associated with the x data
59 :param ys_data: array or list of arrays associated with the y data
60 :param function: function used to fit the data
61 :param p0: parameters initial guess. Need to contain any non-fixed parameters
62 :param detached_parameters: list of parameters that are not shared between y data
63 :param fixed_parameters: fixed parameters dict for each y data
64 :param kwargs: keyword arguments passed to the scipy.optimise.least_square function"""
66 # Store the input arguments
67 self.xs_data = xs_data
68 self.ys_data = ys_data
69 self.function = function
70 self.p0 = p0
71 self.fixed_parameters = fixed_parameters
72 self.kwargs = kwargs
74 # Make sure that the fixed parameters are all the same and that the lengths match
75 assert all(d.keys() == fixed_parameters[0].keys() for d in fixed_parameters)
76 assert len(self.fixed_parameters) == len(self.xs_data) == len(self.ys_data)
78 # Remove fixed parameters from the detached parameters
79 self.detached_parameters = [f for f in detached_parameters if f not in self.fixed_parameters[0]]
81 # Remove fixed parameters from initial guess
82 self.p0 = {key: value for key, value in self.p0.items() if key not in self.fixed_parameters[0]}
83 self.keys = list(self.p0.keys()) # all keys except for the ones fixed
85 # Normalise the initial guess values to unit to facilitate the optimisation
86 p0_split = {key: normalize_to_unit(self.p0[key]) for key in self.p0}
87 self.p0_mantissa, self.p0_factors = [{key: val[i] for key, val in p0_split.items()} for i in (0, 1)]
89 # Generate the positive bounds
90 self.bounds = {key: [0, np.inf] for key in self.keys}
92 # Convert the guess values and bounds to list for the scipy curve fitting
93 self.p0_list = [self.p0_mantissa[key] for key in self.keys]
94 self.bounds_list = [self.bounds[key] for key in self.keys]
96 # Add the detached parameters to the list
97 self.n = len(self.xs_data) # number of xs arrays
98 self.p0_list += [self.p0_mantissa[key] for key in self.detached_parameters] * (self.n - 1)
99 self.bounds_list += [self.bounds[key] for key in self.detached_parameters] * (self.n - 1)
101 def list_to_dicts(self, alist: list) -> list[dict]:
102 """Convert the list of guess values (including detached parameters) into a list of dictionaries
103 :param alist: list of guess values associated with the keys"""
105 # Create the base dict (with any potential first instance of the detached parameter)
106 base_dict = dict(zip(self.keys, alist))
108 # If detached parameters, extract them from the list and add them to the dictionaries
109 if self.detached_parameters:
110 # Determine the value of the detached parameters
111 supp = np.array(alist[len(self.keys) :]).reshape((-1, len(self.detached_parameters)))
112 # Add the detached values to each dictionary
113 dicts = [merge_dicts(dict(zip(self.detached_parameters, p)), base_dict) for p in supp]
114 # Add the base dictionary to the list
115 dicts = [base_dict] + dicts
117 # If no detached parameters, just multiply the number of dictionaries
118 else:
119 dicts = [base_dict] * self.n
121 # Return the values multiplied by their factor
122 return [{k: ps[k] * 10 ** self.p0_factors[k] for k in ps} for ps in dicts]
124 def error_function(self, alist: list) -> np.ndarray:
125 """Calculate the difference between the output of the fit function and each array for the given list of
126 parameter values.
127 :param alist: list of parameter values"""
129 params_list = self.list_to_dicts(alist)
130 errors = []
131 for params, x, y, fp in zip(params_list, self.xs_data, self.ys_data, self.fixed_parameters):
132 errors.append((self.function(x, **merge_dicts(params, fp)) - y))
133 return np.concatenate(errors)
135 def fit(self) -> list[dict]:
136 """Fit the data"""
138 popts = sco.least_squares(
139 self.error_function,
140 self.p0_list,
141 bounds=np.transpose(self.bounds_list),
142 jac="3-point",
143 **self.kwargs,
144 ).x
145 popts_dicts = self.list_to_dicts(popts)
146 return [merge_dicts(i, j) for i, j in zip(popts_dicts, self.fixed_parameters)]
148 def calculate_rss(self, y2: list[np.ndarray]) -> float:
149 """Calculate the residual sum of squares"""
151 y1 = np.concatenate(self.ys_data)
152 y2 = np.concatenate(y2)
153 return 1.0 - np.sum((y1 - y2) ** 2, axis=-1) / np.sum((y1 - np.mean(y1)) ** 2)
155 def calculate_fits(self, popts: list[dict]) -> list[np.ndarray]:
156 """Calculate the fits
157 :param popts: list of optimised parameter dicts"""
159 return [self.function(x, **popt) for x, popt in zip(self.xs_data, popts)]
162class FitFailedException(Exception):
163 """Raised when a fit fails"""
165 pass