diff --git a/catch/pool/param_search.py b/catch/pool/param_search.py index 8420556e4..88f66fcf6 100644 --- a/catch/pool/param_search.py +++ b/catch/pool/param_search.py @@ -353,6 +353,12 @@ def _total_probe_count_without_interp(params, probe_counts): return s +class CannotSatisfyProbeCountConstraintError(Exception): + """The search yielded a number of probes that exceeds the given constraint + """ + pass + + def _round_params(params, probe_counts, max_total_count, loss_coeffs, weights, mismatches_eps=0.01, cover_extension_eps=0.1, mismatches_round=1, cover_extension_round=1): @@ -434,12 +440,33 @@ def _round_params(params, probe_counts, max_total_count, loss_coeffs, weights, total_probe_count = ic._make_total_probe_count_across_datasets_fn( probe_counts, interp_fn_type='standard') + # Verify that the probe count satisfies the constraint # Note that this assertion may fail if we are dealing with datasets # for which few actual probe counts have been computed; in these # cases, the interpolation may severely underestimate the number # of probes at a particular parameter choice - assert total_probe_count(params_rounded) < max_total_count + # It may also fail if it is just not possible to satisfy the desired + # constraint given the precomputed probe counts (e.g., the total probe + # count for the most tolerant precomputed parameter values still exceeds + # the constraint) + tpc = total_probe_count(params_rounded) + if tpc > max_total_count: + msg = ("The total probe count based on parameter values found " + "in the search (%d) exceeds the given limit (%d). This " + "is likely to happen if the range of the precomputed " + "parameter values is not as large as it needs to be to " + "satisfy the constraint. That is, one or more parameter " + "values may need to be more loose to obtain %d probes. To " + "fix this, try inputting probe counts for a larger range " + "(in particular, less stringent choices) of parameter " + "values. Also, note that the search interpolates probe " + "counts between precomputed parameter values (%d may be an " + "interpolated count) and, if the precomputed parameter values " + "are too sparse (i.e., too few actual probe counts were " + "input), it may be underestimating the true number of probes " + "required." % (tpc, max_total_count, max_total_count, tpc)) + raise CannotSatisfyProbeCountConstraintError(msg) # Keep decreasing parameters while satisfying the constraint. # In particular, choose to decrease the parameter whose reduction @@ -685,4 +712,23 @@ def higher_dimensional_search(param_names, probe_counts, max_total_count, probe_counts, interp_fn_type='nd')(x_sol) x_sol_loss = loss_fn(x_sol, 0) + # Verify that the probe count satisfies the constraint + if x_sol_count > max_total_count: + msg = ("The total probe count based on parameter values found " + "in the search (%d) exceeds the given limit (%d). This " + "is likely to happen if the range of the precomputed " + "parameter values is not as large as it needs to be to " + "satisfy the constraint. That is, one or more parameter " + "values may need to be more loose to obtain %d probes. To " + "fix this, try inputting probe counts for a larger range " + "(in particular, less stringent choices) of parameter " + "values. Also, note that the search interpolates probe " + "counts between precomputed parameter values (%d may be an " + "interpolated count) and, if the precomputed parameter values " + "are too sparse (i.e., too few actual probe counts were " + "input), it may be underestimating the true number of probes " + "required." % (x_sol_count, max_total_count, max_total_count, + x_sol_count)) + raise CannotSatisfyProbeCountConstraintError(msg) + return (x_sol_dict, x_sol_count, x_sol_loss) diff --git a/catch/pool/tests/test_param_search.py b/catch/pool/tests/test_param_search.py index 397579310..657fa861e 100644 --- a/catch/pool/tests/test_param_search.py +++ b/catch/pool/tests/test_param_search.py @@ -124,6 +124,12 @@ def _search_vwafr_typical_counts(self, search_fn): hiv1_mismatches, hiv1_cover_extension = opt_params['hiv1_without_ltr'] self.assertTrue(hiv1_mismatches > 3 or hiv1_cover_extension > 20) + def _search_vwafr_too_small_counts(self, search_fn): + self.assertEqual(self.param_names_vwafr, ('mismatches', 'cover_extension')) + for max_total_count in [1, 1000, 10000]: + with self.assertRaises(param_search.CannotSatisfyProbeCountConstraintError): + search_fn(max_total_count) + def test_standard_search_vwafr_typical_counts(self): """Integration test with the V-WAfr probe set data.""" def search_fn(max_total_count): @@ -146,6 +152,13 @@ def test_standard_search_vwafr_high_count(self): self.assertEqual(mismatches, 0) self.assertEqual(cover_extension, 0) + def test_standard_search_vwafr_too_small_counts(self): + """Integration test with the V-WAfr probe set data.""" + def search_fn(max_total_count): + return param_search.standard_search(self.probe_counts_vwafr, + max_total_count) + self._search_vwafr_too_small_counts(search_fn) + def test_higher_dimensional_search_vwafr_typical_counts(self): """Integration test with the V-WAfr probe set data.""" def search_fn(max_total_count): @@ -154,6 +167,14 @@ def search_fn(max_total_count): max_total_count, loss_coeffs=(1.0, 1.0/100.0)) self._search_vwafr_typical_counts(search_fn) + def test_higher_dimensional_search_vwafr_too_small_counts(self): + """Integration test with the V-WAfr probe set data.""" + def search_fn(max_total_count): + return param_search.higher_dimensional_search( + self.param_names_vwafr, self.probe_counts_vwafr, + max_total_count, loss_coeffs=(1.0, 1.0/100.0)) + self._search_vwafr_too_small_counts(search_fn) + def test_higher_dimensional_search_vwafr_with_third_param(self): """Integration test with the V-WAfr probe set data.