diff --git a/factor_analyzer/factor_analyzer.py b/factor_analyzer/factor_analyzer.py index 93f9064..8804ff7 100644 --- a/factor_analyzer/factor_analyzer.py +++ b/factor_analyzer/factor_analyzer.py @@ -31,6 +31,7 @@ from sklearn.utils import check_array from sklearn.utils.validation import check_is_fitted +POSSIBLE_SVDS = ['randomized', 'lapack'] POSSIBLE_IMPUTATIONS = ['mean', 'median', 'drop'] @@ -176,10 +177,19 @@ class FactorAnalyzer(BaseEstimator, TransformerMixin): If missing values are present in the data, either use list-wise deletion ('drop') or impute the column median ('median') or column mean ('mean'). + Defaults to 'median' use_corr_matrix : bool, optional Set to true if the `data` is the correlation matrix. Defaults to False. + svd_method : {‘lapack’, ‘randomized’} + The SVD method to use when ``method='principal'``. + If 'lapack', use standard SVD from ``scipy.linalg``. + If 'randomized', use faster ``randomized_svd`` + function from scikit-learn. The latter should only + be used if the number of columns is greater than or + equal to the number of rows in in the dataset. + Defaults to 'randomized' rotation_kwargs, optional Additional key word arguments are passed to the rotation method. @@ -249,24 +259,9 @@ def __init__(self, is_corr_matrix=False, bounds=(0.005, 1), impute='median', + svd_method='randomized', rotation_kwargs=None): - rotation = rotation.lower() if isinstance(rotation, str) else rotation - if rotation not in POSSIBLE_ROTATIONS + [None]: - raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}") - - method = method.lower() - if method not in POSSIBLE_METHODS: - raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS + [None]}") - - impute = impute.lower() - if impute not in POSSIBLE_IMPUTATIONS: - raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS + [None]}") - - if method == 'principal' and is_corr_matrix: - raise ValueError('The principal method is only implemented using ' - 'the full data set, not the correlation matrix.') - self.n_factors = n_factors self.rotation = rotation self.method = method @@ -274,7 +269,8 @@ def __init__(self, self.bounds = bounds self.impute = impute self.is_corr_matrix = is_corr_matrix - self.rotation_kwargs = {} if rotation_kwargs is None else rotation_kwargs + self.svd_method = svd_method + self.rotation_kwargs = rotation_kwargs # default matrices to None self.mean_ = None @@ -288,6 +284,34 @@ def __init__(self, self.rotation_matrix_ = None self.weights_ = None + def _arg_checker(self): + """ + Check the input parameters to make sure they're properly formattted. + We need to do this to ensure that the FactorAnalyzer class can be properly + cloned when used with grid search CV, for example. + """ + self.rotation = self.rotation.lower() if isinstance(self.rotation, str) else self.rotation + if self.rotation not in POSSIBLE_ROTATIONS + [None]: + raise ValueError(f"The rotation must be one of the following: {POSSIBLE_ROTATIONS + [None]}") + + self.method = self.method.lower() if isinstance(self.method, str) else self.method + if self.method not in POSSIBLE_METHODS: + raise ValueError(f"The method must be one of the following: {POSSIBLE_METHODS}") + + self.impute = self.impute.lower() if isinstance(self.impute, str) else self.impute + if self.impute not in POSSIBLE_IMPUTATIONS: + raise ValueError(f"The imputation must be one of the following: {POSSIBLE_IMPUTATIONS}") + + self.svd_method = self.svd_method.lower() if isinstance(self.svd_method, str) else self.svd_method + if self.svd_method not in POSSIBLE_SVDS: + raise ValueError(f"The SVD method must be one of the following: {POSSIBLE_SVDS}") + + if self.method == 'principal' and self.is_corr_matrix: + raise ValueError('The principal method is only implemented using ' + 'the full data set, not the correlation matrix.') + + self.rotation_kwargs = {} if self.rotation_kwargs is None else self.rotation_kwargs + @staticmethod def _fit_uls_objective(psi, corr_mtx, n_factors): """ @@ -472,8 +496,21 @@ def _fit_principal(self, X): X = X.copy() X = (X - X.mean(0)) / X.std(0) + # if the number of rows is less than the number of columns, + # warn the user that the number of factors will be constrained + nrows, ncols = X.shape + if nrows < ncols and self.n_factors >= nrows: + warnings.warn('The number of factors will be ' + 'constrained to min(n_samples, n_features)' + '={}.'.format(min(nrows, ncols))) + # perform the randomized singular value decomposition - U, S, V = randomized_svd(X, self.n_factors) + if self.svd_method == 'randomized': + U, S, V = randomized_svd(X, self.n_factors) + # otherwise, perform the full SVD + else: + U, S, V = np.linalg.svd(X, full_matrices=False) + corr_mtx = np.dot(X, V.T) loadings = np.array([[pearsonr(x, c)[0] for c in corr_mtx.T] for x in X.T]) return loadings @@ -577,6 +614,9 @@ def fit(self, X, y=None): [ 0.81533404, -0.12494695, 0.17639683]]) """ + # check the input arguments + self._arg_checker() + # check if the data is a data frame, # so we can convert it to an array if isinstance(X, pd.DataFrame): @@ -650,11 +690,14 @@ def fit(self, X, y=None): phi = np.dot(np.dot(np.diag(signs), phi), np.diag(signs)) structure = np.dot(loadings, phi) if self.rotation in OBLIQUE_ROTATIONS else None - # resort the factors according to their variance - variance = self._get_factor_variance(loadings)[0] - new_order = list(reversed(np.argsort(variance))) - loadings = loadings[:, new_order].copy() + # resort the factors according to their variance, + # unless the method is principal + if self.method != 'principal': + variance = self._get_factor_variance(loadings)[0] + new_order = list(reversed(np.argsort(variance))) + loadings = loadings[:, new_order].copy() + # if the structure matrix exists, reorder if structure is not None: structure = structure[:, new_order].copy() diff --git a/factor_analyzer/test_utils.py b/factor_analyzer/test_utils.py index 091cb7a..97ab8a9 100644 --- a/factor_analyzer/test_utils.py +++ b/factor_analyzer/test_utils.py @@ -38,6 +38,7 @@ def calculate_py_output(test_name, factors, method, rotation, + svd_method='randomized', use_corr_matrix=False, top_dir=None): """ @@ -54,6 +55,9 @@ def calculate_py_output(test_name, The rotation method rotation : str The type of rotation + svd_method : str, optional + The SVD method to use + Defaults to 'randomized' use_corr_matrix : bool, optional Whether to use the correlation matrix. Defaults to False. @@ -81,7 +85,7 @@ def calculate_py_output(test_name, rotation = None if rotation == 'none' else rotation method = {'uls': 'minres'}.get(method, method) - fa = FactorAnalyzer(n_factors=factors, method=method, + fa = FactorAnalyzer(n_factors=factors, method=method, svd_method=svd_method, rotation=rotation, is_corr_matrix=use_corr_matrix) fa.fit(X) @@ -228,6 +232,11 @@ def check_close(data1, data2, rel_tol=0.0, abs_tol=0.1, data1 = normalize(data1, absolute) data2 = normalize(data2, absolute) + print(data1) + print() + print(data2) + print('------') + err_msg = 'r - py: {} != {}' assert data1.shape == data2.shape, err_msg.format(data1.shape, data2.shape) @@ -253,6 +262,7 @@ def check_scenario(test_name, check_scores=False, check_structure=False, use_corr_matrix=False, + svd_method='randomized', data_dir=None, expected_dir=None, rel_tol=0, @@ -321,8 +331,10 @@ def check_scenario(test_name, if check_structure: output_types.append('structure') - r_output = collect_r_output(test_name, factors, method, rotation, output_types, expected_dir) - py_output = calculate_py_output(test_name, factors, method, rotation, use_corr_matrix, data_dir) + r_output = collect_r_output(test_name, factors, method, rotation, + output_types, expected_dir) + py_output = calculate_py_output(test_name, factors, method, rotation, svd_method, + use_corr_matrix, data_dir) for output_type in output_types: diff --git a/tests/data/test15.csv b/tests/data/test15.csv new file mode 100644 index 0000000..a4cedbd --- /dev/null +++ b/tests/data/test15.csv @@ -0,0 +1,10 @@ +pers01,pers02,pers03,pers04,pers05,pers06,pers07,pers08,pers09,pers10,pers11,pers12,pers13,pers14,pers15,pers16,pers17,pers18,pers19,pers20,pers21,pers22,pers23,pers24,pers25,pers26,pers27,pers28,pers29,pers30,pers31,pers32,pers33,pers34,pers35,pers36,pers37,pers38,pers39,pers40,pers41,pers42,pers43,pers44 +5,4,5,1,4,3,3,1,2,3,2,4,5,4,4,3,5,1,3,4,2,4,2,3,2,5,4,5,4,2,5,4,5,3,1,3,4,4,4,3,3,3,5,4 +1,1,5,2,1,2,5,1,5,1,5,3,5,4,1,2,1,3,5,1,5,5,1,3,1,2,3,5,3,3,5,5,5,5,5,3,1,1,2,3,3,5,2,1 +4,1,5,3,3,4,5,3,1,4,2,1,5,4,3,2,5,1,5,4,4,4,3,2,4,1,3,4,5,4,3,3,5,5,4,2,3,1,5,4,2,3,5,3 +4,2,5,1,4,3,4,4,4,5,4,1,4,5,3,3,4,2,5,4,4,4,3,2,4,2,1,5,3,4,4,4,4,4,3,3,3,4,5,5,3,5,2,4 +2,3,5,1,2,4,5,2,3,3,4,2,5,3,3,3,5,1,4,2,5,5,1,3,2,2,4,4,4,3,4,4,5,4,4,2,5,5,4,3,2,4,3,2 +1,1,5,4,3,4,4,2,1,4,3,3,5,5,3,2,4,1,5,3,5,4,1,3,3,1,2,5,5,4,5,3,4,2,3,1,1,3,5,4,2,4,3,5 +3,2,5,1,2,1,1,2,5,4,4,1,5,1,1,4,5,5,1,3,2,5,1,4,2,1,1,5,1,2,1,5,4,5,5,5,2,5,1,1,4,5,1,1 +5,2,4,2,4,1,4,3,3,5,4,1,4,3,2,4,4,4,2,5,2,4,3,3,1,1,3,4,3,3,2,5,4,4,2,4,2,4,2,3,2,5,4,4 +5,1,4,3,2,1,4,4,2,3,4,1,4,4,2,5,5,5,4,4,2,5,2,4,2,4,4,4,3,5,2,5,4,2,4,5,1,2,4,3,2,5,4,2 \ No newline at end of file diff --git a/tests/expected/test15/communalities_principal_none_20_test15.csv b/tests/expected/test15/communalities_principal_none_20_test15.csv new file mode 100644 index 0000000..a32bd14 --- /dev/null +++ b/tests/expected/test15/communalities_principal_none_20_test15.csv @@ -0,0 +1,45 @@ +,x +pers01,2.147636456 +pers02,3.05860696 +pers03,2.634980867 +pers04,4.167044602 +pers05,3.239628161 +pers06,2.855673949 +pers07,2.846150806 +pers08,2.342990801 +pers09,3.030327258 +pers10,2.975790993 +pers11,2.395566544 +pers12,2.220826511 +pers13,2.038321064 +pers14,3.265232908 +pers15,2.425621063 +pers16,3.061295019 +pers17,2.017025954 +pers18,3.049338556 +pers19,3.122061642 +pers20,2.728356141 +pers21,2.594510111 +pers22,3.782658794 +pers23,3.070056277 +pers24,4.111231609 +pers25,3.064079366 +pers26,1.832456264 +pers27,2.162362221 +pers28,1.758259243 +pers29,3.225984183 +pers30,3.166512309 +pers31,2.790080337 +pers32,3.175142177 +pers33,2.391267257 +pers34,2.403458399 +pers35,2.514121013 +pers36,2.983086234 +pers37,2.746601113 +pers38,3.495223869 +pers39,2.864928829 +pers40,3.571532646 +pers41,2.654887486 +pers42,2.126933791 +pers43,2.482551325 +pers44,3.105868335 \ No newline at end of file diff --git a/tests/expected/test15/evalues_principal_none_20_test15.csv b/tests/expected/test15/evalues_principal_none_20_test15.csv new file mode 100644 index 0000000..81eafc8 --- /dev/null +++ b/tests/expected/test15/evalues_principal_none_20_test15.csv @@ -0,0 +1,45 @@ +,x +1,14.25162169 +2,9.294533097 +3,6.498852485 +4,4.620861291 +5,3.059319317 +6,2.625005382 +7,1.991825606 +8,1.657981133 +9,2.69E-15 +10,1.67E-15 +11,1.22E-15 +12,9.84E-16 +13,6.64E-16 +14,5.94E-16 +15,5.04E-16 +16,4.43E-16 +17,4.28E-16 +18,3.86E-16 +19,3.62E-16 +20,2.83E-16 +21,1.92E-16 +22,1.41E-16 +23,1.16E-16 +24,7.91E-17 +25,2.42E-17 +26,1.45E-17 +27,-1.52E-17 +28,-3.58E-17 +29,-9.31E-17 +30,-2.07E-16 +31,-2.23E-16 +32,-2.53E-16 +33,-3.39E-16 +34,-4.00E-16 +35,-4.38E-16 +36,-4.54E-16 +37,-5.18E-16 +38,-5.69E-16 +39,-7.36E-16 +40,-8.59E-16 +41,-9.67E-16 +42,-1.20E-15 +43,-1.50E-15 +44,-3.05E-15 \ No newline at end of file diff --git a/tests/expected/test15/loading_principal_none_20_test15.csv b/tests/expected/test15/loading_principal_none_20_test15.csv new file mode 100644 index 0000000..14be558 --- /dev/null +++ b/tests/expected/test15/loading_principal_none_20_test15.csv @@ -0,0 +1,45 @@ +,MR1,MR2,MR3,MR4,MR5,MR6,MR7,MR8,MR9 +pers01,-0.04828462,0.872685807,0.239695145,0.1783874,-0.070266058,0.233832136,-0.127745213,0.266230396,0.109949125 +pers02,0.065050861,0.087111805,0.917108625,0.09829389,0.045889312,0.259212544,0.240458465,-0.101551737,-0.438491274 +pers03,0.335344379,-0.692295104,0.352908838,-0.448246064,-0.105717303,-0.075752288,0.072948636,0.246102405,-0.026867838 +pers04,0.369038127,0.065051128,-0.659481751,0.216562195,0.122770915,-0.550263386,-0.238255766,-0.055995953,0.650878605 +pers05,0.460360115,0.673786618,0.357376302,-0.321981352,0.22082246,0.141451804,-0.140351837,-0.119249891,0.088835419 +pers06,0.87171974,-0.336822521,0.112582984,-0.145152604,-0.211905593,-0.080225718,0.198171197,-0.047949025,0.530776467 +pers07,0.576151802,-0.206133915,-0.510224335,0.311050617,-0.18298515,0.43465408,0.004627584,-0.214586587,0.581738459 +pers08,0.013465102,0.72669014,-0.534236413,-0.20871943,-0.19468342,0.119898964,0.269297945,0.134048523,0.395215471 +pers09,-0.748916302,-0.361166248,0.065624042,-0.340057899,0.067146842,0.42644553,0.043870457,0.021173528,-0.771983064 +pers10,0.137766271,0.729061174,0.055714597,-0.601408949,-0.108200226,-0.133706349,0.019605779,-0.233932375,0.119752545 +pers11,-0.675924401,-0.273483264,-0.433155063,-0.064071358,0.143265802,0.359476466,0.248424031,-0.255228802,-0.544707831 +pers12,0.355725628,-0.531383218,0.427897805,0.322997529,0.543285483,-0.056315934,-0.072021539,-0.012227634,-0.211014824 +pers13,0.191984172,-0.752696697,0.397452888,0.0537552,-0.217179515,-0.397121986,-0.166976097,0.05456656,0.03461962 +pers14,0.761628997,0.018888667,-0.389796375,0.02405038,0.439146501,0.204641161,0.08654672,0.15756543,0.476955155 +pers15,0.841854905,0.251761235,0.412010313,0.053732916,0.042505833,0.029407325,0.226082677,0.038366163,0.402422072 +pers16,-0.64430474,0.655136206,0.025197088,0.259591771,0.01078642,-0.024935079,0.293030345,0.032253453,-0.309892468 +pers17,0.133885448,0.599835612,0.405812074,0.039838088,-0.411211008,-0.402493253,0.338007923,0.103231512,0.273106526 +pers18,-0.899018777,0.306544161,-0.263593624,0.086024622,0.073959402,-0.061269493,-0.070891304,0.081636764,-0.419225537 +pers19,0.690934528,-0.329647629,-0.551624951,0.0182762,0.064019853,0.196192485,0.15420495,0.207240475,0.606635059 +pers20,0.155163821,0.957389706,0.108910582,-0.089609788,0.022891194,-0.02826736,-0.193441382,0.026364585,0.239325582 +pers21,0.485799705,-0.70406665,-0.364281338,-0.193071399,-0.097510953,0.090140204,0.1946385,-0.206866024,0.285201141 +pers22,-0.681026118,-0.43653348,-0.152606385,0.349435475,-0.227476649,-0.024178065,0.372173269,0.096979075,-0.32989024 +pers23,0.298251743,0.738460588,-0.104605406,-0.18269442,-0.149645341,0.429224993,-0.312246211,0.131441227,0.397777692 +pers24,-0.735050328,0.022838942,0.070148594,0.42301257,0.17584153,-0.452732468,0.197949991,-0.015741936,-0.467726762 +pers25,0.612945178,0.179319088,-0.143944907,-0.536182675,-0.246095779,-0.144123534,0.178983737,0.412990281,0.615785245 +pers26,0.075493078,0.164186546,0.361855321,0.60623896,0.363141969,0.212132397,0.266854716,0.469888221,-0.107909185 +pers27,0.222367812,0.025808466,0.118818476,0.937412093,-0.130937606,0.1757883,0.031728968,-0.089298973,0.271281581 +pers28,-0.066730222,-0.434588822,0.269651574,-0.478297293,0.639249191,-0.095051394,-0.032966158,0.294003909,-0.528846167 +pers29,0.943968963,-0.072840218,-0.105035164,0.233586597,-0.059723312,-0.116892943,-0.06291015,-0.129743188,0.741046671 +pers30,0.329973251,0.309685066,-0.828180961,0.108154302,-0.016580396,-0.057603016,0.255376313,0.169769485,0.679006138 +pers31,0.663368245,-0.546831933,0.099996746,0.073101125,0.436956147,0.219166869,0.07969368,-0.016075172,0.068306007 +pers32,-0.896103875,0.087076721,-0.032225109,0.233865369,0.167003644,0.322598699,-0.024188096,-0.033746387,-0.67512874 +pers33,0.368927175,-0.55505924,0.330315318,0.404087257,-0.358903058,0.316887491,-0.186347221,0.139473596,0.227565878 +pers34,-0.297896347,-0.356547533,0.108006974,-0.380935237,-0.55636427,0.376008538,-0.417555873,0.045584353,-0.143322194 +pers35,-0.466026995,-0.521274333,-0.486631795,-0.090454241,-0.437389454,-0.144597074,0.10906771,0.204904425,0.03733149 +pers36,-0.849143547,0.423510789,0.024744724,0.136594471,0.030216843,0.09148576,-0.026121108,0.265251737,-0.457345574 +pers37,0.341668592,0.007337606,0.66756131,-0.008476775,-0.46117769,0.331304008,0.328795419,-0.083337506,0.029768105 +pers38,-0.263669214,0.20346738,0.640196465,-0.321078152,0.009810988,-0.038680089,0.470802886,-0.391010377,-0.606810668 +pers39,0.882113868,0.17203835,-0.16887364,-0.023210193,-0.013321485,-0.040546264,0.34238953,0.210163886,0.712370485 +pers40,0.791007623,0.146938,-0.379995394,-0.239363727,0.157728703,0.334133052,0.109024279,0.051138425,0.546598278 +pers41,-0.56246979,-0.273713014,0.436724979,-0.447447065,0.177598914,0.035751419,-0.068604277,0.424553004,-0.688649095 +pers42,-0.742317116,0.071452222,-0.469083975,-0.218919485,0.263651376,0.217525709,0.182249823,-0.160778989,-0.53833395 +pers43,0.587977857,0.446104906,0.144926452,0.575703527,-0.127295674,0.010140071,-0.290820054,0.044169119,0.615268623 +pers44,0.697943291,0.446797496,0.096467909,-0.230890491,0.408446905,-0.09793622,-0.066987721,-0.264051038,0.260242219 \ No newline at end of file diff --git a/tests/expected/test15/uniquenesses_principal_none_20_test15.csv b/tests/expected/test15/uniquenesses_principal_none_20_test15.csv new file mode 100644 index 0000000..81d7f07 --- /dev/null +++ b/tests/expected/test15/uniquenesses_principal_none_20_test15.csv @@ -0,0 +1,45 @@ +,x +pers01,-1.147636456 +pers02,-2.05860696 +pers03,-1.634980867 +pers04,-3.167044602 +pers05,-2.239628161 +pers06,-1.855673949 +pers07,-1.846150806 +pers08,-1.342990801 +pers09,-2.030327258 +pers10,-1.975790993 +pers11,-1.395566544 +pers12,-1.220826511 +pers13,-1.038321064 +pers14,-2.265232908 +pers15,-1.425621063 +pers16,-2.061295019 +pers17,-1.017025954 +pers18,-2.049338556 +pers19,-2.122061642 +pers20,-1.728356141 +pers21,-1.594510111 +pers22,-2.782658794 +pers23,-2.070056277 +pers24,-3.111231609 +pers25,-2.064079366 +pers26,-0.832456264 +pers27,-1.162362221 +pers28,-0.758259243 +pers29,-2.225984183 +pers30,-2.166512309 +pers31,-1.790080337 +pers32,-2.175142177 +pers33,-1.391267257 +pers34,-1.403458399 +pers35,-1.514121013 +pers36,-1.983086234 +pers37,-1.746601113 +pers38,-2.495223869 +pers39,-1.864928829 +pers40,-2.571532646 +pers41,-1.654887486 +pers42,-1.126933791 +pers43,-1.482551325 +pers44,-2.105868335 \ No newline at end of file diff --git a/tests/expected/test15/value_principal_none_20_test15.csv b/tests/expected/test15/value_principal_none_20_test15.csv new file mode 100644 index 0000000..e6907c8 --- /dev/null +++ b/tests/expected/test15/value_principal_none_20_test15.csv @@ -0,0 +1,45 @@ +"","x" +"1",14.2508647856251 +"2",9.29367124190705 +"3",6.4978611722933 +"4",4.61980443930738 +"5",3.05823573966641 +"6",2.62400357827091 +"7",1.99077380581108 +"8",1.65687368306576 +"9",-0.000465138925331381 +"10",-0.000509491259929655 +"11",-0.000562525789463678 +"12",-0.000599325267054469 +"13",-0.000613092567529119 +"14",-0.00064614659249048 +"15",-0.000676111579105095 +"16",-0.000708313491160807 +"17",-0.000715379290410044 +"18",-0.000747364267679113 +"19",-0.000765485638504557 +"20",-0.000771763049723453 +"21",-0.000780639793752711 +"22",-0.000787738940817513 +"23",-0.000811768446807304 +"24",-0.000833395490831647 +"25",-0.000855737969618182 +"26",-0.000869111350151491 +"27",-0.000880381608326505 +"28",-0.000884006458516127 +"29",-0.000897360466110942 +"30",-0.000908607779136124 +"31",-0.0009294527264382 +"32",-0.000937746402008354 +"33",-0.000967473509608267 +"34",-0.001000012611829 +"35",-0.00102204157655196 +"36",-0.00103468539237349 +"37",-0.00106447356990858 +"38",-0.00110038333562804 +"39",-0.00116702339612245 +"40",-0.00119205924575606 +"41",-0.00123049833566492 +"42",-0.00131644463729588 +"43",-0.00133847837950501 +"44",-0.00149878680581881 diff --git a/tests/generate_r_output.r b/tests/generate_r_output.r index d9640b2..20a97c3 100644 --- a/tests/generate_r_output.r +++ b/tests/generate_r_output.r @@ -2,7 +2,8 @@ packages <- c('argparse', 'psych', 'GPArotation') new_packages <- packages[!(packages %in% installed.packages()[,"Package"])] if(length(new_packages)) { - install.packages(new_packages) + install.packages(new_packages, + repos = "http://cran.us.r-project.org") } library('argparse') @@ -19,18 +20,20 @@ mapping2 <- list(geominT = 'geomin_ort', # argument parser parser <- ArgumentParser(description='Fit some factor models') -parser$add_argument('-n', '--n_factors', type='integer', nargs='+', +parser$add_argument('-n', '--n_factors', type='integer', default=2, help='integer(s) specifying the number of factors') -parser$add_argument('-f', '--fit_methods', type='character', nargs='+', +parser$add_argument('-f', '--fit_methods', type='character', default='minres', help='Fit method(s)') -parser$add_argument('-r', '--rotations', type="character", nargs='+', +parser$add_argument('-r', '--rotations', type="character", default='promax', help='Rotaton(s)') -parser$add_argument('-t', '--test_file', type="character", nargs='+', +parser$add_argument('-t', '--test_file', type="character", default='test02.csv', help='Test file') +parser$add_argument('-o', '--output_dir', type="character", + default=NULL, help='Output directory') # parse the arguments into a list @@ -38,7 +41,13 @@ args <- parser$parse_args() # get the input path and directory path <- args$test_file -dir <- dirname(path) +if (is.null(args$output_dir)) { + dir <- dirname(path) +} else { + dir <- args$output_dir + dir.create(dir, showWarnings = FALSE) +} + filename <- basename(path) # read in the data @@ -63,25 +72,29 @@ for (n in args$n_factors) { rot_name <- mapping2[[rot]] rot_name <- if (length(rot_name) == 0) rot else rot_name - # write out the loadings - loadings_file <- paste('loading', - fm_name, - rot_name, - as.character(n), - filename, - sep='_') - loadings_file <- file.path(dir, loadings_file) - write.csv(res$loadings, loadings_file) - - # write out the communalities - communalities_file <- paste('communalities', - fm_name, - rot_name, - as.character(n), - filename, - sep='_') - communalities_file <- file.path(dir, communalities_file) - write.csv(res$communalities, communalities_file) + # get outputs + loadings <- res$loadings; + values <- res$values; + evalues <- res$e.values; + uniquenesses <- res$uniquenesses; + communalities <- res$communalities; + + info <- list('loading' = loadings, + 'value' = values, + 'evalues' = evalues, + 'uniquenesses' = uniquenesses, + 'communalities' = communalities) + for (name in names(info)) { + df_temp <- info[[name]] + out <- paste(name, + fm_name, + rot_name, + toString(n), + filename, + sep='_') + out_file <- file.path(dir, out) + write.csv(df_temp, out_file) + } } } } \ No newline at end of file diff --git a/tests/test_expected_factor_analyzer.py b/tests/test_expected_factor_analyzer.py index 3bccfb7..ab0fcb9 100644 --- a/tests/test_expected_factor_analyzer.py +++ b/tests/test_expected_factor_analyzer.py @@ -461,3 +461,20 @@ def test_02_none_principal(): ignore_value=True, ignore_communalities=True): assert check > THRESHOLD + +def test_15_none_principal(): + + test_name = 'test15' + factors = 20 + method = 'principal' + rotation = 'none' + svd_method = 'lapack' + + for check in check_scenario(test_name, + factors, + method, + rotation, + svd_method=svd_method, + ignore_value=True, + ignore_communalities=True): + assert check > THRESHOLD \ No newline at end of file diff --git a/tests/test_factor_analyzer.py b/tests/test_factor_analyzer.py index 9ea833f..2a468bc 100644 --- a/tests/test_factor_analyzer.py +++ b/tests/test_factor_analyzer.py @@ -13,6 +13,10 @@ from numpy.testing import assert_array_almost_equal from pandas.util.testing import assert_almost_equal +from sklearn.model_selection import GridSearchCV +from sklearn.pipeline import make_pipeline +from sklearn.tree import DecisionTreeClassifier + from factor_analyzer.utils import smc from factor_analyzer.factor_analyzer import FactorAnalyzer from factor_analyzer.factor_analyzer import (calculate_kmo, @@ -50,6 +54,28 @@ def test_calculate_kmo(): assert_almost_equal(kmo_overall, expected_overall) +def test_gridsearch(): + # make sure this doesn't fail + + X = pd.DataFrame(np.random.randn(1000).reshape(100, 10)) + y = pd.Series(np.random.choice([1, 0], size=100)) + + grid = {'factoranalyzer__n_factors': [5, 7], + 'factoranalyzer__rotation': [None, 'varimax'], + 'decisiontreeclassifier__max_depth': [2, 5]} + + fa = FactorAnalyzer() + decisiontree = DecisionTreeClassifier(random_state=123) + pipe = make_pipeline(fa, decisiontree) + + gridsearch = GridSearchCV(pipe, + grid, + scoring='f1', + cv=3, + verbose=0) + gridsearch.fit(X, y) + + class TestFactorAnalyzer: def test_analyze_weights(self): @@ -112,24 +138,19 @@ def test_analyze_impute_drop(self): assert_array_almost_equal(fa.corr_, expected_corr) @raises(ValueError) - def test_analyze_impute_value_error(self): - - data = pd.DataFrame({'A': [2, 4, 5, 6, 8, 9], - 'B': [4, 8, np.nan, 10, 16, 18], - 'C': [6, 12, 15, 12, 26, 27]}) + def test_analyze_bad_svd_method(self): + fa = FactorAnalyzer(svd_method='foo') + fa.fit(np.random.randn(500).reshape(100, 5)) + @raises(ValueError) + def test_analyze_impute_value_error(self): fa = FactorAnalyzer(rotation=None, impute='blah', n_factors=1) - fa.fit(data) + fa.fit(np.random.randn(500).reshape(100, 5)) @raises(ValueError) def test_analyze_rotation_value_error(self): - - data = pd.DataFrame({'A': [2, 4, 5, 6, 8, 9], - 'B': [4, 8, np.nan, 10, 16, 18], - 'C': [6, 12, 15, 12, 26, 27]}) - fa = FactorAnalyzer(rotation='blah', n_factors=1) - fa.fit(data) + fa.fit(np.random.randn(500).reshape(100, 5)) @raises(ValueError) def test_analyze_infinite(self):