"""Implements the generic progress logger class, and the ProgressBar class."""

from tqdm import tqdm, tqdm_notebook
from collections import OrderedDict
import time

SETTINGS = {"notebook": False}


def notebook(turn="on"):
    SETTINGS["notebook"] = True if (turn == "on") else False


def troncate_string(s, max_length=25):
    return s if (len(s) < max_length) else (s[:max_length] + "...")


class ProgressLogger:
    """Generic class for progress loggers.

    A progress logger contains a "state" dictionary.

    Parameters
    ----------
    init_state : dict
        Dictionary representing the initial state.
    """

    def __init__(self, init_state=None):

        self.state = {}
        self.stored = {}
        self.logs = []
        self.log_indent = 0
        if init_state is not None:
            self.state.update(init_state)

    def log(self, message):
        self.logs.append((" " * self.log_indent) + message)

    def dump_logs(self, filepath=None):
        if filepath is not None:
            with open(filepath, "a") as f:
                f.write("\n".join(self.logs))
        else:
            return "\n".join(self.logs)

    def callback(self, **kw):
        """Execute something after the state has been updated by the given
        state elements.

        This default callback does nothing, overwrite it by subclassing.
        """
        pass

    def store(self, **kw):
        """Store objects in the logger and trigger ``self.store_callback``.

        This works exactly like ``logger()``, but the later is meant for simple
        data objects (text, numbers) that will be sent over the network or
        written to a file. The ``store`` method expects rather large objects
        which are not necessarily serializable, and will be used eg to draw
        plots on the fly.
        """
        self.stored.update(kw)
        self.store_callback(**kw)

    def store_callback(self, **kw):
        """Execute something after the store has been updated by the given
        state elements.

        This default callback does nothing, overwrite it by subclassing
        """
        pass

    def iter(self, **kw):
        """Iterate through a list while updating the state.

        Examples
        --------
        >>> for username in logger.iter(user=['tom', 'tim', 'lea']):
        >>>     # At every loop, logger.state['user'] is updated
        >>>     print(username)
        """
        for field, iterable in kw.items():
            for it in iterable:
                self(**{field: it})
                yield it

    def __call__(self, **kw):
        self.state.update(kw)
        self.callback(**kw)


