diff --git a/cjdb/cli.py b/cjdb/cli.py index 11c2eed..9edc2d4 100644 --- a/cjdb/cli.py +++ b/cjdb/cli.py @@ -8,6 +8,12 @@ from cjdb.modules.utils import get_db_engine, get_db_psycopg_conn from cjdb.resources import strings as s +def get_password(ctx, param, value): + if value is None and 'PGPASSWORD' in os.environ: + return os.environ['PGPASSWORD'] + else: + return click.prompt("Password for database user", hide_input=True) + @click.group() @click.version_option( @@ -33,8 +39,9 @@ def cjdb(ctx): @click.option("--user", "-U", type=str, required=True, help=s.user_help) @click.password_option( help=s.password_help, - prompt="Password for database user", - confirmation_prompt=False + prompt=False, + confirmation_prompt=False, + callback=get_password ) @click.option("--database", "-d", type=str, @@ -83,6 +90,14 @@ def cjdb(ctx): default=False, help=s.transform_help, ) +@click.option( + "--skip-post-import", + "-S", + "skip_post_import", + is_flag=True, + default=False, + help="Skip post import functions", +) def import_cj( filepath, host, @@ -96,7 +111,8 @@ def import_cj( partial_indexed_attributes, ignore_repeated_file, overwrite, - transform + transform, + skip_post_import # add this parameter ): """Import CityJSONL files to a PostgreSQL database. Example of cli command: @@ -115,7 +131,8 @@ def import_cj( partial_indexed_attributes, ignore_repeated_file, overwrite, - transform + transform, + skip_post_import ) as imp: imp.run_import() @@ -127,8 +144,9 @@ def import_cj( @click.option("--user", "-U", type=str, default="postgres", help=s.user_help) @click.password_option( help=s.password_help, - prompt="Password for database user", - confirmation_prompt=False + prompt=False, + confirmation_prompt=False, + callback=get_password ) @click.option("--database", "-d", type=str, diff --git a/cjdb/modules/importer.py b/cjdb/modules/importer.py index d454f92..aa9a6f6 100644 --- a/cjdb/modules/importer.py +++ b/cjdb/modules/importer.py @@ -43,7 +43,7 @@ def __init__(self, file="stdin"): class Importer: def __init__(self, engine, filepath, db_schema, input_srid, indexed_attributes, partial_indexed_attributes, - ignore_repeated_file, overwrite, transform): + ignore_repeated_file, overwrite, transform, skip_post_import): self.engine = engine self.filepath = filepath self.db_schema = db_schema @@ -55,6 +55,7 @@ def __init__(self, engine, filepath, db_schema, input_srid, self.max_id = 0 self.processed = dict() self.transform = transform + self.skip_post_import = skip_post_import # get allowed types for validation self.city_object_types = get_city_object_types() @@ -74,9 +75,11 @@ def run_import(self) -> None: self.prepare_database() self.max_id = CjObjectModel.get_max_id(self.session) self.parse_cityjson() - self.session.commit() # post import operations like clustering, indexing... - self.post_import() + if not self.skip_post_import: + self.post_import() + else: + logger.info("Post import was skipped.") self.session.commit() def prepare_database(self) -> None: