Skip to content

Commit

Permalink
delete init state
Browse files Browse the repository at this point in the history
  • Loading branch information
aramoto99 committed Oct 24, 2023
1 parent db217c1 commit abc4738
Show file tree
Hide file tree
Showing 8 changed files with 55 additions and 337 deletions.
32 changes: 0 additions & 32 deletions aiaccel/scheduler/abstract_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,18 +130,6 @@ def pre_process(self) -> None:
self.write_random_seed_to_debug_log()
self.resume()

self.algorithm = RandomSampling(self.config)
self.change_state_finished_trials()

runnings = self.storage.trial.get_running()
for running in runnings:
job = self.start_job(running)
self.logger.info(f"restart hp files in previous running directory: {running}")

while job.get_state_name() != "Scheduling":
job.main()
job.schedule()

def post_process(self) -> None:
"""Post-procedure after executed processes.
Expand Down Expand Up @@ -183,26 +171,6 @@ def inner_loop_main_process(self) -> bool:
if self.buff.d[job.trial_id].has_difference():
self.logger.info(f"name: {job.trial_id}, state: {state_name}")


# scheduled_candidates = []
# for job in self.jobs:
# if job.get_state_name() == "Scheduling":
# scheduled_candidates.append(job)

# selected_jobs = self.algorithm.select_hp(scheduled_candidates, self.available_resource, rng=self._rng)

# if len(selected_jobs) > 0:
# for job in selected_jobs:
# if job.get_state_name() == "Scheduling":
# self._serialize(job.trial_id)
# job.schedule()
# self.logger.debug(f"trial id: {job.trial_id} has been scheduled.")
# selected_jobs.remove(job)

# for job in self.jobs:
# job.main()
# self.logger.info(f"name: {job.trial_id}, state: {job.get_state_name()}")

self.get_stats()
self.update_resource()
self.print_dict_state()
Expand Down
30 changes: 10 additions & 20 deletions aiaccel/scheduler/job/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@


JOB_STATES = [
{"name": "init"},
{"name": "ready"},
{"name": "running"},
{"name": "finished"},
Expand All @@ -30,13 +29,6 @@


JOB_TRANSITIONS: list[dict[str, str | list[str]]] = [
{
"trigger": "next_state",
"source": "init",
"dest": "ready",
"before": "before_ready",
"after": "after_ready",
},
{
"trigger": "next_state",
"source": "ready",
Expand All @@ -59,12 +51,12 @@
},
{
"trigger": "expire",
"source": ["init", "ready", "running", "finished"],
"source": ["ready", "running", "finished"],
"dest": "failure",
},
{
"trigger": "timeout",
"source": ["init", "ready", "running", "finished"],
"source": ["ready", "running", "finished"],
"dest": "timeout",
"before": "before_timeout",
"after": "after_timeout",
Expand Down Expand Up @@ -92,6 +84,7 @@ def __init__(self, config: DictConfig, scheduler: AbstractScheduler, model: Abst
self.trial_id = trial_id
self.content = self.storage.get_hp_dict(self.trial_id)
self.scheduler = scheduler
self.goals: list[str] = self.config.optimize.goal
self.model = model
if self.model is None:
raise ValueError(
Expand Down Expand Up @@ -157,9 +150,7 @@ def get_job_elapsed_time_in_seconds(self) -> float:
"""
if self.start_time is None:
return 0.0
if self.end_time is None:
return 0.0
return (self.end_time - self.start_time).total_seconds()
return (datetime.now() - self.start_time).total_seconds()

