Source code for boolforge.wiring_diagram.plotting

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import networkx as nx
import numpy as np
import warnings


class WiringDiagramPlottingMixin:
    def plot_modular_structure(
        self,
        ax=None,
        show=True,
        node_labels: bool = True,
        max_nodes: int = 50,
        curviness: float = 0.25,
    ):
        """
        Plot the wiring diagram as a directed acyclic graph of strongly connected components.
    
        The wiring diagram is first condensed into its strongly connected components (SCCs),
        yielding a directed acyclic graph (DAG). Each node in the plot represents one SCC.
    
        The layout is hierarchical (top-to-bottom) using topological generations, making
        feed-forward structure visually apparent, while condensing feedback loops.
    
        Parameters
        ----------
        ax : matplotlib.axes.Axes, optional
            Axis to draw on. If None, a new figure and axis are created.
        show : bool, default=True
            Whether to call ``plt.show()`` after plotting.
        node_labels : bool, default=True
            Whether to label SCC nodes by their size (only shown for SCCs of size > 1).
        max_nodes : int, default=50
            If the number of SCCs exceeds this value, edges are sparsified to reduce clutter.
        curviness : float, default=0.25
            Curvature of edges spanning multiple layers (0 = straight).
        
        Returns
        -------
        ax : matplotlib.axes.Axes
            The axis containing the plot.
        """
        
        try:
            import matplotlib.pyplot as plt
            from matplotlib.patches import Ellipse, FancyArrowPatch
        except:
            raise ImportError(
                "Plotting requires matplotlib. "
                "Install it with: pip install matplotlib"
            )
        
        def _ellipse_boundary_point(x0, y0, x1, y1, a, b):
            """
            Intersection of ray from (x0,y0) toward (x1,y1)
            with ellipse centered at (x0,y0) with semi-axes a,b.
            """
            dx = x1 - x0
            dy = y1 - y0
            if dx == 0 and dy == 0:
                return x0, y0
        
            t = 1.0 / np.sqrt((dx / a) ** 2 + (dy / b) ** 2)
            return x0 + t * dx, y0 + t * dy
    
        N = self.N
        I = self.I
    
        # ------------------------------------------------------------------
        # Build directed graph of the wiring diagram
        # ------------------------------------------------------------------
        g_full = nx.DiGraph()
        g_full.add_nodes_from(range(N))
    
        outdegrees = np.zeros(N, dtype=int)
        for target, regulators in enumerate(I):
            for r in regulators:
                g_full.add_edge(r, target)
                outdegrees[r] += 1
    
        # ------------------------------------------------------------------
        # Compute SCCs
        # ------------------------------------------------------------------
        sccs = list(nx.strongly_connected_components(g_full))
        scc_sizes = np.array([len(scc) for scc in sccs])
    
        scc_index = {}
        for i, scc in enumerate(sccs):
            for node in scc:
                scc_index[node] = i
                
        if len(sccs)<2:
            warnings.warn('No plot created. The network consists of a single SCC', UserWarning)
            return None
    
        # ------------------------------------------------------------------
        # Build SCC DAG
        # ------------------------------------------------------------------
        dag_edges = set()
        for u, v in g_full.edges():
            su, sv = scc_index[u], scc_index[v]
            if su != sv:
                dag_edges.add((su, sv))
    
        dag = nx.DiGraph(dag_edges)
    
        if dag.number_of_nodes() > max_nodes:
            dag = nx.DiGraph(
                (u, v) for (u, v) in dag_edges
                if scc_sizes[u] > 1 or scc_sizes[v] > 1
            )
    
        # ------------------------------------------------------------------
        # Node types
        # ------------------------------------------------------------------
        types = np.zeros(len(sccs), dtype=int)
    
        for i, scc in enumerate(sccs):
            if scc_sizes[i] > 1:
                types[i] = 2
            else:
                node = next(iter(scc))
                if (g_full.in_degree(node) == 0 or
                    (g_full.in_degree(node) == 1 and
                     list(g_full.predecessors(node))[0] == node)
                ):
                    types[i] = -1
                elif outdegrees[node] == 0:
                    types[i] = 1
                else:
                    types[i] = 0
    
        # ------------------------------------------------------------------
        # Hierarchical layout: initial placement by generations
        # ------------------------------------------------------------------
        pos = {}
        layers = []
        
        generations = list(nx.topological_generations(dag))
        #max_n_per_generation = max([len(gen) for gen in generations])
        #n_generations = len(generations)
        for layer, generation in enumerate(generations):
            gen = list(generation)
            layers.append(gen)
    
            k = len(gen)
            if k == 1:
                pos[gen[0]] = (0.0, -layer)
            else:
                xs = np.linspace(-0.5, 0.5, k)
                for x, node in zip(xs, gen):
                    pos[node] = (x, -layer)
    
        # ------------------------------------------------------------------
        # NEW: barycentric horizontal refinement (FFL fix)
        # ------------------------------------------------------------------
        for gen in layers[1:]:  # skip first layer
            for v in gen:
                preds = list(dag.predecessors(v))
                if preds:
                    x_mean = np.mean([pos[p][0] for p in preds])
                    y = pos[v][1]
                    pos[v] = (x_mean, y)
    
        # tiny deterministic jitter to avoid exact overlaps
        eps = 1e-3
        for i, v in enumerate(dag.nodes()):
            x, y = pos[v]
            pos[v] = (x + eps * (i % 7), y)
            
        # ------------------------------------------------------------------
        # Post-process: spread nodes within each layer to use full width
        # ------------------------------------------------------------------
        max_width = 3
        for gen in layers:
            xs = np.array([pos[v][0] for v in gen])
            if len(xs) <= 1:
                continue
        
            # Sort nodes by x
            order = np.argsort(xs)
            gen_sorted = [gen[i] for i in order]
        
            # Reassign evenly spaced x positions
            width = max(1.0, min(len(gen_sorted) / 3, max_width))
            new_xs = np.linspace(-width / 2, width / 2, len(gen_sorted))
        
            for v, x in zip(gen_sorted, new_xs):
                pos[v] = (x, pos[v][1])
        
        # ------------------------------------------------------------------
        # Vertical micro-staggering within layers (reduce edge overlap)
        # ------------------------------------------------------------------
        epsilon = 0.25  # vertical spacing scale
        
        for gen in layers:
            if len(gen) <= 3:
                continue
        
            # sort nodes left-to-right
            gen_sorted = sorted(gen, key=lambda v: pos[v][0])
        
            for i, v in enumerate(gen_sorted):
                x, y = pos[v]
                # pattern: middle, down, up, middle, down, up, ...
                offset = (i % 3) * epsilon
                pos[v] = (x, y + offset)
    
        # ------------------------------------------------------------------
        # Plot
        # ------------------------------------------------------------------
        xs = np.array([x for x, y in pos.values()])
        ys = np.array([y for x, y in pos.values()])
        
        x_span = xs.max() - xs.min()
        y_span = ys.max() - ys.min()
        
        target_aspect = 1.5   # width / height
        current_aspect = x_span / max(y_span, 1e-6)

        if current_aspect < target_aspect:
            scale = target_aspect / current_aspect
            for v, (x, y) in pos.items():
                pos[v] = (x * scale, y)
        
        fig_width  = max(6, 0.6 * x_span * scale)
        fig_height = max(6, 0.6 * abs(y_span))
        
        
        if ax is None:
            _, ax = plt.subplots(figsize=(fig_width, fig_height))
            
        color_map = {
            -1: "#eeeeee",
             0: "#ffcccc",
             1: "#eeeeee",
             2: "#ff9999",
        }
    
        labels = None
        if node_labels:
            labels = {
                n: 'SCC of size '+str(scc_sizes[n]) if scc_sizes[n] > 1
                else self.variables[list(sccs[n])[0]]
                for n in dag.nodes()
            }
    
    
        xs = [x for x, y in pos.values()]
        ys = [y for x, y in pos.values()]
        
        pad_x = 1.0
        pad_y = 1.0
        
        ax.set_xlim(min(xs) - pad_x, max(xs) + pad_x)
        ax.set_ylim(min(ys) - pad_y, max(ys) + pad_y)
    
        texts = {}
        for n, (x, y) in pos.items():
            label = labels[n] if labels else ""
            texts[n] = ax.text(
                x, y, label,
                ha="center", va="center",
                fontsize=8,
                zorder=3
            )
        
        fig = ax.figure
        fig.canvas.draw()
        renderer = fig.canvas.get_renderer()


        # ------------------------------------------------------------------
        # Fix horizontal overlaps using measured text widths
        # ------------------------------------------------------------------
        half_width = {}
        half_height = {}
        
        for n, text in texts.items():
            bbox = text.get_window_extent(renderer=renderer)
            inv = ax.transData.inverted()
            (x0, y0), (x1, y1) = inv.transform(
                [(bbox.x0, bbox.y0), (bbox.x1, bbox.y1)]
            )
        
            half_width[n]  = 0.65 * (x1 - x0)
            half_height[n] = 0.90 * (y1 - y0)
        
        
        min_gap = 0.1  # extra spacing between nodes
        
        for gen in layers:
            if len(gen) <= 1:
                continue
        
            # sort left-to-right
            gen_sorted = sorted(gen, key=lambda v: pos[v][0])
        
            x_cursor = pos[gen_sorted[0]][0]
            new_pos = {gen_sorted[0]: x_cursor}
        
            for prev, curr in zip(gen_sorted[:-1], gen_sorted[1:]):
                required = (
                    half_width[prev] +
                    half_width[curr] +
                    min_gap
                )
                x_cursor = max(pos[curr][0], x_cursor + required)
                new_pos[curr] = x_cursor
        
            # re-center layer
            center = np.mean(list(new_pos.values()))
            for v in gen_sorted:
                pos[v] = (new_pos[v] - center, pos[v][1])
        
        # ------------------------------------------------------------------
        # FINAL barycentric refinement (FFL alignment fix)
        # ------------------------------------------------------------------

        for gen in layers[1:]:
            for v in gen:
                preds = list(dag.predecessors(v))
                if preds:
                    x_mean = np.mean([pos[p][0] for p in preds])
                    #pos[v] = (x_mean, pos[v][1])
                    alpha = 0.1   # 0 = no barycentric, 1 = full snap
                    x_new = alpha * x_mean + (1 - alpha) * pos[v][0]
                    pos[v] = (x_new, pos[v][1])
        
        node_layer = {}
        for i, gen in enumerate(layers):
            for v in gen:
                node_layer[v] = i

        for t in texts.values():
            t.remove()
        
        ellipse_axes = {}   # <-- ADD THIS
        
        for n, (x, y) in pos.items():
            a = half_width[n]
            b = half_height[n]#max(half_height[n], 0.3 * (2 * a))  # since height = max(2b, 0.6w)
        
            ellipse_axes[n] = (a, b)
        
        for n, (x, y) in pos.items():
            a, b = ellipse_axes[n]
        
            ellipse = Ellipse(
                (x, y),
                width=2*a,
                height=2*b,
                facecolor=color_map[types[n]],
                edgecolor="black",
                zorder=2,
            )
            ax.add_patch(ellipse)


        for u, v in dag.edges():
            x0, y0 = pos[u]
            x1, y1 = pos[v]
        
            a0, b0 = ellipse_axes[u]
            a1, b1 = ellipse_axes[v]
        
            sx, sy = _ellipse_boundary_point(x0, y0, x1, y1, a0, b0)
            tx, ty = _ellipse_boundary_point(x1, y1, x0, y0, a1, b1)
        
            if node_layer[v] - node_layer[u] > 1:
                if node_layer[v] - node_layer[u] > 1 and curviness != 0.0:
                    sign = -1 if pos[u][0] + pos[v][0] > 0 else 1
                    conn = f"arc3,rad={sign * curviness}"
                else:
                    conn = "arc3"
            else:
                conn = "arc3"
                
            arrow = FancyArrowPatch(
                (sx, sy),
                (tx, ty),
                arrowstyle='-|>',
                mutation_scale=10,
                linewidth=1.2,
                color='black',
                connectionstyle=conn,   # curved or straight
                zorder=1,
            )
            ax.add_patch(arrow)
            
        #Draw text again, now at the correct positions
        for n, (x, y) in pos.items():
            ax.text(
                x, y, labels[n],
                ha="center", va="center",
                fontsize=7,
                zorder=4,   # on top
            )
    
        ax.set_autoscale_on(False)

        # --------------------------------------------------
        # final hard limits: guarantee everything is visible
        # --------------------------------------------------
        xs = []
        ys = []
        
        for n, (x, y) in pos.items():
            a, b = ellipse_axes[n]
            xs.extend([x - a, x + a])
            ys.extend([y - b, y + b])
        
        PAD_X = 0.5
        PAD_Y = 0.5
        
        ax.set_xlim(min(xs) - PAD_X, max(xs) + PAD_X)
        ax.set_ylim(min(ys) - PAD_Y, max(ys) + PAD_Y)
        
        ax.set_autoscale_on(False)

        ax.set_axis_off()
    
        if show:
            plt.show()
    
        return ax
    
    
    def plot(
        self,
        max_expanded_sccs: int = 4,
        min_scc_size: int = 2,
        show: bool = True,
        curviness: float = 0.25,
    ):
        """
        Plot an integrated overview of the wiring diagram.
    
        The plot consists of:
          1) A top panel showing the modular structure of the network as a DAG of
             strongly connected components (SCCs).
          2) Bottom panels showing the internal wiring of selected SCCs using a
             circular layout.
    
        By default, the largest SCCs of size >= ``min_scc_size`` are expanded,
        up to ``max_expanded_sccs``.
    
        Parameters
        ----------
        max_expanded_sccs : int, default=4
            Maximum number of SCCs to expand and show in detail.
        min_scc_size : int, default=2
            Minimum SCC size to be eligible for expansion.
        show : bool, default=True
            Whether to call ``plt.show()`` at the end.
        curviness : float, default=0.25
            Curvature of edges spanning multiple layers in the modular graph (0 = straight).
            
        Returns
        -------
        fig : matplotlib.figure.Figure
            The created figure.
        """
        
        try:
            import matplotlib.pyplot as plt
        except:
            raise ImportError(
                "Plotting requires matplotlib. "
                "Install it with: pip install matplotlib"
            )
    
        N = self.N
        I = self.I
    
        # ------------------------------------------------------------
        # Build full directed graph
        # ------------------------------------------------------------
        g = nx.DiGraph()
        g.add_nodes_from(range(N))
        for target, regulators in enumerate(I):
            for r in regulators:
                g.add_edge(r, target)
    
        # ------------------------------------------------------------
        # Compute SCCs
        # ------------------------------------------------------------
        sccs = list(nx.strongly_connected_components(g))
        sccs = [sorted(scc) for scc in sccs]
    
        node_to_scc = {}
        for i, scc in enumerate(sccs):
            for v in scc:
                node_to_scc[v] = i
                
        G_scc = nx.DiGraph()
        G_scc.add_nodes_from(range(len(sccs)))
        
        for u, v in g.edges:
            su = node_to_scc[u]
            sv = node_to_scc[v]
            if su != sv:
                G_scc.add_edge(su, sv)
    
        # Select SCCs to expand
        expandable = [scc for scc in sccs if len(scc) >= min_scc_size]
        expandable.sort(key=len, reverse=True)
        expanded_sccs = expandable[:max_expanded_sccs]
    
        n_expanded = len(expanded_sccs)
    
        # ------------------------------------------------------------
        # Figure and GridSpec
        # ------------------------------------------------------------
        if n_expanded == 0: #if just showing the modular graph
            fig = plt.figure(figsize=(8, 4))
            gs = fig.add_gridspec(1, 1)
            ax_top = fig.add_subplot(gs[0, 0])
        elif len(sccs)==1: #if just showing the single SCC
            fig = plt.figure(figsize=(4 * n_expanded, 4))
            gs = fig.add_gridspec(1)
        else: #if showing both
            fig = plt.figure(figsize=(4 * n_expanded, 6))
            gs = fig.add_gridspec(
                2,
                n_expanded,
                height_ratios=[2.2, 1.5],
            )
            ax_top = fig.add_subplot(gs[0, :])
    
        # ------------------------------------------------------------
        # Top panel: modular structure
        # ------------------------------------------------------------
        if len(sccs)>1:
            self.plot_modular_structure(ax=ax_top, show=False,curviness=curviness)
            ax_top.set_title("Modular structure (DAG of SCCs)")
    
        # ------------------------------------------------------------
        # Bottom panels: internal SCC structure
        # ------------------------------------------------------------
        for j, scc in enumerate(expanded_sccs):
            if len(sccs)>1:
                ax = fig.add_subplot(gs[1, j])
            else:
                ax = fig.add_subplot(gs[j])
            
            C = set(scc)

            # direct external inputs only
            inputs = {
                u
                for v in C
                for u in g.predecessors(v)
                if u not in C
            }
            
            nodes_local = C | inputs
            subg = g.subgraph(nodes_local).copy()
    
            for u, v in list(subg.edges):
                if u not in nodes_local or v not in C:
                    subg.remove_edge(u, v)
    
            #subg = g.subgraph(scc).copy()
    
            pos = {}

            epsilon = 0.35  # vertical spacing scale
            
            # Inputs on top
            inputs = sorted(inputs)
            k_in = len(inputs)
            for j, v in enumerate(inputs):
                offset = 0 if k_in <= 3 else (j % 3 - 1) * epsilon
                pos[v] = ((j - (k_in - 1) / 2) /k_in*2, 2.0 + offset)
            
            # SCC nodes in a circle below
            pos_scc = nx.circular_layout(C, scale=1.0, center=(0.0, 0.0))
            pos.update(pos_scc)   
            
            # Color nodes: all feedback (same SCC)
            node_colors = [
                "#ff9999" if v in C else "#eeeeee"
                for v in subg.nodes
            ]
            
            nx.draw_networkx(
                subg,
                pos=pos,
                ax=ax,
                node_color=node_colors,
                node_size=200,
                labels={v: self.variables[v] for v in subg.nodes()},
                font_size=9,
                with_labels=True,
            )
    
            ax.set_title(f"SCC of size {len(scc)}")
            ax.set_axis_off()
    
        fig.tight_layout()
    
        if show:
            plt.show()
    
        return fig