import pprint

import numpy as np

from . import viz2d
from .tools import RadioHideTool, ToggleTool, __plot_dict__

# flake8: noqa
# mypy: ignore-errors


class FormatPrinter(pprint.PrettyPrinter):
    def __init__(self, formats):
        super(FormatPrinter, self).__init__()
        self.formats = formats

    def format(self, obj, ctx, maxlvl, lvl):
        if type(obj) in self.formats:
            return self.formats[type(obj)] % obj, 1, 0
        return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)


class TwoViewFrame:
    default_conf = {
        "default": "image",
        "summary_visible": False,
    }

    plot_dict = __plot_dict__

    childs = []

    event_to_image = [None, "image", "horizon_line", "lat_pred", "lat_gt"]

    def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
        self.conf = conf
        self.data = data
        self.preds = preds
        self.names = list(preds.keys())
        self.plot = self.event_to_image[event]
        self.summaries = summaries
        self.fig, self.axes, self.summary_arts = self.init_frame()
        if title is not None:
            self.fig.canvas.manager.set_window_title(title)

        keys = None
        for _, pred in preds.items():
            keys = set(pred.keys()) if keys is None else keys.intersection(pred.keys())

        keys = keys.union(data.keys())

        self.options = [k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)]
        self.handle = None
        self.radios = self.fig.canvas.manager.toolmanager.add_tool(
            "switch plot",
            RadioHideTool,
            options=self.options,
            callback_fn=self.draw,
            active=conf.default,
            keymap="R",
        )

        self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
            "toggle summary",
            ToggleTool,
            toggled=self.conf.summary_visible,
            callback_fn=self.set_summary_visible,
            keymap="t",
        )

        if self.fig.canvas.manager.toolbar is not None:
            self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
        self.draw(conf.default)

    def init_frame(self):
        """initialize frame"""
        imgs = [[self.data["image"][0].permute(1, 2, 0) for _ in self.names]]
        # imgs = [imgs for _ in self.names]  # repeat for each model

        fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
        [viz2d.add_text(i, n, axes=axes[0]) for i, n in enumerate(self.names)]

        fig.canvas.mpl_connect("pick_event", self.click_artist)
        if self.summaries is not None:
            font_size = 7
            formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
            toggle_artists = [
                viz2d.add_text(
                    i,
                    formatter.pformat(self.summaries[n]),
                    axes=axes[0],
                    pos=(0.01, 0.01),
                    va="bottom",
                    backgroundcolor=(0, 0, 0, 0.5),
                    visible=self.conf.summary_visible,
                    fs=font_size,
                )
                for i, n in enumerate(self.names)
            ]
        else:
            toggle_artists = []
        return fig, axes, toggle_artists

    def draw(self, value):
        """redraw content in frame"""
        self.clear()
        self.conf.default = value
        self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
        return self.handle

    def clear(self):
        if self.handle is not None:
            try:
                self.handle.clear()
            except AttributeError:
                pass
        self.handle = None
        for row in self.axes:
            for ax in row:
                [li.remove() for li in ax.lines]
                [c.remove() for c in ax.collections]
        self.fig.artists.clear()
        self.fig.canvas.draw_idle()
        self.handle = None

    def click_artist(self, event):
        art = event.artist
        select = art.get_arrowstyle().arrow == "-"
        art.set_arrowstyle("<|-|>" if select else "-")
        if select:
            art.set_zorder(1)
        if hasattr(self.handle, "click_artist"):
            self.handle.click_artist(event)
        self.fig.canvas.draw_idle()

    def set_summary_visible(self, visible):
        self.conf.summary_visible = visible
        [s.set_visible(visible) for s in self.summary_arts]
        self.fig.canvas.draw_idle()
