Skip to content

Commit

Permalink
Merge pull request #76 from EducationalTestingService/feature/choose-…
Browse files Browse the repository at this point in the history
…svd-method

Add multiple SVD methods and address other issues
  • Loading branch information
desilinguist authored Mar 26, 2021
2 parents 38827d5 + 2bfd6ed commit 289450b
Show file tree
Hide file tree
Showing 11 changed files with 403 additions and 62 deletions.
87 changes: 65 additions & 22 deletions factor_analyzer/factor_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sklearn.utils import check_array
from sklearn.utils.validation import check_is_fitted

POSSIBLE_SVDS = ['randomized', 'lapack']

POSSIBLE_IMPUTATIONS = ['mean', 'median', 'drop']

Expand Down Expand Up @@ -176,10 +177,19 @@ class FactorAnalyzer(BaseEstimator, TransformerMixin):
If missing values are present in the data, either use
list-wise deletion ('drop') or impute the column median
('median') or column mean ('mean').
Defaults to 'median'
use_corr_matrix : bool, optional
Set to true if the `data` is the correlation
matrix.
Defaults to False.
svd_method : {‘lapack’, ‘randomized’}
The SVD method to use when ``method='principal'``.
If 'lapack', use standard SVD from ``scipy.linalg``.
If 'randomized', use faster ``randomized_svd``
function from scikit-learn. The latter should only
be used if the number of columns is greater than or
equal to the number of rows in in the dataset.
Defaults to 'randomized'
rotation_kwargs, optional
Additional key word arguments
are passed to the rotation method.
Expand Down Expand Up @@ -249,32 +259,18 @@ def __init__(self,
is_corr_matrix=False,
bounds=(0.005, 1),
impute='median',
svd_method='randomized',
rotation_kwargs=None):

rotation = rotation.lower() if isinstance(rotation, str) else rotation
if rotation not in POSSIBLE_ROTATIONS + [None]:
raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}")

method = method.lower()
if method not in POSSIBLE_METHODS:
raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS + [None]}")

impute = impute.lower()
if impute not in POSSIBLE_IMPUTATIONS:
raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS + [None]}")

if method == 'principal' and is_corr_matrix:
raise ValueError('The principal method is only implemented using '
'the full data set, not the correlation matrix.')

self.n_factors = n_factors
self.rotation = rotation
self.method = method
self.use_smc = use_smc
self.bounds = bounds
self.impute = impute
self.is_corr_matrix = is_corr_matrix
self.rotation_kwargs = {} if rotation_kwargs is None else rotation_kwargs
self.svd_method = svd_method
self.rotation_kwargs = rotation_kwargs

# default matrices to None
self.mean_ = None
Expand All @@ -288,6 +284,34 @@ def __init__(self,
self.rotation_matrix_ = None
self.weights_ = None

def _arg_checker(self):
"""
Check the input parameters to make sure they're properly formattted.
We need to do this to ensure that the FactorAnalyzer class can be properly
cloned when used with grid search CV, for example.
"""
self.rotation = self.rotation.lower() if isinstance(self.rotation, str) else self.rotation
if self.rotation not in POSSIBLE_ROTATIONS + [None]:
raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}")

self.method = self.method.lower() if isinstance(self.method, str) else self.method
if self.method not in POSSIBLE_METHODS:
raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS}")

self.impute = self.impute.lower() if isinstance(self.impute, str) else self.impute
if self.impute not in POSSIBLE_IMPUTATIONS:
raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS}")

self.svd_method = self.svd_method.lower() if isinstance(self.svd_method, str) else self.svd_method
if self.svd_method not in POSSIBLE_SVDS:
raise ValueError(f"The SVD method must be one of the following: {POSSIBLE_SVDS}")

if self.method == 'principal' and self.is_corr_matrix:
raise ValueError('The principal method is only implemented using '
'the full data set, not the correlation matrix.')

self.rotation_kwargs = {} if self.rotation_kwargs is None else self.rotation_kwargs

