From e8dda9939e301ca9a3b0880024b57a997adebd8d Mon Sep 17 00:00:00 2001 From: reinterpretcat Date: Tue, 12 Dec 2023 23:30:55 +0100 Subject: [PATCH] Simplify run_solve method --- vrp-cli/src/commands/solve.rs | 144 ++++++++++++---------- vrp-cli/tests/unit/commands/solve_test.rs | 2 +- 2 files changed, 82 insertions(+), 64 deletions(-) diff --git a/vrp-cli/src/commands/solve.rs b/vrp-cli/src/commands/solve.rs index ac975855b..829e62a3e 100644 --- a/vrp-cli/src/commands/solve.rs +++ b/vrp-cli/src/commands/solve.rs @@ -308,97 +308,52 @@ pub fn run_solve( matches: &ArgMatches, out_writer_func: fn(Option) -> BufWriter>, ) -> Result<(), GenericError> { - let max_time = parse_int_value::(matches, TIME_ARG_NAME, "max time")?; - - let environment = get_environment(matches, max_time)?; - + let environment = get_environment(matches)?; let formats = get_formats(matches, environment.random.clone()); - // required let problem_path = matches.get_one::(PROBLEM_ARG_NAME).unwrap(); let problem_format = matches.get_one::(FORMAT_ARG_NAME).unwrap(); let problem_file = open_file(problem_path, "problem"); - // optional - let max_generations = parse_int_value::(matches, GENERATIONS_ARG_NAME, "max generations")?; - let telemetry_mode = if matches.get_one::(LOG_ARG_NAME).copied().unwrap_or(false) { - get_default_telemetry_mode(environment.logger.clone()) - } else { - TelemetryMode::None - }; - - let is_check_requested = matches.get_one::(CHECK_ARG_NAME).copied().unwrap_or(false); - let min_cv = get_min_cv(matches)?; let init_solution = matches.get_one::(INIT_SOLUTION_ARG_NAME).map(|path| open_file(path, "init solution")); - let init_size = get_init_size(matches)?; let config = matches.get_one::(CONFIG_ARG_NAME).map(|path| open_file(path, "config")); let matrix_files = get_matrix_files(matches); let out_result = matches.get_one::(OUT_RESULT_ARG_NAME).map(|path| create_file(path, "out solution")); let out_geojson = matches.get_one::(GEO_JSON_ARG_NAME).map(|path| create_file(path, "out geojson")); + let is_get_locations_set = matches.get_one::(GET_LOCATIONS_ARG_NAME).copied().unwrap_or(false); - let mode = matches.get_one::(SEARCH_MODE_ARG_NAME); + let is_check_requested = matches.get_one::(CHECK_ARG_NAME).copied().unwrap_or(false); match formats.get(problem_format.as_str()) { - Some((problem_reader, init_reader, solution_writer, locations_writer)) => { + Some(( + ProblemReader(problem_reader), + init_reader, + SolutionWriter(solution_writer), + LocationWriter(locations_writer), + )) => { let out_buffer = out_writer_func(out_result); let geo_buffer = out_geojson.map(|geojson| create_write_buffer(Some(geojson))); if is_get_locations_set { - locations_writer.0(problem_file, out_buffer) + locations_writer(problem_file, out_buffer) .map_err(|err| GenericError::from(format!("cannot get locations '{err}'"))) } else { - match problem_reader.0(problem_file, matrix_files) { + match problem_reader(problem_file, matrix_files) { Ok(problem) => { let problem = Arc::new(problem); - let solutions = init_solution - .map(|file| { - init_reader.0(file, problem.clone()) - .map_err(|err| format!("cannot read initial solution '{err}'")) - .map(|solution| { - vec![InsertionContext::new_from_solution( - problem.clone(), - (solution, None), - environment.clone(), - )] - }) - }) - .unwrap_or_else(|| Ok(Vec::new()))?; + let init_solutions = init_solution + .map(|file| read_init_solution(problem.clone(), environment.clone(), file, init_reader)) + .unwrap_or_else(|| Ok(Vec::default()))?; let solver = if let Some(config) = config { - create_builder_from_config_file(problem.clone(), solutions, BufReader::new(config)) - .and_then(|builder| builder.build()) - .map(|config| Solver::new(problem.clone(), config)) - .map_err(|err| format!("cannot read config: '{err}'"))? + from_config_parameters(problem.clone(), init_solutions, config)? } else { - let builder = create_default_config_builder( - problem.clone(), - environment.clone(), - telemetry_mode.clone(), - ) - .with_init_solutions(solutions, init_size) - .with_max_generations(max_generations) - .with_max_time(max_time) - .with_min_cv(min_cv, "min_cv".to_string()) - .with_context(RefinementContext::new( - problem.clone(), - get_population(mode, problem.goal.clone(), environment.clone()), - telemetry_mode, - environment.clone(), - )); - - let config = if cfg!(feature = "async-evolution") && environment.is_experimental { - builder.with_strategy(get_async_evolution(problem.clone(), environment.clone())?) - } else { - builder.with_heuristic(get_heuristic(matches, problem.clone(), environment)?) - } - .build()?; - - Solver::new(problem.clone(), config) + from_cli_parameters(problem.clone(), environment, init_solutions, matches)? }; let solution = solver.solve().map_err(|err| format!("cannot find any solution: '{err}'"))?; - solution_writer.0(&problem, solution, out_buffer, geo_buffer).unwrap(); + solution_writer(&problem, solution, out_buffer, geo_buffer).unwrap(); if is_check_requested { check_pragmatic_solution_with_args(matches)?; @@ -417,6 +372,68 @@ pub fn run_solve( } } +fn read_init_solution( + problem: Arc, + environment: Arc, + file: File, + InitSolutionReader(init_reader): &InitSolutionReader, +) -> Result, GenericError> { + init_reader(file, problem.clone()) + .map_err(|err| format!("cannot read initial solution '{err}'").into()) + .map(|solution| vec![InsertionContext::new_from_solution(problem.clone(), (solution, None), environment)]) +} + +fn from_config_parameters( + problem: Arc, + init_solutions: Vec, + config: File, +) -> Result { + create_builder_from_config_file(problem.clone(), init_solutions, BufReader::new(config)) + .and_then(|builder| builder.build()) + .map(|config| Solver::new(problem.clone(), config)) + .map_err(|err| format!("cannot read config: '{err}'").into()) +} + +fn from_cli_parameters( + problem: Arc, + environment: Arc, + init_solutions: Vec, + matches: &ArgMatches, +) -> Result { + let max_time = parse_int_value::(matches, TIME_ARG_NAME, "max time")?; + + let max_generations = parse_int_value::(matches, GENERATIONS_ARG_NAME, "max generations")?; + let telemetry_mode = if matches.get_one::(LOG_ARG_NAME).copied().unwrap_or(false) { + get_default_telemetry_mode(environment.logger.clone()) + } else { + TelemetryMode::None + }; + let min_cv = get_min_cv(matches)?; + let init_size = get_init_size(matches)?; + let mode = matches.get_one::(SEARCH_MODE_ARG_NAME); + + let builder = create_default_config_builder(problem.clone(), environment.clone(), telemetry_mode.clone()) + .with_init_solutions(init_solutions, init_size) + .with_max_generations(max_generations) + .with_max_time(max_time) + .with_min_cv(min_cv, "min_cv".to_string()) + .with_context(RefinementContext::new( + problem.clone(), + get_population(mode, problem.goal.clone(), environment.clone()), + telemetry_mode, + environment.clone(), + )); + + let config = if cfg!(feature = "async-evolution") && environment.is_experimental { + builder.with_strategy(get_async_evolution(problem.clone(), environment.clone())?) + } else { + builder.with_heuristic(get_heuristic(matches, problem.clone(), environment)?) + } + .build()?; + + Ok(Solver::new(problem.clone(), config)) +} + fn get_min_cv(matches: &ArgMatches) -> Result, GenericError> { let err_result = Err("cannot parse min_cv parameter".into()); matches @@ -451,7 +468,8 @@ fn get_init_size(matches: &ArgMatches) -> Result, GenericError> { .unwrap_or(Ok(None)) } -fn get_environment(matches: &ArgMatches, max_time: Option) -> Result, GenericError> { +fn get_environment(matches: &ArgMatches) -> Result, GenericError> { + let max_time = parse_int_value::(matches, TIME_ARG_NAME, "max time")?; let quota = Some(create_interruption_quota(max_time)); let is_experimental = matches.get_one::(EXPERIMENTAL_ARG_NAME).copied().unwrap_or(false); diff --git a/vrp-cli/tests/unit/commands/solve_test.rs b/vrp-cli/tests/unit/commands/solve_test.rs index 7ee895c92..af53655a4 100644 --- a/vrp-cli/tests/unit/commands/solve_test.rs +++ b/vrp-cli/tests/unit/commands/solve_test.rs @@ -110,7 +110,7 @@ fn can_specify_parallelism() { ] { let matches = get_solomon_matches(params.as_slice()); - let thread_pool_size = get_environment(&matches, None).map(|e| e.parallelism.thread_pool_size()); + let thread_pool_size = get_environment(&matches).map(|e| e.parallelism.thread_pool_size()); assert_eq!(thread_pool_size, result); }