Source code for openaq_engine.config.model_settings

import os
from dataclasses import field
from typing import Any, Dict, List, Sequence

import boto3
from pydantic import StrictStr
from pydantic.dataclasses import dataclass


[docs] @dataclass class ModelVisualizerConfig: PLOT: bool = True PLOT_METRICS: Sequence[str] = field(default_factory=lambda: ["mean"]) PLOTS_TABLE_NAME: str = "" PLOTS_SCHEMA_NAME: str = "" RESULTS_TABLE_NAME: str = "results"
[docs] @dataclass class MatrixGeneratorConfig: ALGORITHM = "RFR" ID_COLUMN_LIST: Sequence[str] = field( default_factory=lambda: ["locationId", "cohort", "cohort_type"] )
[docs] @dataclass class ModelTrainerConfig: MODEL_NAMES_LIST = ["RFR"] # "DTC", "MNB", "RFC", "MLR" ID_COLS_TO_REMOVE = [ "location_id", "cohort", "cohort_type", ] RANDOM_STATE = 99 All_MODEL_FEATURES = [ "Optical_Depth_047", "B4", "B3", "B2", "avg_rad", "temperature_2m_above_ground", "relative_humidity_2m_above_ground", "total_precipitation_surface", "total_cloud_cover_entire_atmosphere", "u_component_of_wind_10m_above_ground", "v_component_of_wind_10m_above_ground", "basic_demographic_characteristics", "discrete_classification", ]
[docs] @dataclass class HyperparamConfig: MODEL_TYPES = ["DTC", "MNB", "RFR", "XGB"] MODEL_HYPERPARAMS = { "DTC": { "max_depth": [5, 10, 20, 30, 40] }, # 5, 50, 500, 10000 50, 100, 200, 300 "RFR": { "n_estimators": [500, 800], # 100, 500, 800, 1000 "max_depth": [10, 50, 70], # 5, 50, 80, 500, 10000 100, 200, 300 }, "XGB": { "max_depth": [5, 150, 200, 250, 300], "learning_rate": [0.1, 0.5, 1], }, "MNB": {"alpha": [0, 0.05]}, # 0.1, 0.5, 0.8, 1 "MLR": { "penalty": ["l2"], "C": [1, 0.1, 0.01], "solver": ["saga"], "max_iter": [2000], }, }
[docs] @dataclass class BuildFeaturesConfig: TARGET_COL: str = "value" TARGET_VARIABLE = "pm25" COUNTRY = "" CITY = "" CATEGORICAL_FEATURES: List[StrictStr] = field(default_factory=lambda: []) CORE_FEATURES: List[StrictStr] = field( default_factory=lambda: [ "city", "country", "pca_lat", "pca_lng", "sourcetype", "mobile", ] ) SATELLITE_FEATURES = [] @property def ALL_MODEL_FEATURES(self) -> List[str]: """Return all features to be fed into the model""" return list(set(self.CORE_FEATURES + self.CATEGORICAL_FEATURES))
[docs] @dataclass class EEConfig: LOOKBACK_N = 1 DATE_COL: str = "timestamp_utc" TABLE_NAME = "cohorts" # Satellite configurations AOD_IMAGE_COLLECTION: str = "MODIS/006/MCD19A2_GRANULES" AOD_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: ["Optical_Depth_047"] ) AOD_IMAGE_PERIOD = 2 AOD_IMAGE_RES = 1000 LANDSAT_IMAGE_COLLECTION: str = "LANDSAT/LC08/C01/T1" LANDSAT_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: ["B4", "B3", "B2"] ) LANDSAT_PERIOD = 8 LANDSAT_RES = 30 NIGHTTIME_LIGHT_IMAGE_COLLECTION: str = "NOAA/VIIRS/DNB/MONTHLY_V1/VCMCFG" NIGHTTIME_LIGHT_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: ["avg_rad"] ) NIGHTTIME_LIGHT_PERIOD = 30 NIGHTTIME_LIGHT_RES = 463.83 METEROLOGICAL_IMAGE_COLLECTION: str = "NOAA/GFS0P25" METEROLOGICAL_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: [ "temperature_2m_above_ground", "relative_humidity_2m_above_ground", "total_precipitation_surface", "total_cloud_cover_entire_atmosphere", "u_component_of_wind_10m_above_ground", "v_component_of_wind_10m_above_ground", ] ) METEROLOGICAL_IMAGE_PERIOD = 1 METEROLOGICAL_IMAGE_RES = 27830 POPULATION_IMAGE_COLLECTION: str = ( "CIESIN/GPWv411/GPW_Basic_Demographic_Characteristics" ) POPULATION_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: ["basic_demographic_characteristics"] ) POPULATION_PERIOD = 1100 POPULATION_IMAGE_RES = 1000 LAND_COVER_IMAGE_COLLECTION: str = ( "COPERNICUS/Landcover/100m/Proba-V-C3/Global" ) LAND_COVER_IMAGE_BAND: Sequence[str] = field( default_factory=lambda: ["discrete_classification"] ) LAND_COVER_IMAGE_RES = 100 LAND_COVER_PERIOD = 1500 BUCKET_NAME = "earthengine-bucket" PATH_TO_PRIVATE_KEY = "" # please provide the path to the private key for the service account BUCKET_NAME = "" # please provide the bucket name SERVICE_ACCOUNT = "" # please provide the service account @property def ALL_SATELLITES(self) -> zip(List[str], List[str]): # type: ignore """Return varying satellites to be fed into the model""" return zip( [ self.AOD_IMAGE_COLLECTION, self.LANDSAT_IMAGE_COLLECTION, self.NIGHTTIME_LIGHT_IMAGE_COLLECTION, self.METEROLOGICAL_IMAGE_COLLECTION, self.POPULATION_IMAGE_COLLECTION, self.LAND_COVER_IMAGE_COLLECTION, ], [ self.AOD_IMAGE_BAND, self.LANDSAT_IMAGE_BAND, self.NIGHTTIME_LIGHT_IMAGE_BAND, self.METEROLOGICAL_IMAGE_BAND, self.POPULATION_IMAGE_BAND, self.LAND_COVER_IMAGE_BAND, ], [ self.AOD_IMAGE_PERIOD, self.LANDSAT_PERIOD, self.NIGHTTIME_LIGHT_PERIOD, self.METEROLOGICAL_IMAGE_PERIOD, self.POPULATION_PERIOD, self.LAND_COVER_PERIOD, ], [ self.AOD_IMAGE_RES, self.LANDSAT_RES, self.NIGHTTIME_LIGHT_RES, self.METEROLOGICAL_IMAGE_RES, self.POPULATION_IMAGE_RES, self.LAND_COVER_IMAGE_RES, ], )
[docs] @dataclass class CohortBuilderConfig: ENTITY_ID_COLS: Sequence[str] = field( default_factory=lambda: ["unique_id"] ) DATE_COL: str = "date.utc" CITY = "" SENSOR_TYPE = "reference grade" REGION = "us-east-1" S3_BUCKET = os.getenv("S3_BUCKET_OPENAQ") S3_OUTPUT = os.getenv("S3_OUTPUT_OPENAQ") TABLE_NAME = "" SCHEMA_NAME: str = "" FILTER_DICT: Dict[str, Any] = field( default_factory=lambda: dict( filter_pollutant=["parameter"], filter_non_null_values=["value"], filter_extreme_values=["value"], filter_no_coordinates=["coordinates"], filter_countries=["country"], filter_cities=["city"], ), ) TARGET_VARIABLE = "" COUNTRY = "" SOURCE = "" LOCAL_DATA = ""
[docs] @dataclass class TimeSplitterConfig: DATE_COL: str = "date.utc" TARGET_VARIABLE = "" COUNTRY = "" CITY = "" SENSOR_TYPE = "reference grade" SOURCE = "" LOCAL_DATA = "" TIME_WINDOW_LENGTH: int = 4 WITHIN_WINDOW_SAMPLER: int = 4 WINDOW_COUNT: int = 10 # this will increase for more than one split TABLE_NAME: str = "" REGION = "" DATABASE = os.getenv("DB_NAME_OPENAQ") AWS_ACCESS_KEY = os.getenv("AWS_ACCESS_KEY") AWS_SECRET_ACCESS_KEY = os.getenv("AWS_SECRET_ACCESS_KEY") S3_BUCKET = os.getenv("S3_BUCKET_OPENAQ") S3_OUTPUT = os.getenv("S3_OUTPUT_OPENAQ") RESOURCE = boto3.resource("s3") TRAIN_VALIDATION_DICT: Dict[str, List[Any]] = field( default_factory=lambda: dict( validation=[], training=[], ) )