11from dataclasses import dataclass
2- import warnings
32
43import numpy as np
54import pandas as pd
@@ -61,7 +60,7 @@ def sensitivity_indices(
6160 Sensitivity indices, combined effect of each input.
6261 foe : ndarray of shape (n_factors, 1)
6362 First-order effects (also called 'main' or 'individual').
64- soe : ndarray of shape (n_factors, 1 )
63+ soe : ndarray of shape (n_factors, n_factors )
6564 Second-order effects (also called 'interaction').
6665
6766 Examples
@@ -96,15 +95,21 @@ def sensitivity_indices(
9695 array([0.43157591, 0.44241433, 0.11767249])
9796
9897 """
98+ # Handle inputs conversion
9999 if isinstance (inputs , pd .DataFrame ):
100100 cat_columns = inputs .select_dtypes (["category" , "O" ]).columns
101101 inputs [cat_columns ] = inputs [cat_columns ].apply (
102102 lambda x : x .astype ("category" ).cat .codes
103103 )
104104 inputs = inputs .to_numpy ()
105- if isinstance (output , pd .DataFrame ):
105+
106+ # Handle output conversion first, then flatten
107+ if isinstance (output , (pd .DataFrame , pd .Series )):
106108 output = output .to_numpy ()
107109
110+ # Flatten output if it's (N, 1)
111+ output = output .flatten ()
112+
108113 n_runs , n_factors = inputs .shape
109114 n_bins_foe , n_bins_soe = number_of_bins (n_runs , n_factors )
110115
@@ -116,55 +121,64 @@ def sensitivity_indices(
116121 soe = np .zeros ((n_factors , n_factors ))
117122
118123 for i in range (n_factors ):
119- # first order
124+ # 1. First- order effects (FOE)
120125 xi = inputs [:, i ]
121126
122127 bin_avg , _ , binnumber = stats .binned_statistic (
123- x = xi , values = output , bins = n_bins_foe
128+ x = xi , values = output , bins = n_bins_foe , statistic = "mean"
124129 )
125- # can have NaN in the average but no corresponding binnumber
126- bin_avg = bin_avg [~ np .isnan (bin_avg )]
127- bin_counts = np .unique (binnumber , return_counts = True )[1 ]
128130
129- # weighted variance and divide by the overall variance of the output
130- foe [i ] = _weighted_var (bin_avg , weights = bin_counts ) / var_y
131+ # Filter empty bins and get weights (counts)
132+ mask_foe = ~ np .isnan (bin_avg )
133+ mean_i_foe = bin_avg [mask_foe ]
134+ # binnumber starts at 1; 0 is for values outside range
135+ bin_counts_foe = np .unique (binnumber [binnumber > 0 ], return_counts = True )[1 ]
136+
137+ foe [i ] = _weighted_var (mean_i_foe , weights = bin_counts_foe ) / var_y
131138
132- # second order
139+ # 2. Second- order effects (SOE)
133140 for j in range (n_factors ):
134- if i == j or j < i :
141+ if j <= i :
135142 continue
136143
137144 xj = inputs [:, j ]
138145
139- bin_avg , * edges , binnumber = stats .binned_statistic_2d (
146+ # 2D Binned Statistic for Var(E[Y|Xi, Xj])
147+ bin_avg_ij , x_edges , y_edges , binnumber_ij = stats .binned_statistic_2d (
140148 x = xi , y = xj , values = output , bins = n_bins_soe , expand_binnumbers = False
141149 )
142150
143- mean_ij = bin_avg [~ np .isnan (bin_avg )]
144- bin_counts = np .unique (binnumber , return_counts = True )[1 ]
145- var_ij = _weighted_var (mean_ij , weights = bin_counts )
146-
147- # expand_binnumbers here
148- nbin = np .array ([len (edges_ ) + 1 for edges_ in edges ])
149- binnumbers = np .asarray (np .unravel_index (binnumber , nbin ))
150-
151- bin_counts_i = np .unique (binnumbers [0 ], return_counts = True )[1 ]
152- bin_counts_j = np .unique (binnumbers [1 ], return_counts = True )[1 ]
151+ mask_ij = ~ np .isnan (bin_avg_ij )
152+ mean_ij = bin_avg_ij [mask_ij ]
153+ counts_ij = np .unique (binnumber_ij [binnumber_ij > 0 ], return_counts = True )[1 ]
154+ var_ij = _weighted_var (mean_ij , weights = counts_ij )
153155
154- # handle NaNs
155- with warnings .catch_warnings ():
156- warnings .simplefilter ("ignore" , RuntimeWarning )
157- mean_i = np .nanmean (bin_avg , axis = 1 )
158- mean_i = mean_i [~ np .isnan (mean_i )]
159- mean_j = np .nanmean (bin_avg , axis = 0 )
160- mean_j = mean_j [~ np .isnan (mean_j )]
161-
162- var_i = _weighted_var (mean_i , weights = bin_counts_i )
163- var_j = _weighted_var (mean_j , weights = bin_counts_j )
164-
165- soe [i , j ] = (var_ij - var_i - var_j ) / var_y
166-
167- soe = np .where (soe == 0 , soe .T , soe )
168- si [i ] = foe [i ] + soe [:, i ].sum () / 2
156+ # Marginal Var(E[Y|Xi]) using n_bins_soe to match MATLAB logic
157+ bin_avg_i_soe , _ , binnumber_i_soe = stats .binned_statistic (
158+ x = xi , values = output , bins = n_bins_soe , statistic = "mean"
159+ )
160+ mask_i = ~ np .isnan (bin_avg_i_soe )
161+ counts_i = np .unique (
162+ binnumber_i_soe [binnumber_i_soe > 0 ], return_counts = True
163+ )[1 ]
164+ var_i_soe = _weighted_var (bin_avg_i_soe [mask_i ], weights = counts_i )
165+
166+ # Marginal Var(E[Y|Xj]) using n_bins_soe to match MATLAB logic
167+ bin_avg_j_soe , _ , binnumber_j_soe = stats .binned_statistic (
168+ x = xj , values = output , bins = n_bins_soe , statistic = "mean"
169+ )
170+ mask_j = ~ np .isnan (bin_avg_j_soe )
171+ counts_j = np .unique (
172+ binnumber_j_soe [binnumber_j_soe > 0 ], return_counts = True
173+ )[1 ]
174+ var_j_soe = _weighted_var (bin_avg_j_soe [mask_j ], weights = counts_j )
175+
176+ soe [i , j ] = (var_ij - var_i_soe - var_j_soe ) / var_y
177+
178+ # Mirror SOE and calculate Combined Effect (SI)
179+ # SI is FOE + half of all interactions associated with that variable
180+ soe = soe + soe .T
181+ for k in range (n_factors ):
182+ si [k ] = foe [k ] + (soe [:, k ].sum () / 2 )
169183
170184 return SensitivityAnalysisResult (si , foe , soe )
0 commit comments