class ProgressBarLogger(ProgressLogger):
    """Generic class for progress loggers.

    A progress logger contains a "state" dictionary.

    Parameters
    ----------
    init_state : dict
        Initial state of the logger.
    bars : None, list, tuple, or dict, optional
        Either None (will be initialized with no bar) or a list/tuple of bar
        names (e.g., ['main', 'sub']) which will be initialized with index -1 and
        no total, or a dictionary (possibly ordered) of bars, of the form
        `{bar_1: {title: 'bar1', index: 2, total: 23}, bar_2: {...}}`.
    ignored_bars : None, list of str, or 'all_others', optional
        Either None (newly met bars will be added) or a list of blacklisted bar
        names, or 'all_others' to signify that all bar names not already in
        `self.bars` will be ignored.
    logged_bars
    min_time_interval : int or float
        Time in seconds between progress bar updates.
    ignore_bars_under : int
    """

    bar_indent = 2

    def __init__(
        self,
        init_state=None,
        bars=None,
        ignored_bars=None,
        logged_bars="all",
        min_time_interval=0,
        ignore_bars_under=0,
    ):
        ProgressLogger.__init__(self, init_state)
        if bars is None:
            bars = OrderedDict()
        elif isinstance(bars, (list, tuple)):
            bars = OrderedDict(
                [
                    (b, dict(title=b, index=-1, total=None, message=None, indent=0))
                    for b in bars
                ]
            )
        if isinstance(ignored_bars, (list, tuple)):
            ignored_bars = set(ignored_bars)
        self.ignored_bars = ignored_bars
        self.logged_bars = logged_bars
        self.state["bars"] = bars
        self.min_time_interval = min_time_interval
        self.ignore_bars_under = ignore_bars_under

    @property
    def bars(self):
        """Return ``self.state['bars'].``"""
        return self.state["bars"]

    def bar_is_ignored(self, bar):
        if self.ignored_bars is None:
            return False
        elif self.ignored_bars == "all_others":
            return bar not in self.bars
        else:
            return bar in self.ignored_bars

    def bar_is_logged(self, bar):
        if not self.logged_bars:
            return False
        elif self.logged_bars == "all":
            return True
        else:
            return bar in self.logged_bars

    def iterable_is_too_short(self, iterable):
        length = len(iterable) if hasattr(iterable, "__len__") else None
        return (length is not None) and (length < self.ignore_bars_under)

    def iter_bar(self, bar_prefix="", **kw):
        """Iterate through a list while updating a state bar.

        Parameters
        ----------
        bar_prefix : str
            Bar prefix.

        Examples
        --------
        >>> for username in logger.iter_bar(user=['tom', 'tim', 'lea']):
        >>>     # At every loop, logger.state['bars']['user'] is updated
        >>>     # to {index: i, total: 3, title:'user'}
        >>>     print (username)

        """
        if "bar_message" in kw:
            bar_message = kw.pop("bar_message")
        else:
            bar_message = None
        bar, iterable = kw.popitem()

        if self.bar_is_ignored(bar) or self.iterable_is_too_short(iterable):
            return iterable
        bar = bar_prefix + bar
        if hasattr(iterable, "__len__"):
            self(**{bar + "__total": len(iterable)})

        def new_iterable():
            last_time = time.time()
            i = 0  # necessary in case the iterator is empty
            for i, it in enumerate(iterable):
                now_time = time.time()
                if (i == 0) or (now_time - last_time > self.min_time_interval):
                    if bar_message is not None:
                        self(**{bar + "__message": bar_message(it)})
                    self(**{bar + "__index": i})
                    last_time = now_time
                yield it

            if self.bars[bar]["index"] != i:
                self(**{bar + "__index": i})
            self(**{bar + "__index": i + 1})

        return new_iterable()

    def bars_callback(self, bar, attr, value, old_value=None):
        """Execute a custom action after the progress bars are updated.

        Parameters
        ----------
        bar
          Name/ID of the bar to be modified.

        attr
          Attribute of the bar attribute to be modified

        value
          New value of the attribute

        old_value
          Previous value of this bar's attribute.

        This default callback does nothing, overwrite it by subclassing.
        """
        pass

    def __call__(self, **kw):

        items = sorted(kw.items(), key=lambda kv: not kv[0].endswith("total"))

        for key, value in items:
            if "__" in key:
                bar, attr = key.split("__")
                if self.bar_is_ignored(bar):
                    continue
                kw.pop(key)
                if bar not in self.bars:
                    self.bars[bar] = dict(title=bar, index=-1, total=None, message=None)
                old_value = self.bars[bar][attr]

                if self.bar_is_logged(bar):
                    new_bar = (attr == "index") and (value < old_value)
                    if (attr == "total") or (new_bar):
                        self.bars[bar]["indent"] = self.log_indent
                    else:
                        self.log_indent = self.bars[bar]["indent"]
                    self.log("[%s] %s: %s" % (bar, attr, value))
                    self.log_indent += self.bar_indent
                self.bars[bar][attr] = value
                self.bars_callback(bar, attr, value, old_value)
        self.state.update(kw)
        self.callback(**kw)