@staticmethod
def _fit_uls_objective(psi, corr_mtx, n_factors):
"""
Expand Down Expand Up @@ -472,8 +496,21 @@ def _fit_principal(self, X):
X = X.copy()
X = (X - X.mean(0)) / X.std(0)

# if the number of rows is less than the number of columns,
# warn the user that the number of factors will be constrained
nrows, ncols = X.shape
if nrows < ncols and self.n_factors >= nrows:
warnings.warn('The number of factors will be '
'constrained to min(n_samples, n_features)'
'={}.'.format(min(nrows, ncols)))

# perform the randomized singular value decomposition
U, S, V = randomized_svd(X, self.n_factors)
if self.svd_method == 'randomized':
U, S, V = randomized_svd(X, self.n_factors)
# otherwise, perform the full SVD
else:
U, S, V = np.linalg.svd(X, full_matrices=False)

corr_mtx = np.dot(X, V.T)
loadings = np.array([[pearsonr(x, c)[0] for c in corr_mtx.T] for x in X.T])
return loadings
Expand Down Expand Up @@ -577,6 +614,9 @@ def fit(self, X, y=None):
[ 0.81533404, -0.12494695, 0.17639683]])
"""

# check the input arguments
self._arg_checker()

# check if the data is a data frame,
# so we can convert it to an array
if isinstance(X, pd.DataFrame):
Expand Down Expand Up @@ -650,11 +690,14 @@ def fit(self, X, y=None):
phi = np.dot(np.dot(np.diag(signs), phi), np.diag(signs))
structure = np.dot(loadings, phi) if self.rotation in OBLIQUE_ROTATIONS else None

# resort the factors according to their variance
variance = self._get_factor_variance(loadings)[0]
new_order = list(reversed(np.argsort(variance)))
loadings = loadings[:, new_order].copy()
# resort the factors according to their variance,
# unless the method is principal
if self.method != 'principal':
variance = self._get_factor_variance(loadings)[0]
new_order = list(reversed(np.argsort(variance)))
loadings = loadings[:, new_order].copy()

# if the structure matrix exists, reorder
if structure is not None:
structure = structure[:, new_order].copy()

