__version__ = "0.0.1"

from typing import NamedTuple, Iterable

from argparse import ArgumentParser
from threading import Lock
import datetime
from traceback import print_exc
import signal

from isodate import parse_duration
import requests

from prometheus_client.core import REGISTRY, Metric, GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client import (
    start_http_server,
    PROCESS_COLLECTOR,
    PLATFORM_COLLECTOR,
    GC_COLLECTOR,
)


class Forecast(NamedTuple):
    pv_estimate: float
    pv_estimate10: float
    pv_estimate90: float
    period_end: datetime.datetime
    period: datetime.timedelta
    
    @property
    def period_start(self) -> datetime.datetime:
        return self.period_end - self.period


class SolcastCollector(Collector):
    _resource_id: str
    _api_key: str
    _max_age_minutes: float
    
    _lock: Lock
    
    _forecasts: list[Forecast] | None
    _last_update: datetime.datetime | None
    
    def __init__(self, resource_id: str, api_key: str, max_age_minutes: float) -> None:
        super().__init__()
        self._resource_id = resource_id
        self._api_key = api_key
        self._max_age_minutes = max_age_minutes
        
        self._lock = Lock()
        self._forecasts = None
        self._last_update = None
    
    def _refresh_data_if_stale(self) -> None:
        now = datetime.datetime.now(datetime.timezone.utc)
        if (
            self._last_update is None or
            (now - self._last_update).total_seconds() >= (self._max_age_minutes * 60)
        ):
            try:
                print("Requesting data")
                response = requests.get(
                    f"https://api.solcast.com.au/rooftop_sites/{self._resource_id}/forecasts",
                    params={
                        "format": "json",
                        "api_key": self._api_key,
                    },
                )
                response.raise_for_status()
                self._forecasts = [
                    Forecast(
                        pv_estimate=f["pv_estimate"],
                        pv_estimate10=f["pv_estimate10"],
                        pv_estimate90=f["pv_estimate90"],
                        period_end=datetime.datetime.fromisoformat(f["period_end"]),
                        period=parse_duration(f["period"]),
                    )
                    for f in response.json()["forecasts"]
                ]
                self._last_update = now
            except Exception:
                self._forecasts = None
                self._last_update = None
                raise
    
    def collect(self) -> Iterable[Metric]:
        with self._lock:
            up = True
            try:
                self._refresh_data_if_stale()
                
                # Report forecasts
                if self._forecasts is not None:
                    forecast_metric = GaugeMetricFamily(
                        "solcast_forecast_watts",
                        "Solcast forecasted solar output (watts)",
                        labels=["resource_id", "time_offset", "centile"],
                    )
                    
                    now = datetime.datetime.now(datetime.timezone.utc)
                    
                    # Skip past historic data.
                    forecasts = sorted(self._forecasts, key=lambda f: f.period_end)
                    while forecasts and now > forecasts[0].period_end:
                        forecasts.pop(0)
                    
                    # If forecast doesn't include 'now', fail because otherwise we
                    # must manually extrapolate the sample interval which we're
                    # currently in. Since 'now' should always be present in the
                    # forecast, I'm too lazy to implement this...
                    if not forecasts or forecasts[0].period_start > now:
                        raise ValueError("Forecast doesn't include data for now.")
                    
                    last_forecast = forecasts[0]
                    for forecast in forecasts:
                        time_offset = round(
                            (forecast.period_end - last_forecast.period_end).total_seconds() / 60
                        )
                        for value, centile in [
                            (forecast.pv_estimate, ""),
                            (forecast.pv_estimate10, "10"),
                            (forecast.pv_estimate90, "90"),
                        ]:
                            forecast_metric.add_metric(
                                [self._resource_id, str(time_offset), centile],
                                value * 1000,  # Convert from kW
                            )
                    
                    yield forecast_metric
            except Exception:
                up = False
                print_exc()
            
            # Report current up-status
            up_metric = GaugeMetricFamily(
                "solcast_up",
                "Is the Solcast API responding successfully?",
                labels=["resource_id"],
            )
            up_metric.add_metric([self._resource_id], up)
            yield up_metric


def run_exporter_until_terminated(*args, **kwargs) -> None:
    """
    Wrapper around :py:func:`prometheus_client.start_http_server` which runs
    the server until sigint (Ctrl+C) or sigterm at which point it shuts down
    the gracefully and returns.
    """
    server, thread = start_http_server(*args, **kwargs)
    for sig in [signal.SIGINT, signal.SIGTERM]:
        signal.signal(sig, lambda *_: server.shutdown())
    thread.join()


def main() -> None:
    parser = ArgumentParser()
    parser.add_argument(
        "resource_id",
        type=str,
        help="""
            The Solcast resource ID for the solar array to fetch data for.
            Typically this is a series of hex digits of the form
            XXXX-XXXX-XXXX-XXXX.
        """,
    )
    parser.add_argument(
        "--api-key",
        type=str,
        required=True,
        help="""
            The Solcast API key.
        """,
    )
    parser.add_argument(
        "--max-age",
        type=float,
        default=180.0,
        help="""
            The maximum age of the data to return, in minutes (effectively the
            time between API requests). Default %(default)s.
        """,
    )
    parser.add_argument(
        "--address",
        "-a",
        type=str,
        default="0.0.0.0",
        help="""
            The Prometheus listen address. Default: %(default)s.
        """,
    )
    parser.add_argument(
        "--port",
        "-p",
        type=int,
        default=9534,
        help="""
            The Prometheus listen port. Default: %(default)s.
        """,
    )
    args = parser.parse_args()
    
    REGISTRY.unregister(PROCESS_COLLECTOR)
    REGISTRY.unregister(PLATFORM_COLLECTOR)
    REGISTRY.unregister(GC_COLLECTOR)
    REGISTRY.register(SolcastCollector(args.resource_id, args.api_key, args.max_age))

    run_exporter_until_terminated(port=args.port, addr=args.address)


if __name__ == "__main__":
    main()