class TqdmProgressBarLogger(ProgressBarLogger):
    """Tqdm-powered progress bar for console or Notebooks.

    Parameters
    ----------
    init_state : dict
        Initial state of the logger.
    bars : None, list, tuple, or dict, optional
        Either None (will be initialized with no bar) or a list/tuple of bar
        names (e.g., ['main', 'sub']) which will be initialized with index -1 and
        no total, or a dictionary (possibly ordered) of bars, of the form
        `{bar_1: {title: 'bar1', index: 2, total: 23}, bar_2: {...}}`.
    leave_bars : bool, optional
        Whether to leave the progress bars displayed upon completion.
    ignored_bars : None, list of str, or 'all_others', optional
        Either None (newly met bars will be added) or a list of blacklisted bar
        names, or 'all_others' to signify that all bar names not already in
        `self.bars` will be ignored.
    notebook : bool, optional
        True will make the bars look nice (HTML) in the Jupyter notebook. It is
        advised to leave to 'default' as the default can be globally set from
        inside a notebook with `import proglog; proglog.notebook_mode()`.
    print_messages : bool
        If True, every `logger(message='something')` will print a message in
        the console or notebook.
    min_time_interval : int or float
        Time in seconds between progress bar updates.
    ignore_bars_under : int
    """

    def __init__(
        self,
        init_state=None,
        bars=None,
        leave_bars=False,
        ignored_bars=None,
        logged_bars="all",
        notebook="default",
        print_messages=True,
        min_time_interval=0,
        ignore_bars_under=0,
    ):
        ProgressBarLogger.__init__(
            self,
            init_state=init_state,
            bars=bars,
            ignored_bars=ignored_bars,
            logged_bars=logged_bars,
            ignore_bars_under=ignore_bars_under,
            min_time_interval=min_time_interval,
        )
        self.leave_bars = leave_bars
        self.tqdm_bars = OrderedDict([(bar, None) for bar in self.bars])
        if notebook == "default":
            notebook = SETTINGS["notebook"]
        self.notebook = notebook
        self.print_messages = print_messages
        self.tqdm = tqdm_notebook if self.notebook else tqdm

    def new_tqdm_bar(self, bar):
        """Create a new tqdm bar, possibly replacing an existing one."""
        if (bar in self.tqdm_bars) and (self.tqdm_bars[bar] is not None):
            self.close_tqdm_bar(bar)
        infos = self.bars[bar]
        self.tqdm_bars[bar] = self.tqdm(
            total=infos["total"],
            desc=infos["title"],
            postfix=dict(now=troncate_string(str(infos["message"]))),
            leave=self.leave_bars,
        )

    def close_tqdm_bar(self, bar):
        """Close and erase the tqdm bar."""
        self.tqdm_bars[bar].close()
        if not self.notebook:
            self.tqdm_bars[bar] = None

    def bars_callback(self, bar, attr, value, old_value):
        if (bar not in self.tqdm_bars) or (self.tqdm_bars[bar] is None):
            self.new_tqdm_bar(bar)
        if attr == "index":
            if value >= old_value:
                total = self.bars[bar]["total"]
                if total and (value >= total):
                    self.close_tqdm_bar(bar)
                else:
                    self.tqdm_bars[bar].update(value - old_value)
            else:
                self.new_tqdm_bar(bar)
                self.tqdm_bars[bar].update(value + 1)
        elif attr == "message":
            self.tqdm_bars[bar].set_postfix(now=troncate_string(str(value)))
            self.tqdm_bars[bar].update(0)

    def callback(self, **kw):
        if self.print_messages and ("message" in kw) and kw["message"]:
            if self.notebook:
                print(kw["message"])
            else:
                self.tqdm.write(kw["message"])


class RqWorkerProgressLogger:
    def __init__(self, job):
        self.job = job
        if "progress_data" not in self.job.meta:
            self.job.meta["progress_data"] = {}
            self.job.save()

    def callback(self, **kw):
        self.job.meta["progress_data"] = self.state
        self.job.save()


class RqWorkerBarLogger(RqWorkerProgressLogger, ProgressBarLogger):

    def __init__(
        self,
        job,
        init_state=None,
        bars=None,
        ignored_bars=(),
        logged_bars="all",
        min_time_interval=0,
    ):
        RqWorkerProgressLogger.__init__(self, job)
        ProgressBarLogger.__init__(
            self,
            init_state=init_state,
            bars=bars,
            ignored_bars=ignored_bars,
            logged_bars=logged_bars,
            min_time_interval=min_time_interval,
        )


class MuteProgressBarLogger(ProgressBarLogger):

    def bar_is_ignored(self, bar):
        return True


def default_bar_logger(
    logger,
    bars=None,
    ignored_bars=None,
    logged_bars="all",
    min_time_interval=0,
    ignore_bars_under=0,
):
    if logger == "bar":
        return TqdmProgressBarLogger(
            bars=bars,
            ignored_bars=ignored_bars,
            logged_bars=logged_bars,
            min_time_interval=min_time_interval,
            ignore_bars_under=ignore_bars_under,
        )
    elif logger is None:
        return MuteProgressBarLogger()
    else:
        return logger