Expand Down
18 changes: 15 additions & 3 deletions factor_analyzer/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def calculate_py_output(test_name,
factors,
method,
rotation,
svd_method='randomized',
use_corr_matrix=False,
top_dir=None):
"""
Expand All @@ -54,6 +55,9 @@ def calculate_py_output(test_name,
The rotation method
rotation : str
The type of rotation
svd_method : str, optional
The SVD method to use
Defaults to 'randomized'
use_corr_matrix : bool, optional
Whether to use the correlation matrix.
Defaults to False.
Expand Down Expand Up @@ -81,7 +85,7 @@ def calculate_py_output(test_name,
rotation = None if rotation == 'none' else rotation
method = {'uls': 'minres'}.get(method, method)

fa = FactorAnalyzer(n_factors=factors, method=method,
fa = FactorAnalyzer(n_factors=factors, method=method, svd_method=svd_method,
rotation=rotation, is_corr_matrix=use_corr_matrix)
fa.fit(X)

Expand Down Expand Up @@ -228,6 +232,11 @@ def check_close(data1, data2, rel_tol=0.0, abs_tol=0.1,
data1 = normalize(data1, absolute)
data2 = normalize(data2, absolute)

print(data1)
print()
print(data2)
print('------')

err_msg = 'r - py: {} != {}'
assert data1.shape == data2.shape, err_msg.format(data1.shape, data2.shape)

Expand All @@ -253,6 +262,7 @@ def check_scenario(test_name,
check_scores=False,
check_structure=False,
use_corr_matrix=False,
svd_method='randomized',
data_dir=None,
expected_dir=None,
rel_tol=0,
Expand Down Expand Up @@ -321,8 +331,10 @@ def check_scenario(test_name,
if check_structure:
output_types.append('structure')

r_output = collect_r_output(test_name, factors, method, rotation, output_types, expected_dir)
py_output = calculate_py_output(test_name, factors, method, rotation, use_corr_matrix, data_dir)
r_output = collect_r_output(test_name, factors, method, rotation,
output_types, expected_dir)
py_output = calculate_py_output(test_name, factors, method, rotation, svd_method,
use_corr_matrix, data_dir)

for output_type in output_types:

Expand Down
10 changes: 10 additions & 0 deletions tests/data/test15.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
pers01,pers02,pers03,pers04,pers05,pers06,pers07,pers08,pers09,pers10,pers11,pers12,pers13,pers14,pers15,pers16,pers17,pers18,pers19,pers20,pers21,pers22,pers23,pers24,pers25,pers26,pers27,pers28,pers29,pers30,pers31,pers32,pers33,pers34,pers35,pers36,pers37,pers38,pers39,pers40,pers41,pers42,pers43,pers44
5,4,5,1,4,3,3,1,2,3,2,4,5,4,4,3,5,1,3,4,2,4,2,3,2,5,4,5,4,2,5,4,5,3,1,3,4,4,4,3,3,3,5,4
1,1,5,2,1,2,5,1,5,1,5,3,5,4,1,2,1,3,5,1,5,5,1,3,1,2,3,5,3,3,5,5,5,5,5,3,1,1,2,3,3,5,2,1
4,1,5,3,3,4,5,3,1,4,2,1,5,4,3,2,5,1,5,4,4,4,3,2,4,1,3,4,5,4,3,3,5,5,4,2,3,1,5,4,2,3,5,3
4,2,5,1,4,3,4,4,4,5,4,1,4,5,3,3,4,2,5,4,4,4,3,2,4,2,1,5,3,4,4,4,4,4,3,3,3,4,5,5,3,5,2,4
2,3,5,1,2,4,5,2,3,3,4,2,5,3,3,3,5,1,4,2,5,5,1,3,2,2,4,4,4,3,4,4,5,4,4,2,5,5,4,3,2,4,3,2
1,1,5,4,3,4,4,2,1,4,3,3,5,5,3,2,4,1,5,3,5,4,1,3,3,1,2,5,5,4,5,3,4,2,3,1,1,3,5,4,2,4,3,5
3,2,5,1,2,1,1,2,5,4,4,1,5,1,1,4,5,5,1,3,2,5,1,4,2,1,1,5,1,2,1,5,4,5,5,5,2,5,1,1,4,5,1,1
5,2,4,2,4,1,4,3,3,5,4,1,4,3,2,4,4,4,2,5,2,4,3,3,1,1,3,4,3,3,2,5,4,4,2,4,2,4,2,3,2,5,4,4
5,1,4,3,2,1,4,4,2,3,4,1,4,4,2,5,5,5,4,4,2,5,2,4,2,4,4,4,3,5,2,5,4,2,4,5,1,2,4,3,2,5,4,2
45 changes: 45 additions & 0 deletions tests/expected/test15/communalities_principal_none_20_test15.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
,x
pers01,2.147636456
pers02,3.05860696
pers03,2.634980867
pers04,4.167044602
pers05,3.239628161
pers06,2.855673949
pers07,2.846150806
pers08,2.342990801
pers09,3.030327258
pers10,2.975790993
pers11,2.395566544
pers12,2.220826511
pers13,2.038321064
pers14,3.265232908
pers15,2.425621063
pers16,3.061295019
pers17,2.017025954
pers18,3.049338556
pers19,3.122061642
pers20,2.728356141
pers21,2.594510111
pers22,3.782658794
pers23,3.070056277
pers24,4.111231609
pers25,3.064079366
pers26,1.832456264
pers27,2.162362221
pers28,1.758259243
pers29,3.225984183
pers30,3.166512309
pers31,2.790080337
pers32,3.175142177
pers33,2.391267257
pers34,2.403458399
pers35,2.514121013
pers36,2.983086234
pers37,2.746601113
pers38,3.495223869
pers39,2.864928829
pers40,3.571532646
pers41,2.654887486
pers42,2.126933791
pers43,2.482551325
pers44,3.105868335
45 changes: 45 additions & 0 deletions tests/expected/test15/evalues_principal_none_20_test15.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
,x
1,14.25162169
2,9.294533097
3,6.498852485
4,4.620861291
5,3.059319317
6,2.625005382
7,1.991825606
8,1.657981133
9,2.69E-15
10,1.67E-15
11,1.22E-15
12,9.84E-16
13,6.64E-16
14,5.94E-16
15,5.04E-16
16,4.43E-16
17,4.28E-16
18,3.86E-16
19,3.62E-16
20,2.83E-16
21,1.92E-16
22,1.41E-16
23,1.16E-16
24,7.91E-17
25,2.42E-17
26,1.45E-17
27,-1.52E-17
28,-3.58E-17
29,-9.31E-17
30,-2.07E-16
31,-2.23E-16
32,-2.53E-16
33,-3.39E-16
34,-4.00E-16
35,-4.38E-16
36,-4.54E-16
37,-5.18E-16
38,-5.69E-16
39,-7.36E-16
40,-8.59E-16
41,-9.67E-16
42,-1.20E-15
43,-1.50E-15
44,-3.05E-15
45 changes: 45 additions & 0 deletions tests/expected/test15/loading_principal_none_20_test15.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
,MR1,MR2,MR3,MR4,MR5,MR6,MR7,MR8,MR9
pers01,-0.04828462,0.872685807,0.239695145,0.1783874,-0.070266058,0.233832136,-0.127745213,0.266230396,0.109949125
pers02,0.065050861,0.087111805,0.917108625,0.09829389,0.045889312,0.259212544,0.240458465,-0.101551737,-0.438491274
pers03,0.335344379,-0.692295104,0.352908838,-0.448246064,-0.105717303,-0.075752288,0.072948636,0.246102405,-0.026867838
pers04,0.369038127,0.065051128,-0.659481751,0.216562195,0.122770915,-0.550263386,-0.238255766,-0.055995953,0.650878605
pers05,0.460360115,0.673786618,0.357376302,-0.321981352,0.22082246,0.141451804,-0.140351837,-0.119249891,0.088835419
pers06,0.87171974,-0.336822521,0.112582984,-0.145152604,-0.211905593,-0.080225718,0.198171197,-0.047949025,0.530776467
pers07,0.576151802,-0.206133915,-0.510224335,0.311050617,-0.18298515,0.43465408,0.004627584,-0.214586587,0.581738459
pers08,0.013465102,0.72669014,-0.534236413,-0.20871943,-0.19468342,0.119898964,0.269297945,0.134048523,0.395215471
pers09,-0.748916302,-0.361166248,0.065624042,-0.340057899,0.067146842,0.42644553,0.043870457,0.021173528,-0.771983064
pers10,0.137766271,0.729061174,0.055714597,-0.601408949,-0.108200226,-0.133706349,0.019605779,-0.233932375,0.119752545
pers11,-0.675924401,-0.273483264,-0.433155063,-0.064071358,0.143265802,0.359476466,0.248424031,-0.255228802,-0.544707831
pers12,0.355725628,-0.531383218,0.427897805,0.322997529,0.543285483,-0.056315934,-0.072021539,-0.012227634,-0.211014824
pers13,0.191984172,-0.752696697,0.397452888,0.0537552,-0.217179515,-0.397121986,-0.166976097,0.05456656,0.03461962
pers14,0.761628997,0.018888667,-0.389796375,0.02405038,0.439146501,0.204641161,0.08654672,0.15756543,0.476955155
pers15,0.841854905,0.251761235,0.412010313,0.053732916,0.042505833,0.029407325,0.226082677,0.038366163,0.402422072
pers16,-0.64430474,0.655136206,0.025197088,0.259591771,0.01078642,-0.024935079,0.293030345,0.032253453,-0.309892468
pers17,0.133885448,0.599835612,0.405812074,0.039838088,-0.411211008,-0.402493253,0.338007923,0.103231512,0.273106526
pers18,-0.899018777,0.306544161,-0.263593624,0.086024622,0.073959402,-0.061269493,-0.070891304,0.081636764,-0.419225537
pers19,0.690934528,-0.329647629,-0.551624951,0.0182762,0.064019853,0.196192485,0.15420495,0.207240475,0.606635059
pers20,0.155163821,0.957389706,0.108910582,-0.089609788,0.022891194,-0.02826736,-0.193441382,0.026364585,0.239325582
pers21,0.485799705,-0.70406665,-0.364281338,-0.193071399,-0.097510953,0.090140204,0.1946385,-0.206866024,0.285201141
pers22,-0.681026118,-0.43653348,-0.152606385,0.349435475,-0.227476649,-0.024178065,0.372173269,0.096979075,-0.32989024
pers23,0.298251743,0.738460588,-0.104605406,-0.18269442,-0.149645341,0.429224993,-0.312246211,0.131441227,0.397777692
pers24,-0.735050328,0.022838942,0.070148594,0.42301257,0.17584153,-0.452732468,0.197949991,-0.015741936,-0.467726762
pers25,0.612945178,0.179319088,-0.143944907,-0.536182675,-0.246095779,-0.144123534,0.178983737,0.412990281,0.615785245
pers26,0.075493078,0.164186546,0.361855321,0.60623896,0.363141969,0.212132397,0.266854716,0.469888221,-0.107909185
pers27,0.222367812,0.025808466,0.118818476,0.937412093,-0.130937606,0.1757883,0.031728968,-0.089298973,0.271281581
pers28,-0.066730222,-0.434588822,0.269651574,-0.478297293,0.639249191,-0.095051394,-0.032966158,0.294003909,-0.528846167
pers29,0.943968963,-0.072840218,-0.105035164,0.233586597,-0.059723312,-0.116892943,-0.06291015,-0.129743188,0.741046671
pers30,0.329973251,0.309685066,-0.828180961,0.108154302,-0.016580396,-0.057603016,0.255376313,0.169769485,0.679006138
pers31,0.663368245,-0.546831933,0.099996746,0.073101125,0.436956147,0.219166869,0.07969368,-0.016075172,0.068306007
pers32,-0.896103875,0.087076721,-0.032225109,0.233865369,0.167003644,0.322598699,-0.024188096,-0.033746387,-0.67512874
pers33,0.368927175,-0.55505924,0.330315318,0.404087257,-0.358903058,0.316887491,-0.186347221,0.139473596,0.227565878
pers34,-0.297896347,-0.356547533,0.108006974,-0.380935237,-0.55636427,0.376008538,-0.417555873,0.045584353,-0.143322194
pers35,-0.466026995,-0.521274333,-0.486631795,-0.090454241,-0.437389454,-0.144597074,0.10906771,0.204904425,0.03733149
pers36,-0.849143547,0.423510789,0.024744724,0.136594471,0.030216843,0.09148576,-0.026121108,0.265251737,-0.457345574
pers37,0.341668592,0.007337606,0.66756131,-0.008476775,-0.46117769,0.331304008,0.328795419,-0.083337506,0.029768105
pers38,-0.263669214,0.20346738,0.640196465,-0.321078152,0.009810988,-0.038680089,0.470802886,-0.391010377,-0.606810668
pers39,0.882113868,0.17203835,-0.16887364,-0.023210193,-0.013321485,-0.040546264,0.34238953,0.210163886,0.712370485
pers40,0.791007623,0.146938,-0.379995394,-0.239363727,0.157728703,0.334133052,0.109024279,0.051138425,0.546598278
pers41,-0.56246979,-0.273713014,0.436724979,-0.447447065,0.177598914,0.035751419,-0.068604277,0.424553004,-0.688649095
pers42,-0.742317116,0.071452222,-0.469083975,-0.218919485,0.263651376,0.217525709,0.182249823,-0.160778989,-0.53833395
pers43,0.587977857,0.446104906,0.144926452,0.575703527,-0.127295674,0.010140071,-0.290820054,0.044169119,0.615268623
pers44,0.697943291,0.446797496,0.096467909,-0.230890491,0.408446905,-0.09793622,-0.066987721,-0.264051038,0.260242219
Loading

0 comments on commit 289450b

Please sign in to comment.