Skip to content

Commit

Permalink
fix evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
aramoto99 committed Oct 23, 2023
1 parent 783b366 commit a87e3db
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 10 deletions.
4 changes: 2 additions & 2 deletions aiaccel/cli/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,8 @@ def main() -> None: # pragma: no cover
config_name = Path(args.config).name
shutil.copy(Path(args.config), dst / config_name)

if os.path.exists(workspace.final_result_file):
with open(workspace.final_result_file, "r") as f:
if os.path.exists(workspace.best_result_file):
with open(workspace.best_result_file, "r") as f:
final_results: list[dict[str, Any]] = yaml.load(f, Loader=yaml.UnsafeLoader)

for i, final_result in enumerate(final_results):
Expand Down
12 changes: 8 additions & 4 deletions aiaccel/scheduler/abstract_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,15 @@ def evaluate(self) -> None:

best_trial_ids, _ = self.storage.get_best_trial(self.goals)
if best_trial_ids is None:
self.logger.error(f"Failed to output {self.workspace.final_result_file}.")
self.logger.error(f"Failed to output {self.workspace.best_result_file}.")
return
hp_results = []
for best_trial_id in best_trial_ids:
hp_results.append(self.storage.get_hp_dict(best_trial_id))
create_yaml(self.workspace.final_result_file, hp_results, self.workspace.lock)
self.logger.info("Best hyperparameter is followings:")
self.logger.info(hp_results)

create_yaml(self.workspace.best_result_file, hp_results, self.workspace.lock)

finished = self.storage.get_num_finished()
if self.config.optimize.trial_number >= finished:
self.logger.info("Best hyperparameter is followings:")
self.logger.info(hp_results)
3 changes: 2 additions & 1 deletion aiaccel/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@ def __init__(self, base_path: str):
self.storage = self.path / dict_storage
self.tensorboard = self.path / dict_tensorboard
self.timestamp = self.path / dict_timestamp

self.consists = [
self.alive,
self.error,
Expand All @@ -103,6 +102,8 @@ def __init__(self, base_path: str):
self.retults_csv_file = self.path / "results.csv"
self.final_result_file = self.path / "final_result.result"
self.storage_file_path = self.storage / "storage.db"
self.best_result_file = self.path / "best_result.yaml"


def create(self) -> bool:
"""Create a work directory.
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,3 @@ def evaluate(self, config, is_multi_objective=False):
assert finished == config.optimize.trial_number
assert ready == 0
assert running == 0

if not is_multi_objective:
assert workspace.final_result_file.exists()

0 comments on commit a87e3db

Please sign in to comment.