Coverage for app/fitting.py: 100%

66 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-08-28 18:30 +0000

1"""fitting module""" 

2 

3from typing import Callable 

4 

5import numpy as np 

6import scipy.optimize as sco 

7 

8from utility.dict import merge_dicts 

9 

10 

11def normalize_to_unit(number: float) -> tuple[float, int]: 

12 """Normalize a number to the range [0.1, 1], adjusting the power (exponent) accordingly. 

13 The function returns the normalized value and the exponent. 

14 

15 Examples 

16 -------- 

17 >>> normalize_to_unit(1.433364345e9) 

18 (0.1433364345, 10) 

19 >>> normalize_to_unit(14e-6) 

20 (0.13999999999999999, -4) 

21 >>> normalize_to_unit(-14e-6) 

22 (-0.13999999999999999, -4)""" 

23 

24 exponent = 0 # exponent value 

25 value = number # normalized value 

26 

27 if value == 0.0: 

28 value = 0.0 

29 exponent = 0 

30 

31 elif abs(value) < 1: 

32 while abs(value) < 0.1: 

33 value *= 10.0 

34 exponent -= 1 

35 

36 elif abs(value) > 1: 

37 while abs(value) > 1: 

38 value /= 10.0 

39 exponent += 1 

40 

41 return value, exponent 

42 

43 

44class Fit(object): 

45 """Fit class""" 

46 

47 def __init__( 

48 self, 

49 xs_data: list[np.ndarray], 

50 ys_data: list[np.ndarray], 

51 function: Callable, 

52 p0: dict, 

53 detached_parameters: list[str], 

54 fixed_parameters: list[dict], 

55 **kwargs, 

56 ): 

57 """Object constructor 

58 :param xs_data: array or list of arrays associated with the x data 

59 :param ys_data: array or list of arrays associated with the y data 

60 :param function: function used to fit the data 

61 :param p0: parameters initial guess. Need to contain any non-fixed parameters 

62 :param detached_parameters: list of parameters that are not shared between y data 

63 :param fixed_parameters: fixed parameters dict for each y data 

64 :param kwargs: keyword arguments passed to the scipy.optimise.least_square function""" 

65 

66 # Store the input arguments 

67 self.xs_data = xs_data 

68 self.ys_data = ys_data 

69 self.function = function 

70 self.p0 = p0 

71 self.fixed_parameters = fixed_parameters 

72 self.kwargs = kwargs 

73 

74 # Make sure that the fixed parameters are all the same and that the lengths match 

75 assert all(d.keys() == fixed_parameters[0].keys() for d in fixed_parameters) 

76 assert len(self.fixed_parameters) == len(self.xs_data) == len(self.ys_data) 

77 

78 # Remove fixed parameters from the detached parameters 

79 self.detached_parameters = [f for f in detached_parameters if f not in self.fixed_parameters[0]] 

80 

81 # Remove fixed parameters from initial guess 

82 self.p0 = {key: value for key, value in self.p0.items() if key not in self.fixed_parameters[0]} 

83 self.keys = list(self.p0.keys()) # all keys except for the ones fixed 

84 

85 # Normalise the initial guess values to unit to facilitate the optimisation 

86 p0_split = {key: normalize_to_unit(self.p0[key]) for key in self.p0} 

87 self.p0_mantissa, self.p0_factors = [{key: val[i] for key, val in p0_split.items()} for i in (0, 1)] 

88 

89 # Generate the positive bounds 

90 self.bounds = {key: [0, np.inf] for key in self.keys} 

91 

92 # Convert the guess values and bounds to list for the scipy curve fitting 

93 self.p0_list = [self.p0_mantissa[key] for key in self.keys] 

94 self.bounds_list = [self.bounds[key] for key in self.keys] 

95 

96 # Add the detached parameters to the list 

97 self.n = len(self.xs_data) # number of xs arrays 

98 self.p0_list += [self.p0_mantissa[key] for key in self.detached_parameters] * (self.n - 1) 

99 self.bounds_list += [self.bounds[key] for key in self.detached_parameters] * (self.n - 1) 

100 

101 def list_to_dicts(self, alist: list) -> list[dict]: 

102 """Convert the list of guess values (including detached parameters) into a list of dictionaries 

103 :param alist: list of guess values associated with the keys""" 

104 

105 # Create the base dict (with any potential first instance of the detached parameter) 

106 base_dict = dict(zip(self.keys, alist)) 

107 

108 # If detached parameters, extract them from the list and add them to the dictionaries 

109 if self.detached_parameters: 

110 # Determine the value of the detached parameters 

111 supp = np.array(alist[len(self.keys) :]).reshape((-1, len(self.detached_parameters))) 

112 # Add the detached values to each dictionary 

113 dicts = [merge_dicts(dict(zip(self.detached_parameters, p)), base_dict) for p in supp] 

114 # Add the base dictionary to the list 

115 dicts = [base_dict] + dicts 

116 

117 # If no detached parameters, just multiply the number of dictionaries 

118 else: 

119 dicts = [base_dict] * self.n 

120 

121 # Return the values multiplied by their factor 

122 return [{k: ps[k] * 10 ** self.p0_factors[k] for k in ps} for ps in dicts] 

123 

124 def error_function(self, alist: list) -> np.ndarray: 

125 """Calculate the difference between the output of the fit function and each array for the given list of 

126 parameter values. 

127 :param alist: list of parameter values""" 

128 

129 params_list = self.list_to_dicts(alist) 

130 errors = [] 

131 for params, x, y, fp in zip(params_list, self.xs_data, self.ys_data, self.fixed_parameters): 

132 errors.append((self.function(x, **merge_dicts(params, fp)) - y)) 

133 return np.concatenate(errors) 

134 

135 def fit(self) -> list[dict]: 

136 """Fit the data""" 

137 

138 popts = sco.least_squares( 

139 self.error_function, 

140 self.p0_list, 

141 bounds=np.transpose(self.bounds_list), 

142 jac="3-point", 

143 **self.kwargs, 

144 ).x 

145 popts_dicts = self.list_to_dicts(popts) 

146 return [merge_dicts(i, j) for i, j in zip(popts_dicts, self.fixed_parameters)] 

147 

148 def calculate_rss(self, y2: list[np.ndarray]) -> float: 

149 """Calculate the residual sum of squares""" 

150 

151 y1 = np.concatenate(self.ys_data) 

152 y2 = np.concatenate(y2) 

153 return 1.0 - np.sum((y1 - y2) ** 2, axis=-1) / np.sum((y1 - np.mean(y1)) ** 2) 

154 

155 def calculate_fits(self, popts: list[dict]) -> list[np.ndarray]: 

156 """Calculate the fits 

157 :param popts: list of optimised parameter dicts""" 

158 

159 return [self.function(x, **popt) for x, popt in zip(self.xs_data, popts)] 

160 

161 

162class FitFailedException(Exception): 

163 """Raised when a fit fails""" 

164 

165 pass