def is_timeout(self) -> bool:
"""Check if a job is timeout.
Expand All @@ -169,8 +160,6 @@ def is_timeout(self) -> bool:
"""
if self.start_time is None:
return False
if self.end_time is None:
return False
elapsed_time = self.get_job_elapsed_time_in_seconds()
if elapsed_time > self.config.generic.batch_job_timeout:
return True
Expand All @@ -184,10 +173,13 @@ def main(self) -> None:
"""

state = self.machine.get_state(self.model.state)
if state in ["failure", "success", "timeout"]:
return
if self.is_timeout():
self.model.timeout(self)
return
try:
if state.name.lower() == "init":
self.model.next_state(self)
elif state.name.lower() == "ready":
if state.name.lower() == "ready":
self.model.next_state(self)
elif state.name.lower() == "running":
self.model.next_state(self)
Expand All @@ -197,6 +189,4 @@ def main(self) -> None:
self.logger.error(f"An error occurred in the job thread. {e}")
self.model.expire(self)
return
if self.is_timeout():
self.model.timeout(self)
return
11 changes: 3 additions & 8 deletions aiaccel/scheduler/job/model/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@ class AbstractModel(object):
next_state: Any
timeout: Any

# ready
def before_ready(self, obj: Job) -> None:
self.runner_create(obj)

def after_ready(self, obj: Job) -> None:
obj.write_start_time_to_storage()
self.job_submitted(obj)

def runner_create(self, obj: Job) -> None: # noqa: U100
...

Expand All @@ -33,6 +25,9 @@ def job_submitted(self, obj: Job) -> None: # noqa: U100

# running
def before_running(self, obj: Job) -> None:
self.runner_create(obj)
self.job_submitted(obj)
obj.write_start_time_to_storage()
obj.write_state_to_storage("running")

def after_running(self, obj: Job) -> None: # noqa: U100
Expand Down
22 changes: 16 additions & 6 deletions aiaccel/scheduler/job/model/local_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,20 +105,30 @@ def write_results_to_database(self, obj: "Job") -> None:
objectives: list[str] = []

if len(stdouts) > 0:
objective = stdouts[-1] # TODO: fix
objective = objective.strip("[]")
objective = objective.replace(" ", "")
objectives = objective.split(",")
if len(stdouts) >= len(obj.goals):
objectives = stdouts[-len(obj.goals) :]
elif len(stdouts) == 1:
objective.append(stdouts[0])
elif len(stdouts) > 1:
for i in range(len(obj.goals)):
o_index = len(stdouts) - len(obj.goals) + i
objectives.append(stdouts[o_index])
else:
raise NotImplementedError("Not Readched")
if len(stdouts) < len(obj.goals):
obj.logger.warning(
f"Number of objectives is less than the number of goals. "
f"Number of objectives: {len(stdouts)}, "
f"Number of goals: {len(obj.goals)}"
)

error = "\n".join(stderrs)

args = {
"storage_file_path": str(obj.workspace.storage_file_path),
"trial_id": str(trial_id),
"error": error,
"returncode": returncode,
}

if len(error) == 0:
del args["error"]

Expand Down
33 changes: 0 additions & 33 deletions aiaccel/scheduler/pylocal_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,6 @@ def inner_loop_main_process(self) -> bool:
for trial_id, xs, ys, err, start_time, end_time in self.pool.imap_unordered(execute, args):
self.report(trial_id, ys, err, start_time, end_time)
self.storage.trial.set_any_trial_state(trial_id=trial_id, state="finished")

self.create_result_file(trial_id, xs, ys, err, start_time, end_time)

return True

def post_process(self) -> None:
Expand Down Expand Up @@ -120,36 +117,6 @@ def get_result_file_path(self) -> Path:
"""
return self.workspace.get_any_result_file_path(self.trial_id.get())

def create_result_file(
self, trial_id: int, xs: dict[str, Any], ys: list[Any], error: str, start_time: str, end_time: str
) -> None:
args = {
"file": self.workspace.get_any_result_file_path(trial_id),
"trial_id": str(trial_id),
"config": self.config.config_path,
"start_time": start_time,
"end_time": end_time,
"error": error,
}

if len(error) == 0:
del args["error"]

commands = ["aiaccel-set-result"]
for key in args.keys():
commands.append(f"--{key}={str(args[key])}")

commands.append("--objective")
for y in ys:
commands.append(str(y))

for key in xs.keys():
commands.append(f"--{key}={str(xs[key])}")

self.processes.append(Popen(commands))

return None

def __getstate__(self) -> dict[str, Any]:
obj = super().__getstate__()
del obj["run"]
Expand Down
29 changes: 19 additions & 10 deletions aiaccel/util/aiaccel.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,22 +145,22 @@ def execute(
if self.workspace is not None and self.args.trial_id is not None:
set_logging_file_for_trial_id(self.workspace, self.args.trial_id)

y = None
ys = None
err = ""

start_time = datetime.now().strftime(datetime_format)

try:
y = cast_y(func(xs), y_data_type)
ys = cast_y(func(xs), y_data_type)
except BaseException:
err = str(traceback.format_exc())
y = None
ys = None
else:
err = ""

end_time = datetime.now().strftime(datetime_format)

return xs, y, err, start_time, end_time
return xs, ys, err, start_time, end_time

def execute_and_report(
self, func: Callable[[dict[str, float | int | str]], float], y_data_type: str | None = None
Expand Down Expand Up @@ -189,21 +189,30 @@ def func(p: dict[str, Any]) -> float:
"""

xs = self.args.get_xs_from_args()
y: Any = None
_, y, err, _, _ = self.execute(func, xs, y_data_type)
ys: Any = None
_, ys, err, _, _ = self.execute(func, xs, y_data_type)

self.report(y, err)
self.report(ys, err)

def report(self, y: Any, err: str) -> None:
def report(self, ys: Any, err: str) -> None:
"""Save the results to a text file.
Args:
y (Any): Objective value.
err (str): Error string.
"""

if y is not None:
sys.stdout.write(f"\n{y}\n")
if ys is not None:
if isinstance(ys, str):
ys = ys.replace(" ", "")
ys = ys.split(",")
for y in ys:
sys.stdout.write(f"{y}\n")
elif isinstance(ys, (list, tuple)):
for y in ys:
sys.stdout.write(f"{y}\n")
else:
sys.stdout.write(f"{ys}\n")
if err != "":
sys.stderr.write(f"{err}\n")
exit(1)
Expand Down
Loading

0 comments on commit abc4738

Please sign in to comment.