Skip to content

Commit

Permalink
Dev opt dataclass - add agent time periods (#3)
Browse files Browse the repository at this point in the history
* v1.0.0a3 fix bugs

* add time periods in agent generation

---------

Co-authored-by: xyluo25 <[email protected]>
  • Loading branch information
xyluo25 and xyluo25 authored Oct 16, 2024
1 parent 3fe572b commit 9ae7b4e
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 10 deletions.
15 changes: 10 additions & 5 deletions grid2demand/_grid2demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def run_gravity_model(self,
print(" : Successfully generated OD demands.")
return None

def gen_agent_based_demand(self) -> None:
def gen_agent_based_demand(self, time_periods: str = "0700-0800") -> None:
"""generate agent-based demand
Args:
Expand All @@ -680,7 +680,9 @@ def gen_agent_based_demand(self) -> None:
node_dict = self.node_dict

self.df_agent = gen_agent_based_demand(node_dict, self.zone_dict,
df_demand=self.df_demand, verbose=self.verbose)
df_demand=self.df_demand,
time_periods=time_periods,
verbose=self.verbose)
return None

def save_results_to_csv(self, output_dir: str = "",
Expand All @@ -690,9 +692,9 @@ def save_results_to_csv(self, output_dir: str = "",
node: bool = True, # save updated node
poi: bool = True, # save updated poi
agent: bool = False, # save agent-based demand
agent_time_period: str = "0700-0800",
zone_od_dist_table: bool = False,
zone_od_dist_matrix: bool = False,
is_demand_with_geometry: bool = False,
overwrite_file: bool = True) -> None:
"""save results to csv files
Expand All @@ -713,7 +715,7 @@ def save_results_to_csv(self, output_dir: str = "",
self.output_dir = path2linux(output_dir)

if demand:
save_demand(self, overwrite_file=overwrite_file, is_demand_with_geometry=is_demand_with_geometry)
save_demand(self, overwrite_file=overwrite_file)

if zone:
save_zone(self, overwrite_file=overwrite_file)
Expand All @@ -731,7 +733,10 @@ def save_results_to_csv(self, output_dir: str = "",
save_zone_od_dist_matrix(self, overwrite_file=overwrite_file)

if agent:
self.gen_agent_based_demand()
if agent_time_period:
self.gen_agent_based_demand(time_periods=agent_time_period)
else:
self.gen_agent_based_demand()
save_agent(self, overwrite_file=overwrite_file)

return None
29 changes: 26 additions & 3 deletions grid2demand/func_lib/gen_agent_demand.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from random import choice, uniform
import math
import re

import pandas as pd
from pyufunc import gmns_geo
Expand All @@ -18,6 +19,7 @@ def gen_agent_based_demand(node_dict: dict, zone_dict: dict,
path_demand: str = "",
df_demand: pd.DataFrame = "",
agent_type: str = "v",
time_period: str = "0000-2359",
verbose: bool = False) -> pd.DataFrame:
"""Generate agent-based demand data
Expand All @@ -32,7 +34,28 @@ def gen_agent_based_demand(node_dict: dict, zone_dict: dict,
Returns:
pd.DataFrame: _description_
"""
# either path_demand or df_demand must be provided
# Validate time_period format
time_period_pattern = re.compile(r"^\d{4}-\d{4}$")
if not isinstance(time_period, str) or not time_period_pattern.match(time_period):
raise ValueError(
f"Error: time_period '{time_period}' must be a string in the format 'HHMM-HHMM'.")

start_time_str, end_time_str = time_period.split('-')

# Validate that the times are within the valid range
try:
start_time = int(start_time_str[:2]) * 60 + int(start_time_str[2:])
end_time = int(end_time_str[:2]) * 60 + int(end_time_str[2:])
except ValueError as e:
raise ValueError("Error: time_period contains non-numeric values.") from e

if not (0 <= start_time <= 1440) or not (0 <= end_time <= 1440):
raise ValueError(
"Error: time_period must be between '0000' and '2400'.")

if start_time >= end_time:
raise ValueError(
"Error: start_time must be less than end_time in time_period.")

# if path_demand is provided, read demand data from path_demand
if path_demand:
Expand All @@ -53,8 +76,8 @@ def gen_agent_based_demand(node_dict: dict, zone_dict: dict,
d_node_id = choice(zone_dict[d_zone_id]["node_id_list"] + [""])

if o_node_id and d_node_id:
# Change the range to 0 to 1440
rand_time = math.ceil(uniform(0, 1440))
# Generate a random time within the specified time period
rand_time = math.ceil(uniform(start_time, end_time))

# Calculate hours and minutes from rand_time
hours = rand_time // 60
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "grid2demand"
version = "1.0.0a2"
version = "1.0.0a3"
description = "A tool for generating zone-to-zone travel demand based on grid cells or TAZs and gravity model"
authors = [
{name = "Xiangyong Luo", email = "[email protected]"},
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

setuptools.setup(
name="grid2demand", # Replace with your own username
version="1.0.0a2",
version="1.0.0a3",
author="Xiangyong Luo, Dr.Xuesong (Simon) Zhou",
author_email="[email protected], [email protected]",
description="A tool for generating zone-to-zone travel demand based on grid cells or TAZs and gravity model",
Expand Down

0 comments on commit 9ae7b4e

Please sign in to comment.