Skip to content

Commit

Permalink
Simplify run_solve method
Browse files Browse the repository at this point in the history
  • Loading branch information
reinterpretcat committed Dec 12, 2023
1 parent a29ddf8 commit e8dda99
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 64 deletions.
144 changes: 81 additions & 63 deletions vrp-cli/src/commands/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,97 +308,52 @@ pub fn run_solve(
matches: &ArgMatches,
out_writer_func: fn(Option<File>) -> BufWriter<Box<dyn Write>>,
) -> Result<(), GenericError> {
let max_time = parse_int_value::<usize>(matches, TIME_ARG_NAME, "max time")?;

let environment = get_environment(matches, max_time)?;

let environment = get_environment(matches)?;
let formats = get_formats(matches, environment.random.clone());

// required
let problem_path = matches.get_one::<String>(PROBLEM_ARG_NAME).unwrap();
let problem_format = matches.get_one::<String>(FORMAT_ARG_NAME).unwrap();
let problem_file = open_file(problem_path, "problem");

// optional
let max_generations = parse_int_value::<usize>(matches, GENERATIONS_ARG_NAME, "max generations")?;
let telemetry_mode = if matches.get_one::<bool>(LOG_ARG_NAME).copied().unwrap_or(false) {
get_default_telemetry_mode(environment.logger.clone())
} else {
TelemetryMode::None
};

let is_check_requested = matches.get_one::<bool>(CHECK_ARG_NAME).copied().unwrap_or(false);
let min_cv = get_min_cv(matches)?;
let init_solution = matches.get_one::<String>(INIT_SOLUTION_ARG_NAME).map(|path| open_file(path, "init solution"));
let init_size = get_init_size(matches)?;
let config = matches.get_one::<String>(CONFIG_ARG_NAME).map(|path| open_file(path, "config"));
let matrix_files = get_matrix_files(matches);
let out_result = matches.get_one::<String>(OUT_RESULT_ARG_NAME).map(|path| create_file(path, "out solution"));
let out_geojson = matches.get_one::<String>(GEO_JSON_ARG_NAME).map(|path| create_file(path, "out geojson"));

let is_get_locations_set = matches.get_one::<bool>(GET_LOCATIONS_ARG_NAME).copied().unwrap_or(false);
let mode = matches.get_one::<String>(SEARCH_MODE_ARG_NAME);
let is_check_requested = matches.get_one::<bool>(CHECK_ARG_NAME).copied().unwrap_or(false);

match formats.get(problem_format.as_str()) {
Some((problem_reader, init_reader, solution_writer, locations_writer)) => {
Some((
ProblemReader(problem_reader),
init_reader,
SolutionWriter(solution_writer),
LocationWriter(locations_writer),
)) => {
let out_buffer = out_writer_func(out_result);
let geo_buffer = out_geojson.map(|geojson| create_write_buffer(Some(geojson)));

if is_get_locations_set {
locations_writer.0(problem_file, out_buffer)
locations_writer(problem_file, out_buffer)
.map_err(|err| GenericError::from(format!("cannot get locations '{err}'")))
} else {
match problem_reader.0(problem_file, matrix_files) {
match problem_reader(problem_file, matrix_files) {
Ok(problem) => {
let problem = Arc::new(problem);
let solutions = init_solution
.map(|file| {
init_reader.0(file, problem.clone())
.map_err(|err| format!("cannot read initial solution '{err}'"))
.map(|solution| {
vec![InsertionContext::new_from_solution(
problem.clone(),
(solution, None),
environment.clone(),
)]
})
})
.unwrap_or_else(|| Ok(Vec::new()))?;
let init_solutions = init_solution
.map(|file| read_init_solution(problem.clone(), environment.clone(), file, init_reader))
.unwrap_or_else(|| Ok(Vec::default()))?;

let solver = if let Some(config) = config {
create_builder_from_config_file(problem.clone(), solutions, BufReader::new(config))
.and_then(|builder| builder.build())
.map(|config| Solver::new(problem.clone(), config))
.map_err(|err| format!("cannot read config: '{err}'"))?
from_config_parameters(problem.clone(), init_solutions, config)?
} else {
let builder = create_default_config_builder(
problem.clone(),
environment.clone(),
telemetry_mode.clone(),
)
.with_init_solutions(solutions, init_size)
.with_max_generations(max_generations)
.with_max_time(max_time)
.with_min_cv(min_cv, "min_cv".to_string())
.with_context(RefinementContext::new(
problem.clone(),
get_population(mode, problem.goal.clone(), environment.clone()),
telemetry_mode,
environment.clone(),
));

let config = if cfg!(feature = "async-evolution") && environment.is_experimental {
builder.with_strategy(get_async_evolution(problem.clone(), environment.clone())?)
} else {
builder.with_heuristic(get_heuristic(matches, problem.clone(), environment)?)
}
.build()?;

Solver::new(problem.clone(), config)
from_cli_parameters(problem.clone(), environment, init_solutions, matches)?
};

let solution = solver.solve().map_err(|err| format!("cannot find any solution: '{err}'"))?;

solution_writer.0(&problem, solution, out_buffer, geo_buffer).unwrap();
solution_writer(&problem, solution, out_buffer, geo_buffer).unwrap();

if is_check_requested {
check_pragmatic_solution_with_args(matches)?;
Expand All @@ -417,6 +372,68 @@ pub fn run_solve(
}
}

fn read_init_solution(
problem: Arc<Problem>,
environment: Arc<Environment>,
file: File,
InitSolutionReader(init_reader): &InitSolutionReader,
) -> Result<Vec<InsertionContext>, GenericError> {
init_reader(file, problem.clone())
.map_err(|err| format!("cannot read initial solution '{err}'").into())
.map(|solution| vec![InsertionContext::new_from_solution(problem.clone(), (solution, None), environment)])
}

fn from_config_parameters(
problem: Arc<Problem>,
init_solutions: Vec<InsertionContext>,
config: File,
) -> Result<Solver, GenericError> {
create_builder_from_config_file(problem.clone(), init_solutions, BufReader::new(config))
.and_then(|builder| builder.build())
.map(|config| Solver::new(problem.clone(), config))
.map_err(|err| format!("cannot read config: '{err}'").into())
}

fn from_cli_parameters(
problem: Arc<Problem>,
environment: Arc<Environment>,
init_solutions: Vec<InsertionContext>,
matches: &ArgMatches,
) -> Result<Solver, GenericError> {
let max_time = parse_int_value::<usize>(matches, TIME_ARG_NAME, "max time")?;

let max_generations = parse_int_value::<usize>(matches, GENERATIONS_ARG_NAME, "max generations")?;
let telemetry_mode = if matches.get_one::<bool>(LOG_ARG_NAME).copied().unwrap_or(false) {
get_default_telemetry_mode(environment.logger.clone())
} else {
TelemetryMode::None
};
let min_cv = get_min_cv(matches)?;
let init_size = get_init_size(matches)?;
let mode = matches.get_one::<String>(SEARCH_MODE_ARG_NAME);

let builder = create_default_config_builder(problem.clone(), environment.clone(), telemetry_mode.clone())
.with_init_solutions(init_solutions, init_size)
.with_max_generations(max_generations)
.with_max_time(max_time)
.with_min_cv(min_cv, "min_cv".to_string())
.with_context(RefinementContext::new(
problem.clone(),
get_population(mode, problem.goal.clone(), environment.clone()),
telemetry_mode,
environment.clone(),
));

let config = if cfg!(feature = "async-evolution") && environment.is_experimental {
builder.with_strategy(get_async_evolution(problem.clone(), environment.clone())?)
} else {
builder.with_heuristic(get_heuristic(matches, problem.clone(), environment)?)
}
.build()?;

Ok(Solver::new(problem.clone(), config))
}

fn get_min_cv(matches: &ArgMatches) -> Result<Option<(String, usize, f64, bool)>, GenericError> {
let err_result = Err("cannot parse min_cv parameter".into());
matches
Expand Down Expand Up @@ -451,7 +468,8 @@ fn get_init_size(matches: &ArgMatches) -> Result<Option<usize>, GenericError> {
.unwrap_or(Ok(None))
}

fn get_environment(matches: &ArgMatches, max_time: Option<usize>) -> Result<Arc<Environment>, GenericError> {
fn get_environment(matches: &ArgMatches) -> Result<Arc<Environment>, GenericError> {
let max_time = parse_int_value::<usize>(matches, TIME_ARG_NAME, "max time")?;
let quota = Some(create_interruption_quota(max_time));
let is_experimental = matches.get_one::<bool>(EXPERIMENTAL_ARG_NAME).copied().unwrap_or(false);

Expand Down
2 changes: 1 addition & 1 deletion vrp-cli/tests/unit/commands/solve_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ fn can_specify_parallelism() {
] {
let matches = get_solomon_matches(params.as_slice());

let thread_pool_size = get_environment(&matches, None).map(|e| e.parallelism.thread_pool_size());
let thread_pool_size = get_environment(&matches).map(|e| e.parallelism.thread_pool_size());

assert_eq!(thread_pool_size, result);
}
Expand Down

0 comments on commit e8dda99

Please sign in to comment.