Source code for the_utils.plt

"""Draw plots.
"""

from typing import Any
from typing import List
from typing import Tuple

import matplotlib.axes as Axes
import matplotlib.pyplot as plt


[docs]def charts(t, ax: Axes): return { "line": ax.plot, "scatter": ax.scatter, "bar": ax.bar, }[t]
# pylint: disable=too-many-locals,too-many-branches,too-many-statements
[docs]def draw_chart( ys: List[List], lbls: List[str], y_colors: List[str] = None, xs: List[Any] = None, fill_between: List = None, fill_colors: List[str] = None, fill_alpha: List[float] or float = 0.5, figsize: Tuple = (24, 3), xmode: str = "s", # pylint: disable=unused-argument ymode: str = "s", title: str = None, set_xticks: bool = False, xticks: List[Any] = None, tick_labels: List[Any] = None, x_rotation: float = 45, boxes: List[List[int]] = None, box_colors: List[str] = None, box_alphas: List[float] = None, ylim: (float, float) = None, xlim: (float, float) = None, linewidth: List[int] = None, types: List[str] or str = "line", markersize: float = 1.25, save_path: str = None, bar_width: float = 0.25, legend_loc: str = "center left", legend_bbox_to_anchor: Tuple[float] = (1, 0.5), ) -> None: """Draw charts. Args: ys (List[List]): value lists of the y axis. lbls (List[str]): legend labels of the value lists. y_colors (List[str], optional): colors of the the ys. Defaults to None. xs (List[Any], optional): value list of the x axis. \ If None, idx will be used. Defaults to None. fill_between (List, optional): fill the range of \ `[ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]]`.\ Defaults to None. fill_colors (List[str], optional): colors of the fill range. Defaults to None. fill_alpha (List[float] or float, optional): color alphas of the fill range. \ Defaults to 0.5. figsize (Tuple, optional): size of the output figure. Defaults to (24, 3). xmode (str, optional): all ys with the same x ('s') or with different xs ('d'). \ Defaults to "s". ymode (str, optional): all ys with the same tick ('s') or with different ticks ('d'). \ Defaults to "s". title (str, optional): figure title. Defaults to None. set_xticks (bool, optional): set the x ticks. Defaults to True. xticks (List[Any], optional): ticks for the x axis. Defaults to None. tick_labels (List[Any], optional): tick labels for the x axis. Defaults to None. x_rotation (float, optional): rotation of the x ticks. Defaults to 45. boxes (List[List[int]], optional): draw boxes. Defaults to None. box_colors (List[str], optional): colors of the boxes to draw. Defaults to None. box_alpha (List[float], optional): color alphas of the boxes to draw. Defaults to 0.5. ylim (float, optional): maximum value of the y axis. Defaults to None. xlim (float, optional): maximum value of the x axis. Defaults to None. linewidth (List[int], optional): line width of the line charts to draw. Defaults to None. types (List[str] or str, optional): types of charts of the ys. Defaults to "line". markersize (float, optional): dot size of the scatter charts. Defaults to 1.25. save_path (str, optional): If not None, save image to the path. Defaults to None. bar_width (float, optional): width of the bars. Defaults to 0.25. legend_loc (str, optional): location of the legend. Defaults to "center left". legend_bbox_to_anchor (Tuple[float], optional): bbox_to_anchor of the legend. \ Defaults to (1, 0.5). Raises: ValueError: "The arg 'types:'{types} has values not supported." """ plt.clf() _, ax = plt.subplots(figsize=figsize) # drawing func list funcs = [] optional_args = [] if isinstance(types, List): for t in types: funcs.append(charts(t, ax)) if t == "bar": optional_args.append({ "width": bar_width, }) else: optional_args.append({}) elif isinstance(types, str): for _ in range(len(ys)): funcs.append(charts(types, ax)) if t == "bar": optional_args.append({ "width": bar_width, }) else: optional_args.append({}) else: raise ValueError(f"The arg 'types:'{types} has values not supported.") plt.rcParams["lines.markersize"] = markersize # ys with the same x if xmode == "s": for idx, y in enumerate(ys): funcs[idx]( xs, y, label=lbls[idx], color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None, linewidth=1 if linewidth is None else linewidth[idx], **optional_args[idx], ) # fill the range of # [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]] if fill_between is not None and fill_between[idx] is not None: ax.fill_between( xs, [y + fill_between[idx][i] for i, y in enumerate(y)], [y - fill_between[idx][i] for i, y in enumerate(y)], facecolor=( fill_colors[idx] if fill_colors is not None and fill_colors[idx] is not None else None ), alpha=( fill_alpha[idx] if isinstance(fill_alpha, List) and fill_colors[idx] is not None else fill_alpha ), ) # set the ticks of the x axis if set_xticks: ax.set_xticks( ticks=range(0, len(xs)) if xticks is None else xticks, labels=xs if tick_labels is None else tick_labels, rotation=x_rotation, ha="right", ) # ys with different xs elif xmode == "d": for idx, y in enumerate(ys): funcs[idx]( xs[idx], y, label=lbls[idx], color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None, linewidth=1 if linewidth is None else linewidth[idx], **optional_args[idx], ) # fill the range of # [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]] if fill_between is not None and fill_between[idx] is not None: ax.fill_between( xs, [y + fill_between[idx][i] for i, y in enumerate(y)], [y - fill_between[idx][i] for i, y in enumerate(y)], facecolor=( fill_colors[idx] if fill_colors is not None and fill_colors[idx] is not None else None ), alpha=( fill_alpha[idx] if isinstance(fill_alpha, List) and fill_colors[idx] is not None else fill_alpha ), ) # set the ticks of the x axis if set_xticks: ax.set_xticks( ticks=range(0, len(xs[0])) if xticks is None else xticks, labels=xs[0] if tick_labels is None else tick_labels, rotation=x_rotation, ha="right", ) # set xs with numbers when no xs provided else: for idx, y in enumerate(ys): xs = list(range(len(y))) funcs[idx]( xs, y, label=lbls[idx], color=y_colors[idx] if y_colors is not None and y_colors[idx] is not None else None, linewidth=1 if linewidth is None else linewidth[idx], **optional_args[idx], ) # fill the range of # [ys[idx][i] - fill_between[idx][i], ys[idx][i] + fill_between[idx][i]] if fill_between is not None and fill_between[idx] is not None: ax.fill_between( xs, [y + fill_between[idx][i] for i, y in enumerate(y)], [y - fill_between[idx][i] for i, y in enumerate(y)], facecolor=( fill_colors[idx] if fill_colors is not None and fill_colors[idx] is not None else None ), alpha=( fill_alpha[idx] if isinstance(fill_alpha, List) and fill_colors[idx] is not None else fill_alpha ), ) # set the ticks of the x axis if set_xticks: ax.set_xticks( ticks=range(0, len(xs)) if xticks is None else xticks, labels=list(range(len(ys[0]))) if tick_labels is None else tick_labels, rotation=45, ha="right", ) # draw a box on the figure if boxes is not None: for idx, box in enumerate(boxes): ax.add_patch( plt.Rectangle( box[0], box[1], box[2], transform=ax.transAxes, color=box_colors[idx] if box_colors is not None else "darkgrey", alpha=box_alphas[idx] if box_alphas is not None else 0.5, ) ) ax.legend(loc=legend_loc, bbox_to_anchor=legend_bbox_to_anchor) if title: plt.title(title) if xlim is not None: plt.xlim(xlim) if ylim is not None: plt.ylim(ylim) if save_path is not None: plt.savefig(save_path) plt.show()