diff --git a/.gitignore b/.gitignore index db06e01..e962159 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ -src/isd/_version.py +isd/_version.py # Byte-compiled / optimized / DLL files __pycache__/ diff --git a/src/isd/__init__.py b/isd/__init__.py similarity index 100% rename from src/isd/__init__.py rename to isd/__init__.py diff --git a/isd/batch.py b/isd/batch.py new file mode 100644 index 0000000..9cb600f --- /dev/null +++ b/isd/batch.py @@ -0,0 +1,53 @@ +import gzip +from io import BytesIO +from pathlib import Path +from dataclasses import dataclass +from typing import List, TYPE_CHECKING, Union, Optional +import datetime as dt + +from isd.record import Record + +if TYPE_CHECKING: + import pandas as pd + + +@dataclass +class Batch: + records: List[Record] + + @classmethod + def from_path(cls, path: Union[str, Path]) -> "Batch": + """Opens a local ISD file and returns an iterator over its records. + + If the path has a .gz extension, this function will assume it has gzip + compression and will attempt to open it using `gzip.open`. + """ + path = Path(path) + if path.suffix == ".gz": + with gzip.open(path) as gzip_file: + return cls([Record.from_string(gzip_line.decode("utf-8")) for gzip_line in gzip_file]) + else: + with open(path) as uncompressed_file: + return cls([Record.from_string(uncompressed_line) for uncompressed_line in uncompressed_file]) + + @classmethod + def from_string(cls, string: Union[str, BytesIO]) -> "Batch": + """Reads records from a text io stream.""" + if isinstance(string, BytesIO): + string = string.read().decode("utf-8") + return cls([Record.from_string(line) for line in string.splitlines()]) + + def filter_by_datetime(self, start_date: Optional[dt.datetime] = None, end_date: Optional[dt.datetime] = None, + ) -> List[Record]: + """Returns an iterator over records filtered by start and end datetimes (both optional).""" + return [ + record + for record in self.records + if (not start_date or record.datetime() >= start_date) + and (not end_date or record.datetime() < end_date) + ] + + def to_df(self) -> "pd.DataFrame": + """Reads a local ISD file into a DataFrame.""" + import pandas as pd + return pd.DataFrame([record.to_dict() for record in self.records]) diff --git a/src/isd/cli.py b/isd/cli.py similarity index 53% rename from src/isd/cli.py rename to isd/cli.py index 533fd04..41c17b8 100644 --- a/src/isd/cli.py +++ b/isd/cli.py @@ -1,13 +1,9 @@ # type: ignore -import dataclasses -import itertools -import json - import click from click import ClickException -import isd.io +from isd.batch import Batch @click.group() @@ -20,9 +16,9 @@ def main() -> None: @click.option("-i", "--index", default=0) def record(infile: str, index: int) -> None: """Prints a single record to standard output in JSON format.""" - with isd.io.open(infile) as records: - record = next(itertools.islice(records, index, None), None) - if record: - print(json.dumps(dataclasses.asdict(record), indent=4)) - else: - raise ClickException(f"No record with index {index}") + batch = Batch.from_path(infile) + try: + record_ = batch.records[index] + print(record_.to_json()) + except IndexError: + raise ClickException(f"No record with index {index}") diff --git a/src/isd/errors.py b/isd/errors.py similarity index 100% rename from src/isd/errors.py rename to isd/errors.py diff --git a/isd/record.py b/isd/record.py new file mode 100644 index 0000000..3fe0a5c --- /dev/null +++ b/isd/record.py @@ -0,0 +1,223 @@ +import datetime +import json +from dataclasses import dataclass +from io import BytesIO +from typing import Any, Callable, List, Optional, Tuple, Union, Dict + +from isd.errors import IsdError + +MIN_LINE_LENGTH = 105 + + +@dataclass +class Record: + """A single string of an ISD file.""" + + usaf_id: str + ncei_id: str + year: int + month: int + day: int + hour: int + minute: int + data_source: str + latitude: Optional[float] + longitude: Optional[float] + report_type: Optional[str] + elevation: Optional[float] + call_letters: Optional[str] + quality_control_process: str + wind_direction: Optional[int] + wind_direction_quality_code: str + wind_observation_type: Optional[str] + wind_speed: Optional[float] + wind_speed_quality_code: str + ceiling: Optional[int] + ceiling_quality_code: str + ceiling_determination_code: Optional[str] + cavok_code: Optional[str] + visibility: Optional[int] + visibility_quality_code: str + visibility_variability_code: Optional[str] + visibility_variability_quality_code: str + air_temperature: Optional[float] + air_temperature_quality_code: str + dew_point_temperature: Optional[float] + dew_point_temperature_quality_code: str + sea_level_pressure: Optional[float] + sea_level_pressure_quality_code: str + additional_data: str + remarks: str + element_quality_data: str + original_observation_data: str + + @classmethod + def from_string(cls, string: Union[str, BytesIO]) -> "Record": + """Parses an ISD string into a record.""" + if isinstance(string, BytesIO): + string = string.read().decode("utf-8") + if len(string) < MIN_LINE_LENGTH: + raise IsdError(f"Invalid ISD string (too short): {string}") + string = string.strip() + usaf_id = string[4:10] + ncei_id = string[10:15] + year = int(string[15:19]) + month = int(string[19:21]) + day = int(string[21:23]) + hour = int(string[23:25]) + minute = int(string[25:27]) + data_source = string[27] + # TODO test missing latitudes and longitudes + latitude = cls._transform_value(string[28:34], "+99999", lambda s: float(s) / 1000) + longitude = cls._transform_value(string[34:41], "+999999", lambda s: float(s) / 1000) + report_type = cls._transform_value(string[41:46], "99999") + elevation = cls._transform_value(string[46:51], "+9999", lambda s: float(s)) + call_letters = cls._transform_value(string[51:56], "99999") + quality_control_process = string[56:60] + wind_direction = cls._transform_value(string[60:63], "999", lambda s: int(s)) + wind_direction_quality_code = string[63] + wind_observation_type = cls._transform_value(string[64], "9") + wind_speed = cls._transform_value(string[65:69], "9999", lambda s: float(s) / 10) + wind_speed_quality_code = string[69] + ceiling = cls._transform_value(string[70:75], "99999", lambda s: int(s)) + ceiling_quality_code = string[75] + ceiling_determination_code = cls._transform_value(string[76], "9") + cavok_code = cls._transform_value(string[77], "9") + visibility = cls._transform_value(string[78:84], "999999", lambda s: int(s)) + visibility_quality_code = string[84] + visibility_variability_code = cls._transform_value(string[85], "9") + visibility_variability_quality_code = string[86] + air_temperature = cls._transform_value(string[87:92], "+9999", lambda s: float(s) / 10) + air_temperature_quality_code = string[92] + dew_point_temperature = cls._transform_value(string[93:98], "+9999", lambda s: float(s) / 10) + dew_point_temperature_quality_code = string[98] + sea_level_pressure = cls._transform_value(string[99:104], "99999", lambda s: float(s) / 10) + sea_level_pressure_quality_code = string[104] + additional_data, remainder = cls._extract_data( + string[105:], "ADD", ["REM", "EQD", "QNN"] + ) + remarks, remainder = cls._extract_data(remainder, "REM", ["EQD", "QNN"]) + element_quality_data, remainder = cls._extract_data(remainder, "EQD", ["QNN"]) + original_observation_data, remainder = cls._extract_data(remainder, "QNN", []) + + if remainder: + raise IsdError(f"Invalid remainder after parsing: {remainder}") + + return cls( + usaf_id=usaf_id, + ncei_id=ncei_id, + year=year, + month=month, + day=day, + hour=hour, + minute=minute, + data_source=data_source, + latitude=latitude, + longitude=longitude, + report_type=report_type, + elevation=elevation, + call_letters=call_letters, + quality_control_process=quality_control_process, + wind_direction=wind_direction, + wind_direction_quality_code=wind_direction_quality_code, + wind_observation_type=wind_observation_type, + wind_speed=wind_speed, + wind_speed_quality_code=wind_speed_quality_code, + ceiling=ceiling, + ceiling_quality_code=ceiling_quality_code, + ceiling_determination_code=ceiling_determination_code, + cavok_code=cavok_code, + visibility=visibility, + visibility_quality_code=visibility_quality_code, + visibility_variability_code=visibility_variability_code, + visibility_variability_quality_code=visibility_variability_quality_code, + air_temperature=air_temperature, + air_temperature_quality_code=air_temperature_quality_code, + dew_point_temperature=dew_point_temperature, + dew_point_temperature_quality_code=dew_point_temperature_quality_code, + sea_level_pressure=sea_level_pressure, + sea_level_pressure_quality_code=sea_level_pressure_quality_code, + additional_data=additional_data, + remarks=remarks, + element_quality_data=element_quality_data, + original_observation_data=original_observation_data, + ) + + @classmethod + def _extract_data(cls, message: str, tag: str, later_tags: List[str]) -> Tuple[str, str]: + if message.startswith(tag): + index = None + for other_tag in later_tags: + try: + index = message.find(other_tag) + except ValueError: + continue + break + if index != -1: + data = message[len(tag): index] + tail = message[index:] + return data, tail + else: + return message[len(tag):], "" + else: + return "", message + + @classmethod + def _transform_value( + cls, string: str, missing_value: str, transform: Optional[Callable[[str], Any]] = None + ) -> Any: + if string == missing_value: + return None + elif transform: + return transform(string) + else: + return string + + def datetime(self) -> datetime.datetime: + """Returns this record's datetime.""" + return datetime.datetime( + self.year, self.month, self.day, self.hour, self.minute + ) + + def to_dict(self) -> Dict[str, Any]: + """Returns a dictionary representation of this record.""" + return { + "usaf_id": self.usaf_id, + "ncei_id": self.ncei_id, + # use datetime instead of year, month, day, hour, minute + "datetime": self.datetime(), + "data_source": self.data_source, + "latitude": self.latitude, + "longitude": self.longitude, + "report_type": self.report_type, + "elevation": self.elevation, + "call_letters": self.call_letters, + "quality_control_process": self.quality_control_process, + "wind_direction": self.wind_direction, + "wind_direction_quality_code": self.wind_direction_quality_code, + "wind_observation_type": self.wind_observation_type, + "wind_speed": self.wind_speed, + "wind_speed_quality_code": self.wind_speed_quality_code, + "ceiling": self.ceiling, + "ceiling_quality_code": self.ceiling_quality_code, + "ceiling_determination_code": self.ceiling_determination_code, + "cavok_code": self.cavok_code, + "visibility": self.visibility, + "visibility_quality_code": self.visibility_quality_code, + "visibility_variability_code": self.visibility_variability_code, + "visibility_variability_quality_code": self.visibility_variability_quality_code, + "air_temperature": self.air_temperature, + "air_temperature_quality_code": self.air_temperature_quality_code, + "dew_point_temperature": self.dew_point_temperature, + "dew_point_temperature_quality_code": self.dew_point_temperature_quality_code, + "sea_level_pressure": self.sea_level_pressure, + "sea_level_pressure_quality_code": self.sea_level_pressure_quality_code, + "additional_data": self.additional_data, + "remarks": self.remarks, + "element_quality_data": self.element_quality_data, + "original_observation_data": self.original_observation_data, + } + + def to_json(self, indent: int = 4) -> str: + """Returns a JSON representation of this record.""" + return json.dumps(self.to_dict(), indent=indent) diff --git a/pyproject.toml b/pyproject.toml index 7bf401c..0a09c83 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,5 +7,5 @@ requires = [ build-backend = "setuptools.build_meta" [tool.setuptools_scm] -write_to = "src/isd/_version.py" +write_to = "isd/_version.py" local_scheme = "no-local-version" diff --git a/requirements-dev.txt b/requirements-dev.txt index e4abbcc..35bd2a5 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ mypy +pandas~=1.3 pandas-stubs pre-commit pytest diff --git a/setup.cfg b/setup.cfg index 55d639d..1917195 100644 --- a/setup.cfg +++ b/setup.cfg @@ -20,8 +20,6 @@ packages = find: python_requires = >=3.7 install_requires = click ~= 8.0 - pandas ~= 1.3 - [options.packages.find] where = src diff --git a/src/isd/io.py b/src/isd/io.py deleted file mode 100644 index f35621f..0000000 --- a/src/isd/io.py +++ /dev/null @@ -1,48 +0,0 @@ -import datetime -import gzip -import os.path -from contextlib import contextmanager -from typing import Generator, Iterable, Iterator, Optional, TextIO - -from pandas import DataFrame - -from . import pandas as isd_pandas -from .record import Record - -builtin_open = open - - -@contextmanager -def open(path: str) -> Generator[Iterable[Record], None, None]: - """Opens a local ISD file and returns an iterator over its records. - - If the path has a .gz extension, this function will assume it has gzip - compression and will attempt to open it using `gzip.open`. - """ - if os.path.splitext(path)[1] == ".gz": - with gzip.open(path) as gzip_file: - yield (Record.parse(gzip_line.decode("utf-8")) for gzip_line in gzip_file) - else: - with builtin_open(path) as uncompressed_file: - yield ( - Record.parse(uncompressed_line) - for uncompressed_line in uncompressed_file - ) - - -def from_text_io(text_io: TextIO) -> Iterator[Record]: - """Reads records from a text io stream.""" - while True: - line = text_io.readline() - if not line: - break - else: - yield Record.parse(line) - - -def read_to_data_frame( - path: str, since: Optional[datetime.datetime] = None -) -> DataFrame: - """Reads a local ISD file into a DataFrame.""" - with open(path) as file: - return isd_pandas.data_frame(file, since=since) diff --git a/src/isd/pandas.py b/src/isd/pandas.py deleted file mode 100644 index a4b28ec..0000000 --- a/src/isd/pandas.py +++ /dev/null @@ -1,165 +0,0 @@ -import datetime -from typing import Iterable, Optional - -import pandas -from pandas import CategoricalDtype, DataFrame - -from isd import Record - -DataSourceDtype = CategoricalDtype( - [ - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "8", - "A", - "B", - "C", - "D", - "E", - "F", - "G", - "H", - "I", - "J", - "K", - "L", - "M", - "N", - "O", - ] -) -ReportTypeDtype = CategoricalDtype( - [ - "AERO", - "AUST", - "AUTO", - "BOGUS", - "BRAZ", - "COOPD", - "COOPS", - "CRB", - "CRN05", - "CRN15", - "FM-12", - "FM-13", - "FM-14", - "FM-15", - "FM-16", - "FM-18", - "GREEN", - "MESOH", - "MESOS", - "MESOW", - "MEXIC", - "NSRDB", - "PCP15", - "PCP60", - "S-S-A", - "SA-AU", - "SAO", - "SAOSP", - "SHEF", - "SMARS", - "SOD", - "SOM", - "SURF", - "SY-AE", - "SY-AU", - "SY-MT", - "SY-SA", - "WBO", - "WNO", - ] -) -QualityControlProcessDtype = CategoricalDtype(["V01", "V02", "V03"]) -QualityCodeDtype = CategoricalDtype( - [ - "0", - "1", - "2", - "3", - "4", - "5", - "6", - "7", - "9", - "A", - "U", - "P", - "I", - "M", - "C", - "R", - ] -) -WindObservationTypeDtype = CategoricalDtype( - ["A", "B", "C", "H", "N", "R", "Q", "T", "V"] -) -CeilingDeterminationCodeDtype = CategoricalDtype( - ["A", "B", "C", "D", "E", "M", "P", "R", "S", "U", "V", "W"] -) -CavokCodeDtype = CategoricalDtype(["N", "Y"]) -VisibilityVariabilityCodeDtype = CategoricalDtype(["N", "V"]) - - -def data_frame( - records: Iterable[Record], since: Optional[datetime.datetime] = None -) -> DataFrame: - """Constructs a pandas data frame from an iterable of Records. - - Uses appropriate datatypes and categorical variables. - """ - data_frame = DataFrame(records).astype( - { - "usaf_id": "string", - "ncei_id": "string", - "year": "UInt16", - "month": "UInt8", - "day": "UInt8", - "hour": "UInt8", - "minute": "UInt8", - "data_source": DataSourceDtype, - "latitude": "float", - "longitude": "float", - "report_type": ReportTypeDtype, - "elevation": "Int16", - "call_letters": "string", - "quality_control_process": QualityControlProcessDtype, - "wind_direction": "UInt16", - "wind_direction_quality_code": QualityCodeDtype, - "wind_observation_type": WindObservationTypeDtype, - "wind_speed": "float", - "wind_speed_quality_code": QualityCodeDtype, - "ceiling": "float", - "ceiling_quality_code": QualityCodeDtype, - "ceiling_determination_code": CeilingDeterminationCodeDtype, - "cavok_code": CavokCodeDtype, - "visibility": "UInt32", - "visibility_quality_code": QualityCodeDtype, - "visibility_variability_code": VisibilityVariabilityCodeDtype, - "visibility_variability_quality_code": QualityCodeDtype, - "air_temperature": "float", - "air_temperature_quality_code": QualityCodeDtype, - "dew_point_temperature": "float", - "dew_point_temperature_quality_code": QualityCodeDtype, - "sea_level_pressure": "float", - "sea_level_pressure_quality_code": QualityCodeDtype, - "additional_data": "string", - "remarks": "string", - "element_quality_data": "string", - "original_observation_data": "string", - } - ) - timestamp = pandas.to_datetime( - data_frame[["year", "month", "day", "hour", "minute"]] - ) - data_frame["timestamp"] = timestamp - if since: - return data_frame[data_frame["timestamp"] > since] - else: - return data_frame diff --git a/src/isd/record.py b/src/isd/record.py deleted file mode 100644 index e0cc94c..0000000 --- a/src/isd/record.py +++ /dev/null @@ -1,174 +0,0 @@ -import datetime -from dataclasses import dataclass -from typing import Any, Callable, List, Optional, Tuple - -from isd.errors import IsdError - -MIN_LINE_LENGTH = 105 - - -@dataclass -class Record: - """A single line of an ISD file.""" - - usaf_id: str - ncei_id: str - year: int - month: int - day: int - hour: int - minute: int - data_source: str - latitude: Optional[float] - longitude: Optional[float] - report_type: Optional[str] - elevation: Optional[float] - call_letters: Optional[str] - quality_control_process: str - wind_direction: Optional[int] - wind_direction_quality_code: str - wind_observation_type: Optional[str] - wind_speed: Optional[float] - wind_speed_quality_code: str - ceiling: Optional[int] - ceiling_quality_code: str - ceiling_determination_code: Optional[str] - cavok_code: Optional[str] - visibility: Optional[int] - visibility_quality_code: str - visibility_variability_code: Optional[str] - visibility_variability_quality_code: str - air_temperature: Optional[float] - air_temperature_quality_code: str - dew_point_temperature: Optional[float] - dew_point_temperature_quality_code: str - sea_level_pressure: Optional[float] - sea_level_pressure_quality_code: str - additional_data: str - remarks: str - element_quality_data: str - original_observation_data: str - - @classmethod - def parse(cls, line: str) -> "Record": - """Parses an ISD line into a record.""" - if len(line) < MIN_LINE_LENGTH: - raise IsdError(f"Invalid ISD line (too short): {line}") - line = line.strip() - usaf_id = line[4:10] - ncei_id = line[10:15] - year = int(line[15:19]) - month = int(line[19:21]) - day = int(line[21:23]) - hour = int(line[23:25]) - minute = int(line[25:27]) - data_source = line[27] - # TODO test missing latitudes and longitudes - latitude = optional(line[28:34], "+99999", lambda s: float(s) / 1000) - longitude = optional(line[34:41], "+999999", lambda s: float(s) / 1000) - report_type = optional(line[41:46], "99999") - elevation = optional(line[46:51], "+9999", lambda s: float(s)) - call_letters = optional(line[51:56], "99999") - quality_control_process = line[56:60] - wind_direction = optional(line[60:63], "999", lambda s: int(s)) - wind_direction_quality_code = line[63] - wind_observation_type = optional(line[64], "9") - wind_speed = optional(line[65:69], "9999", lambda s: float(s) / 10) - wind_speed_quality_code = line[69] - ceiling = optional(line[70:75], "99999", lambda s: int(s)) - ceiling_quality_code = line[75] - ceiling_determination_code = optional(line[76], "9") - cavok_code = optional(line[77], "9") - visibility = optional(line[78:84], "999999", lambda s: int(s)) - visibility_quality_code = line[84] - visibility_variability_code = optional(line[85], "9") - visibility_variability_quality_code = line[86] - air_temperature = optional(line[87:92], "+9999", lambda s: float(s) / 10) - air_temperature_quality_code = line[92] - dew_point_temperature = optional(line[93:98], "+9999", lambda s: float(s) / 10) - dew_point_temperature_quality_code = line[98] - sea_level_pressure = optional(line[99:104], "99999", lambda s: float(s) / 10) - sea_level_pressure_quality_code = line[104] - additional_data, remainder = extract_data( - line[105:], "ADD", ["REM", "EQD", "QNN"] - ) - remarks, remainder = extract_data(remainder, "REM", ["EQD", "QNN"]) - element_quality_data, remainder = extract_data(remainder, "EQD", ["QNN"]) - original_observation_data, remainder = extract_data(remainder, "QNN", []) - assert not remainder - - return cls( - usaf_id=usaf_id, - ncei_id=ncei_id, - year=year, - month=month, - day=day, - hour=hour, - minute=minute, - data_source=data_source, - latitude=latitude, - longitude=longitude, - report_type=report_type, - elevation=elevation, - call_letters=call_letters, - quality_control_process=quality_control_process, - wind_direction=wind_direction, - wind_direction_quality_code=wind_direction_quality_code, - wind_observation_type=wind_observation_type, - wind_speed=wind_speed, - wind_speed_quality_code=wind_speed_quality_code, - ceiling=ceiling, - ceiling_quality_code=ceiling_quality_code, - ceiling_determination_code=ceiling_determination_code, - cavok_code=cavok_code, - visibility=visibility, - visibility_quality_code=visibility_quality_code, - visibility_variability_code=visibility_variability_code, - visibility_variability_quality_code=visibility_variability_quality_code, - air_temperature=air_temperature, - air_temperature_quality_code=air_temperature_quality_code, - dew_point_temperature=dew_point_temperature, - dew_point_temperature_quality_code=dew_point_temperature_quality_code, - sea_level_pressure=sea_level_pressure, - sea_level_pressure_quality_code=sea_level_pressure_quality_code, - additional_data=additional_data, - remarks=remarks, - element_quality_data=element_quality_data, - original_observation_data=original_observation_data, - ) - - def datetime(self) -> datetime.datetime: - """Returns this record's datetime.""" - return datetime.datetime( - self.year, self.month, self.day, self.hour, self.minute - ) - - -def extract_data(message: str, tag: str, later_tags: List[str]) -> Tuple[str, str]: - if message.startswith(tag): - index = None - for other_tag in later_tags: - try: - index = message.find(other_tag) - except ValueError: - continue - break - if index != -1: - data = message[len(tag) : index] - tail = message[index:] - return data, tail - else: - return message[len(tag) :], "" - else: - return "", message - - -def optional( - string: str, missing_value: str, transform: Optional[Callable[[str], Any]] = None -) -> Any: - if string == missing_value: - return None - elif transform: - return transform(string) - else: - return string diff --git a/src/isd/utils.py b/src/isd/utils.py deleted file mode 100644 index 849009f..0000000 --- a/src/isd/utils.py +++ /dev/null @@ -1,18 +0,0 @@ -import datetime -from typing import Iterable, Iterator, Optional - -from isd.record import Record - - -def filter_by_datetime( - records: Iterable[Record], - start: Optional[datetime.datetime] = None, - end: Optional[datetime.datetime] = None, -) -> Iterator[Record]: - """Returns an iterator over records filtered by start and end datetimes (both optional).""" - return ( - record - for record in records - if (not start or record.datetime() >= start) - and (not end or record.datetime() < end) - ) diff --git a/tests/conftest.py b/tests/conftest.py index e660335..c490ce4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -50,7 +50,7 @@ def half_path() -> str: @pytest.fixture def records() -> List[Record]: with open_data_file(VANCE_BRAND_FILE_NAME) as f: - return [Record.parse(line) for line in f] + return [Record.from_string(line) for line in f] @pytest.fixture diff --git a/tests/test_batch.py b/tests/test_batch.py new file mode 100644 index 0000000..eebbfa5 --- /dev/null +++ b/tests/test_batch.py @@ -0,0 +1,37 @@ +import datetime as dt +from typing import List + +from isd.batch import Batch +from isd import Record + + +def test_bach_from_uncompressed(uncompressed_path: str) -> None: + batch = Batch.from_path(uncompressed_path) + assert len(batch.records) == 500 + + +def test_batch_from_compressed(compressed_path: str) -> None: + batch = Batch.from_path(compressed_path) + assert len(batch.records) == 24252 + + +def test_from_string(uncompressed_path: str) -> None: + with open(uncompressed_path) as file: + batch = Batch.from_string(file.read()) + assert len(batch.records) == 500 + + +def test_filter_by_datetime(records: List[Record]) -> None: + batch = Batch(records) + assert len(batch.filter_by_datetime(start_date=dt.datetime(2021, 1, 1, 3, 30))) == 490 + assert len(batch.filter_by_datetime(end_date=dt.datetime(2021, 1, 1, 3, 30))) == 10 + assert len(batch.filter_by_datetime(start_date=dt.datetime(2021, 1, 1, 3, 30), end_date=dt.datetime(2021, 1, 1, 3, 55))) == 1 + assert len(batch.filter_by_datetime(start_date=dt.datetime(2021, 1, 1, 3, 30), end_date=dt.datetime(2021, 1, 1, 3, 56))) == 2 + + +def test_batch_to_df(uncompressed_path: str) -> None: + batch = Batch.from_path(uncompressed_path) + datetime_min = dt.datetime(2021, 1, 5) + df = batch.to_df() + df = df[df["datetime"] >= datetime_min] + assert len(df) == 212 diff --git a/tests/test_io.py b/tests/test_io.py deleted file mode 100644 index ed44c24..0000000 --- a/tests/test_io.py +++ /dev/null @@ -1,28 +0,0 @@ -import datetime - -import isd.io - - -def test_open_uncompressed(uncompressed_path: str) -> None: - with isd.io.open(uncompressed_path) as generator: - records = list(generator) - assert len(records) == 500 - - -def test_open_compressed(compressed_path: str) -> None: - with isd.io.open(compressed_path) as generator: - records = list(generator) - assert len(records) == 24252 - - -def test_read_to_data_frame_since(uncompressed_path: str) -> None: - data_frame = isd.io.read_to_data_frame( - uncompressed_path, since=datetime.datetime(2021, 1, 5) - ) - assert len(data_frame) == 212 - - -def test_from_text_io(uncompressed_path: str) -> None: - with open(uncompressed_path) as file: - records = list(isd.io.from_text_io(file)) - assert len(records) == 500 diff --git a/tests/test_pandas.py b/tests/test_pandas.py deleted file mode 100644 index a294581..0000000 --- a/tests/test_pandas.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import List - -import isd.pandas -from isd import Record - - -def test_data_frame(records: List[Record]) -> None: - isd.pandas.data_frame(records) diff --git a/tests/test_record.py b/tests/test_record.py index b7f0826..5248e13 100644 --- a/tests/test_record.py +++ b/tests/test_record.py @@ -4,7 +4,7 @@ def test_parse(record_line: str) -> None: - record = Record.parse(record_line) + record = Record.from_string(record_line) assert record.usaf_id == "720538" assert record.ncei_id == "00164" assert record.year == 2021 @@ -51,4 +51,4 @@ def test_parse(record_line: str) -> None: def test_line_too_short() -> None: with pytest.raises(IsdError): - Record.parse("") + Record.from_string("") diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 169e034..0000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,52 +0,0 @@ -import datetime -from typing import List - -import isd.utils -from isd.record import Record - - -def test_filter_by_datetime(records: List[Record]) -> None: - assert ( - len( - list( - isd.utils.filter_by_datetime( - records, start=datetime.datetime(2021, 1, 1, 3, 30) - ) - ) - ) - == 490 - ) - assert ( - len( - list( - isd.utils.filter_by_datetime( - records, end=datetime.datetime(2021, 1, 1, 3, 30) - ) - ) - ) - == 10 - ) - assert ( - len( - list( - isd.utils.filter_by_datetime( - records, - start=datetime.datetime(2021, 1, 1, 3, 30), - end=datetime.datetime(2021, 1, 1, 3, 55), - ) - ) - ) - == 1 - ) - assert ( - len( - list( - isd.utils.filter_by_datetime( - records, - start=datetime.datetime(2021, 1, 1, 3, 30), - end=datetime.datetime(2021, 1, 1, 3, 56), - ) - ) - ) - == 2 - )