diff --git a/fmtm_splitter/splitter.py b/fmtm_splitter/splitter.py index 570c816..25979fe 100755 --- a/fmtm_splitter/splitter.py +++ b/fmtm_splitter/splitter.py @@ -33,7 +33,6 @@ from osm_rawdata.postgres import PostgresClient from psycopg2.extensions import connection from shapely.geometry import Polygon, box, shape -from shapely.geometry.geo import mapping from shapely.ops import unary_union from fmtm_splitter.db import ( @@ -198,12 +197,18 @@ def meters_to_degrees( def splitBySquare( # noqa: N802 self, meters: int, + db: Union[str, connection], extract_geojson: Optional[Union[dict, FeatureCollection]] = None, ) -> FeatureCollection: """Split the polygon into squares. Args: meters (int): The size of each task square in meters. + db (str, psycopg2.extensions.connection): The db url, format: + postgresql://myusername:mypassword@myhost:5432/mydatabase + OR an psycopg2 connection object object that is reused. + Passing an connection object prevents requiring additional + database connections to be spawned. extract_geojson (dict, FeatureCollection): an OSM extract geojson, containing building polygons, or linestrings. @@ -221,35 +226,120 @@ def splitBySquare( # noqa: N802 cols = np.arange(xmin, xmax + width_deg, width_deg) rows = np.arange(ymin, ymax + length_deg, length_deg) - extract_geoms = [] - if extract_geojson: - features = ( - extract_geojson.get("features", extract_geojson) - if isinstance(extract_geojson, dict) - else extract_geojson.features - ) - extract_geoms = [shape(feature["geometry"]) for feature in features] - - # Generate grid polygons and clip them by AOI - polygons = [] - for x in cols[:-1]: - for y in rows[:-1]: - grid_polygon = box(x, y, x + width_deg, y + length_deg) - clipped_polygon = grid_polygon.intersection(self.aoi) - - if clipped_polygon.is_empty: - continue - - # Check intersection with extract geometries if available - if extract_geoms: - if any(geom.within(clipped_polygon) for geom in extract_geoms): - polygons.append(clipped_polygon) - else: - polygons.append(clipped_polygon) - - self.split_features = FeatureCollection( - [Feature(geometry=mapping(poly)) for poly in polygons] - ) + with create_connection(db) as conn: + with conn.cursor() as cur: + # Create temporary table + cur.execute(""" + CREATE TEMP TABLE temp_polygons ( + id SERIAL PRIMARY KEY, + geom GEOMETRY(GEOMETRY, 4326), + area DOUBLE PRECISION + ) + """) + + extract_geoms = [] + if extract_geojson: + features = ( + extract_geojson.get("features", extract_geojson) + if isinstance(extract_geojson, dict) + else extract_geojson.features + ) + extract_geoms = [shape(feature["geometry"]) for feature in features] + + # Generate grid polygons and clip them by AOI + polygons = [] + for x in cols[:-1]: + for y in rows[:-1]: + grid_polygon = box(x, y, x + width_deg, y + length_deg) + clipped_polygon = grid_polygon.intersection(self.aoi) + + if clipped_polygon.is_empty: + continue + + # Check intersection with extract geometries if available + if extract_geoms: + if any( + geom.centroid.within(clipped_polygon) + for geom in extract_geoms + ): + polygons.append( + (clipped_polygon.wkt, clipped_polygon.wkt) + ) + + else: + polygons.append((clipped_polygon.wkt, clipped_polygon.wkt)) + + insert_query = """ + INSERT INTO temp_polygons (geom, area) + SELECT ST_GeomFromText(%s, 4326), + ST_Area(ST_GeomFromText(%s, 4326)::geography) + """ + + if polygons: + cur.executemany(insert_query, polygons) + + area_threshold = 0.35 * (meters**2) + + cur.execute( + """ + DO $$ + DECLARE + small_polygon RECORD; + nearest_neighbor RECORD; + BEGIN + CREATE TEMP TABLE small_polygons As + SELECT id, geom, area + FROM temp_polygons + WHERE area < %s; + FOR small_polygon IN SELECT * FROM small_polygons + LOOP + FOR nearest_neighbor IN + SELECT id, + lp.geom AS large_geom, + ST_LENGTH2D( + ST_INTERSECTION(small_polygon.geom, geom) + ) AS shared_bound + FROM temp_polygons lp + WHERE id NOT IN (SELECT id FROM small_polygons) + AND ST_Touches(small_polygon.geom, lp.geom) + AND ST_Touches(small_polygon.geom, lp.geom) + AND ST_GEOMETRYTYPE( + ST_INTERSECTION(small_polygon.geom, geom) + ) != 'ST_Point' + ORDER BY shared_bound DESC + LIMIT 1 + LOOP + UPDATE temp_polygons + SET geom = ST_UNION(small_polygon.geom, geom) + WHERE id = nearest_neighbor.id; + + DELETE FROM temp_polygons WHERE id = small_polygon.id; + EXIT; + END LOOP; + END LOOP; + END $$; + """, + (area_threshold,), + ) + + cur.execute( + """ + SELECT + JSONB_BUILD_OBJECT( + 'type', 'FeatureCollection', + 'features', JSONB_AGG(feature) + ) + FROM( + SELECT JSONB_BUILD_OBJECT( + 'type', 'Feature', + 'properties', JSONB_BUILD_OBJECT('area', (t.area)), + 'geometry', ST_ASGEOJSON(t.geom)::json + ) AS feature + FROM temp_polygons as t + ) AS features; + """ + ) + self.split_features = cur.fetchone()[0] return self.split_features def splitBySQL( # noqa: N802 @@ -444,6 +534,7 @@ def outputGeojson( # noqa: N802 def split_by_square( aoi: Union[str, FeatureCollection], + db: Union[str, connection], meters: int = 100, osm_extract: Union[str, FeatureCollection] = None, outfile: Optional[str] = None, @@ -453,6 +544,11 @@ def split_by_square( Args: aoi(str, FeatureCollection): Input AOI, either a file path, GeoJSON string, or FeatureCollection object. + db (str, psycopg2.extensions.connection): The db url, format: + postgresql://myusername:mypassword@myhost:5432/mydatabase + OR an psycopg2 connection object object that is reused. + Passing an connection object prevents requiring additional + database connections to be spawned. meters(str, optional): Specify the square size for the grid. Defaults to 100m grid. osm_extract (str, FeatureCollection): an OSM extract geojson, @@ -479,6 +575,7 @@ def split_by_square( for index, feat in enumerate(feat_array): featcol = split_by_square( FeatureCollection(features=[feat]), + db, meters, None, f"{Path(outfile).stem}_{index}.geojson)" if outfile else None, @@ -489,7 +586,7 @@ def split_by_square( split_features = FeatureCollection(features) else: splitter = FMTMSplitter(aoi_featcol) - split_features = splitter.splitBySquare(meters, extract_geojson) + split_features = splitter.splitBySquare(meters, db, extract_geojson) if not split_features: msg = "Failed to generate split features." log.error(msg) @@ -795,6 +892,7 @@ def main(args_list: list[str] | None = None): if args.meters: split_by_square( args.boundary, + db=args.dburl, meters=args.meters, outfile=args.outfile, osm_extract=args.extract,