from datetime import datetime, timedelta

from tzlocal import get_localzone

from apscheduler.triggers.base import BaseTrigger
from apscheduler.triggers.cron.fields import (
    DEFAULT_VALUES,
    BaseField,
    DayOfMonthField,
    DayOfWeekField,
    MonthField,
    WeekField,
)
from apscheduler.util import (
    astimezone,
    convert_to_datetime,
    datetime_ceil,
    datetime_repr,
)


class CronTrigger(BaseTrigger):
    """
    Triggers when current time matches all specified time constraints,
    similarly to how the UNIX cron scheduler works.

    :param int|str year: 4-digit year
    :param int|str month: month (1-12)
    :param int|str day: day of month (1-31)
    :param int|str week: ISO week (1-53)
    :param int|str day_of_week: number or name of weekday (0-6 or mon,tue,wed,thu,fri,sat,sun)
    :param int|str hour: hour (0-23)
    :param int|str minute: minute (0-59)
    :param int|str second: second (0-59)
    :param datetime|str start_date: earliest possible date/time to trigger on (inclusive)
    :param datetime|str end_date: latest possible date/time to trigger on (inclusive)
    :param datetime.tzinfo|str timezone: time zone to use for the date/time calculations (defaults
        to scheduler timezone)
    :param int|None jitter: delay the job execution by ``jitter`` seconds at most

    .. note:: The first weekday is always **monday**.
    """

    FIELD_NAMES = (
        "year",
        "month",
        "day",
        "week",
        "day_of_week",
        "hour",
        "minute",
        "second",
    )
    FIELDS_MAP = {
        "year": BaseField,
        "month": MonthField,
        "week": WeekField,
        "day": DayOfMonthField,
        "day_of_week": DayOfWeekField,
        "hour": BaseField,
        "minute": BaseField,
        "second": BaseField,
    }

    __slots__ = "timezone", "start_date", "end_date", "fields", "jitter"

    def __init__(
        self,
        year=None,
        month=None,
        day=None,
        week=None,
        day_of_week=None,
        hour=None,
        minute=None,
        second=None,
        start_date=None,
        end_date=None,
        timezone=None,
        jitter=None,
    ):
        if timezone:
            self.timezone = astimezone(timezone)
        elif isinstance(start_date, datetime) and start_date.tzinfo:
            self.timezone = astimezone(start_date.tzinfo)
        elif isinstance(end_date, datetime) and end_date.tzinfo:
            self.timezone = astimezone(end_date.tzinfo)
        else:
            self.timezone = get_localzone()

        self.start_date = convert_to_datetime(start_date, self.timezone, "start_date")
        self.end_date = convert_to_datetime(end_date, self.timezone, "end_date")

        self.jitter = jitter

        values = dict(
            (key, value)
            for (key, value) in locals().items()
            if key in self.FIELD_NAMES and value is not None
        )
        self.fields = []
        assign_defaults = False
        for field_name in self.FIELD_NAMES:
            if field_name in values:
                exprs = values.pop(field_name)
                is_default = False
                assign_defaults = not values
            elif assign_defaults:
                exprs = DEFAULT_VALUES[field_name]
                is_default = True
            else:
                exprs = "*"
                is_default = True

            field_class = self.FIELDS_MAP[field_name]
            field = field_class(field_name, exprs, is_default)
            self.fields.append(field)

    @classmethod
    def from_crontab(cls, expr, timezone=None):
        """
        Create a :class:`~CronTrigger` from a standard crontab expression.

        See https://en.wikipedia.org/wiki/Cron for more information on the format accepted here.

        :param expr: minute, hour, day of month, month, day of week
        :param datetime.tzinfo|str timezone: time zone to use for the date/time calculations (
            defaults to scheduler timezone)
        :return: a :class:`~CronTrigger` instance

        """
        values = expr.split()
        if len(values) != 5:
            raise ValueError(f"Wrong number of fields; got {len(values)}, expected 5")

        return cls(
            minute=values[0],
            hour=values[1],
            day=values[2],
            month=values[3],
            day_of_week=values[4],
            timezone=timezone,
        )

    def _increment_field_value(self, dateval, fieldnum):
        """
        Increments the designated field and resets all less significant fields to their minimum
        values.

        :type dateval: datetime
        :type fieldnum: int
        :return: a tuple containing the new date, and the number of the field that was actually
            incremented
        :rtype: tuple
        """

        values = {}
        i = 0
        while i < len(self.fields):
            field = self.fields[i]
            if not field.REAL:
                if i == fieldnum:
                    fieldnum -= 1
                    i -= 1
                else:
                    i += 1
                continue

            if i < fieldnum:
                values[field.name] = field.get_value(dateval)
                i += 1
            elif i > fieldnum:
                values[field.name] = field.get_min(dateval)
                i += 1
            else:
                value = field.get_value(dateval)
                maxval = field.get_max(dateval)
                if value == maxval:
                    fieldnum -= 1
                    i -= 1
                else:
                    values[field.name] = value + 1
                    i += 1

        difference = datetime(**values) - dateval.replace(tzinfo=None)
        dateval = datetime.fromtimestamp(
            dateval.timestamp() + difference.total_seconds(), self.timezone
        )
        return dateval, fieldnum

    def _set_field_value(self, dateval, fieldnum, new_value):
        values = {}
        for i, field in enumerate(self.fields):
            if field.REAL:
                if i < fieldnum:
                    values[field.name] = field.get_value(dateval)
                elif i > fieldnum:
                    values[field.name] = field.get_min(dateval)
                else:
                    values[field.name] = new_value

        return datetime(**values, tzinfo=self.timezone, fold=dateval.fold)

    def get_next_fire_time(self, previous_fire_time, now):
        if previous_fire_time:
            start_date = min(now, previous_fire_time + timedelta(microseconds=1))
            if start_date == previous_fire_time:
                start_date += timedelta(microseconds=1)
        else:
            start_date = max(now, self.start_date) if self.start_date else now

        fieldnum = 0
        next_date = datetime_ceil(start_date).astimezone(self.timezone)
        while 0 <= fieldnum < len(self.fields):
            field = self.fields[fieldnum]
            curr_value = field.get_value(next_date)
            next_value = field.get_next_value(next_date)

            if next_value is None:
                # No valid value was found
                next_date, fieldnum = self._increment_field_value(
                    next_date, fieldnum - 1
                )
            elif next_value > curr_value:
                # A valid, but higher than the starting value, was found
                if field.REAL:
                    next_date = self._set_field_value(next_date, fieldnum, next_value)
                    fieldnum += 1
                else:
                    next_date, fieldnum = self._increment_field_value(
                        next_date, fieldnum
                    )
            else:
                # A valid value was found, no changes necessary
                fieldnum += 1

            # Return if the date has rolled past the end date
            if self.end_date and next_date > self.end_date:
                return None

        if fieldnum >= 0:
            next_date = self._apply_jitter(next_date, self.jitter, now)
            return min(next_date, self.end_date) if self.end_date else next_date

    def __getstate__(self):
        return {
            "version": 2,
            "timezone": self.timezone,
            "start_date": self.start_date,
            "end_date": self.end_date,
            "fields": self.fields,
            "jitter": self.jitter,
        }

    def __setstate__(self, state):
        # This is for compatibility with APScheduler 3.0.x
        if isinstance(state, tuple):
            state = state[1]

        if state.get("version", 1) > 2:
            raise ValueError(
                f"Got serialized data for version {state['version']} of "
                f"{self.__class__.__name__}, but only versions up to 2 can be handled"
            )

        self.timezone = astimezone(state["timezone"])
        self.start_date = state["start_date"]
        self.end_date = state["end_date"]
        self.fields = state["fields"]
        self.jitter = state.get("jitter")

    def __str__(self):
        options = [f"{f.name}='{f}'" for f in self.fields if not f.is_default]
        return "cron[{}]".format(", ".join(options))

    def __repr__(self):
        options = [f"{f.name}='{f}'" for f in self.fields if not f.is_default]
        if self.start_date:
            options.append(f"start_date={datetime_repr(self.start_date)!r}")
        if self.end_date:
            options.append(f"end_date={datetime_repr(self.end_date)!r}")
        if self.jitter:
            options.append(f"jitter={self.jitter}")

        return "<{} ({}, timezone='{}')>".format(
            self.__class__.__name__,
            ", ".join(options),
            self.timezone,
        )
