Source code for schrodinger.ui.qt.network_visualizer

"""
network_visualizer.py

Description: This package is meant to help with the visualization of network-
connection data, in conjunction with network_views.py. A good example of the
type of data this is meant to visualize is at:
http://networkx.lanl.gov/index.html

The `Graph` class is meant as a wrapper for `networkx.Graph` objects, which can
then act as a model for `AbstractNetworkView` and the associated network view
classes defined in network_views.py.

Copyright Schrodinger, LLC. All rights reserved.
"""

#Author:  Pat Lorton

import copy
import math
from sys import maxsize

import networkx as nx
import numpy as np

from schrodinger.Qt.QtCore import QObject
from schrodinger.Qt.QtCore import pyqtSignal

#### Spring layout parameters ####
STARTING_TEMPERATURE = .1
ITERATIONS = 100
SCALE = 3.0  # Scale of distances to node size. Higher number is greater separation
PUSH = 1  # Multiplier for node-node repulsion
PUSHEXP = 6  # Exponential dependence of node-node repulsion

#===============================================================================
# Graph Model Classes
#===============================================================================


class GraphSignals(QObject):
    selectionChanged = pyqtSignal(set, object)
    positionChanged = pyqtSignal(set)
    nodesChanged = pyqtSignal(set)
    nodesAdded = pyqtSignal(set)
    nodesDeleted = pyqtSignal(set)
    edgesChanged = pyqtSignal(set)
    graphChanged = pyqtSignal()
    undoPointSet = pyqtSignal()


class Graph:
    """
    A model class for an undirected graph. This wraps around the NetworkX Graph
    class and provides QT signals, a easier-to-use API, and access control.

    All persistent data should be stored in self._ggraph.

    Note that Graph itself cannot be pickled; Graph has Graph.signals, which is
    a QObject and cannot be pickled. For this reason selection information
    (which contains references to Graph) is not placed in self._ggraph, so that
    self._ggraph can be pickled.

    """

    def __init__(self, ggraph=None, node_class=None, edge_class=None):
        """
        Constructs a new Graph object

        :param ggraph: The graph underlying this graph.
        :type ggraph: `networkx.Graph`
        :param node_class: The class to represent the graph's nodes (should be
            subclass of `Node`)
        :type node_class: class
        :param edge_class: The class to represent the graph's edges (should be
            subclass of `Edge`)
        :type edge_class: class
        """
        self.signals = GraphSignals()
        if ggraph is None:
            ggraph = nx.Graph()

        self.node_class = node_class or Node
        self.edge_class = edge_class or Edge

        self._ggraph = ggraph
        self.selected_nodes = set()
        self.selected_edges = set()
        self.node_objects = {}
        self._updateNodeMap()

        self.connection_validator = None

        self.undo_stack = []
        self.max_undo_stack = 100
        self.redo_stack = []

    @property
    def ggraph(self):
        return self._ggraph

    def update(self):
        """
        Update any derived aspects of the graph after changes.
        """
        return

    def _updateNodeMap(self):
        """
        Update the `node_objects` dictionary with any new nodes from the
        underlying ggraph.
        """

        new_nodes = set()
        for gnode in self.ggraph.node:
            if str(gnode) not in self.node_objects:
                node = self.node_class(gnode, self)
                self.node_objects[node.name] = node
                new_nodes.add(node)
        return new_nodes

    def setEdgeValidator(self, validator):
        """
        Set an edge validator that will be run when adding edges between nodes.
        :param validator: the validator
        :type validator: ConnectionValidator
        """
        if not isinstance(validator, ConnectionValidator):
            raise TypeError('Validator must be a subclass of '
                            'ConnectionValidator')
        self.connection_validator = validator

    def toNetworkX(self):
        """
        Return a copy of the underlying NetworkX graph.
        """
        return copy.deepcopy(self._ggraph)

    def getData(self, key):
        """
        Return the requested item from the graph's data dictionary. Returns None
        if the key is not found.
        """
        return self._ggraph.graph.get(key)

    def setData(self, key, value, signal=True):
        """
        Set the value of an item in the graph's data dictionary.
        """
        self._ggraph.graph[key] = value
        if signal:
            self.signals.graphChanged.emit()

    def isConnected(self):
        """
        Checks whether the graph is connected, that is, whether every node is
        connected by some path to every other node.

        :return: Whether the graph is connected
            rtype: bool
        """
        return self._ggraph and nx.is_connected(self._ggraph)

    #===========================================================================
    # Node methods
    #===========================================================================

    def nodeCount(self):
        """
        :return: the number of nodes in the graph
        :rtype: int
        """

        return self.ggraph.number_of_nodes()

    def getIsolates(self):
        """
        :return: a complete set of nodes in the graph that have degree 0
        :rtype: set(Node)
        """

        return {self.getNode(gnode) for gnode in nx.isolates(self.ggraph)}

    def getConnectedComponents(self, nodes=None):
        """
        Return a set of nodes for each connected component in the graph.

        :param nodes: optionally, a set of nodes to filter the returned
            components. If provided, this method will only return components for
            which at least one node is in `nodes`
        :type nodes: set(Node) or NoneType
        :return: a generater over each connected component in the graph
        :rtype: typing.Generator[set[Node], None, None]
        """

        for gnodes in nx.connected_components(self.ggraph):
            component_nodes = {self.getNode(gnode) for gnode in gnodes}
            if nodes is None or nodes.intersection(component_nodes):
                yield component_nodes

    def getNodeConnectedComponent(self, node):
        """
        Return a set of nodes that are part of the same connected component as
        `node`.

        :param node: a node
        :type node: Node
        :return: a set of nodes connected to `node` through any path of edges
        :rtype: set(Node)
        """

        return {
            self.getNode(gnode)
            for gnode in nx.node_connected_component(self.ggraph, node.gnode)
        }

    def _getNodeInstances(self, objects):
        """
        Given a list of objects, return only the `Node` instances among them.

        :param objects: a list of objects
        :type objects: list(object)
        :return: the set of node instances among the supplied list of objects
        :rtype: set(Node)
        """

        return {obj for obj in objects if isinstance(obj, self.node_class)}

    def _getEdgeInstances(self, objects):
        """
        Given a list of objects, return only the `Edge` instances among them.

        :param objects: a list of objects
        :type objects: list(object)
        :return: the set of edge instances among the supplied list of objects
        :rtype: set(Edge)
        """

        return {obj for obj in objects if isinstance(obj, self.edge_class)}

    def getNode(self, node_key):
        """
        Retrieve a node via its name. Retrieved nodes are cached, so getting the
        same Node again will return the same instance. Returns None if no
        matching Node exists.

        :param node_key: a node, gnode, or string that corresponds to the
            desired node
        :type node_key: object
        :return: a node if found, else `None`
        :rtype: Node or NoneType
        """

        if isinstance(node_key, self.node_class):
            return node_key
        return self.node_objects.get(str(node_key))

    def getNodes(self, node_keys=None):
        """
        Retrieve a set of nodes optionally indicated by a list of keys. If none
        is provided, return all nodes.

        :param node_keys: optionally, a list of nodes, gnodes, or strings that
            correspond to the desired nodes
        :type node_keys: list(object) or NoneType
        :return: a set of nodes
        :rtype: set(Node)
        """

        if node_keys is None:
            return set(self.node_objects.values())

        nodes = set()
        for node_key in node_keys:
            node = self.getNode(node_key)
            if node:
                nodes.add(node)
        return nodes

    def getNeighbors(self, node):
        """
        Return a set of all nodes connected to a specified node

        :param node: center node
        :type node: Node
        :return: neighboring nodes
        :rtype: set of Node
        """

        gnodes = self._ggraph.neighbors(node.gnode)
        return self.getNodes(gnodes)

    def addNodes(self, nodes, signal=True):
        """
        Add a list of nodes to this graph. The `nodes` argument can either be a
        list  of `Node` objects or a list of hashable objects that can be used
        as new gnodes.

        Note that any time a new gnode is created for use in this graph, its
        string representation must be unique among the other nodes in this
        graph: nodes are keyed in the `node_objects` dictionary by the string
        representation of their corresponding gnode.

        :param nodes: list of gnodes or nodes
        :type nodes: `list(object)` or `list(Node)`

        :param signal: whether the `addNodes` signal should be emitted when done
        :type signal: bool

        :return: a set of added nodes
        :rtype: set(Node)
        """

        for node in nodes:
            if not isinstance(node, Node):
                if str(node) in self.node_objects:
                    msg = ('A node with the same string representation ("{0}")'
                           ' already exists in this graph. New nodes must have'
                           ' a unique string representation.').format(node)
                    raise ValueError(msg)
                new_node = self.node_class(node, self)
            else:
                new_node = node

            if new_node.gnode in self.ggraph.node:
                msg = 'Node %s already exists in graph.' % new_node.gnode
                raise ValueError(msg)

            self._ggraph.add_node(new_node.gnode, **new_node.gdata())
        new_nodes = self._updateNodeMap()

        if signal and new_nodes:
            self.signals.nodesAdded.emit(new_nodes)
        return new_nodes

    def addNode(self, node, signal=True):
        """
        Convenience method for adding a single node to the graph. See
        `addNodes()` for full documentation.

        :param node: gnode or node
        :type node: hashable or `Node`

        :param signal: whether the `addNodes` signal should be emitted when done
        :type signal: bool

        :return: the added node
        :rtype: Node
        """

        new_nodes = self.addNodes([node], signal=signal)
        return new_nodes.pop()

    def removeNodes(self, nodes, signal=True):
        """
        Remove specified nodes from the graph and optionally emit a signal.

        :param nodes: a list of nodes to be removed
        :type nodes: list(Node)

        :param signal: whether to emit a `nodesDeleted` signal when done
        :type signal: bool
        """

        self._ggraph.remove_nodes_from([n.gnode for n in nodes])
        for node in nodes:
            del self.node_objects[node.name]

        if signal:
            self.signals.nodesDeleted.emit(set(nodes))

    def removeNode(self, node, signal=True):
        """
        Convenience function for removing a single node. See `removeNode()` for
        full documentation.

        :param node: a gnode or node to remove
        :type node: `object` or `Node`

        :param signal: whether to emit a `nodesDeleted` signal when done
        :type signal: bool
        """

        self.removeNodes([node], signal=signal)

    def setMultipleNodePos(self, pos_dict, signal=True):
        """
        Set the positions of nodes from a dictionary.

        :param pos_dict: A dictionary mapping nodes to (x,y) tuples.
        :type pos_dict: dict {Node : (int, int)}

        """
        changednodes = set()
        for node, pos in pos_dict.items():
            node.setPos(pos[0], pos[1], False)
            changednodes.add(node)
        if signal:
            self.signals.positionChanged.emit(changednodes)

    #===========================================================================
    # Edge methods
    #===========================================================================
    def edgeCount(self):
        """
        :return: the number of edges in the graph
        :rtype: int
        """

        return self.ggraph.number_of_edges()

    def hasEdge(self, node1, node2):
        """
        Return whether there is an edge between the supplied nodes.

        :param node1: a node from this graph
        :type node1: Node
        :param node2: a node from this graph
        :type node2: Node
        :return: whether there exists an edge between the two supplied nodes
        :rtype: bool
        """

        return self._ggraph.has_edge(node1.gnode, node2.gnode)

    def getGEdge(self, node0, node1):
        """
        Return the underlying gedge object corresponding to two supplied nodes.
        This can be overwritten in subclasses, but the returned class should
        define a consistent edge ordering that is independent of the order of
        the supplied node parameters.

        :param node0: a node
        :type node0: Node
        :param node1: a node
        :type node1: Node
        :return: the underlying gedge between the two nodes, if it exists
        :rtype: tuple(networkx.Node) or NoneType
        """

        if not self.hasEdge(node0, node1):
            return None

        return tuple(sorted([node0.gnode, node1.gnode]))

    def getEdge(self, node0, node1):
        """
        Given two nodes, return the corresponding edge if it exists.

        :param node0: a node
        :type node0: Node
        :param node1: a node
        :type node1: Node
        :return: the edge connecting the two nodes if it exists
        :rtype: Edge or NoneType
        """

        gedge = self.getGEdge(node0, node1)
        if gedge:
            return self.edge_class(gedge, self)

    def getEdges(self, nodes=None):
        """
        Return all edges connected to a node or set of nodes. If no node is
        specified, all the edges in the graph are returned.

        :param nodes: optionally a node or iterable of nodes
        :type nodes: `iterable(Node)`, `Node`, or `None`
        :return: a set of edges connected to at least one of the supplied nodes,
            or a set of all edges if `nodes` is not specified
        :rtype: set(Edge)
        """

        if nodes is None:
            gnodes = None
        else:
            if not hasattr(nodes, '__iter__'):
                nodes = [nodes]
            gnodes = [node.gnode for node in nodes]

        edges = set()
        for gnode1, gnode2 in self.ggraph.edges(gnodes):
            node1, node2 = self.getNode(gnode1), self.getNode(gnode2)
            edges.add(self.getEdge(node1, node2))
        return edges

    def addEdges(self, edge_tuples, signal=True):
        """
        Add edges to graph.

        :param edge_tuples: list of tuples indicating the edges to add,
            containing two gnodes or nodes and an edge attribute dictionary (or
            `None`)
        :type edge_tuples: list(tuple(Node, Node, dict)) or
            list(tuple(Node, Node, None))
        :param signal: whether `edgesChanged` signal should be emitted when done
        :type signal: bool
        """

        new_edges = set()
        for node1, node2, data_dict in edge_tuples:
            if data_dict is None:
                data_dict = {}
            if self.hasEdge(node1, node2):
                msg = 'Edge {} already exists in graph.'.format((node1, node2))
                raise ValueError(msg)
            self._ggraph.add_edge(node1.gnode, node2.gnode, **data_dict)
            new_edges.add(self.getEdge(node1, node2))

        if signal:
            self.signals.edgesChanged.emit(new_edges)

    def addEdge(self, node1, node2, signal=True, data=None):
        """
        Convenience function to add a single edge to the graph given two nodes.
        The order of the nodes does not matter.

        :param node1: a gnode or node connected by the edge
        :type node1: `object` or `Node`

        :param node2: a gnode or node connected by the edge
        :type node2: `object` or `Node`

        :param signal: whether `edgesChanged` signal should be emitted when done
        :type signal: bool
        """

        self.addEdges([(node1, node2, data)], signal=signal)

    def removeEdges(self, edges, signal=True):
        """
        Removes specified edges from the graph.

        :param edges: a list of edges
        :type edges: list(Edge)
        :param signal: whether `edgesChanged` signal should be emitted when done
        :type signal: bool
        """

        for edge in edges:
            node1, node2 = edge
            if not self.hasEdge(node1, node2):
                raise ValueError('Edge not found between %s and %s' % (node1,
                                                                       node2))
            self._ggraph.remove_edge(node1.gnode, node2.gnode)

        if signal:
            self.signals.edgesChanged.emit(set(edges))

    def removeEdge(self, edge, signal=True):
        """
        Convenience function to remove a single edge from the graph.

        :param edge: an edge
        :type edge: Edge
        :param signal: whether `edgesChanged` signal should be emitted when done
        :type signal: bool
        """

        self.removeEdges([edge], signal=signal)

    def getEdgeApproval(self, node1, node2):
        """
        Test whether a new edge can be added between two nodes. Doesn't actually
        add an edge, just returns whether it is allowable to add.
        """

        if self.hasEdge(node1, node2):
            return False, "This connection already exists"
        if node1 == node2:
            return False, "Can't connect a node to itself."
        if self.connection_validator:
            return self.connection_validator.validate(node1, node2)
        return True, "No Problem"

    #===========================================================================
    # Selection
    #===========================================================================
    def selectedNodes(self):
        """
        Return the currently selected nodes.

        :rtype: set of Nodes
        """
        return self.selected_nodes

    def selectedEdges(self):
        """
        :return: the set of selected edges
        :rtype: set(Edge)
        """
        return self.selected_edges

    def setSelectedObjs(self, objs, source=None, signal=True):
        """
        Specify the current selection.

        :param objs: a list of objects (nodes or edges) to be selected
        :type objs: list(Node or Edge)
        :param source: the class instance calling this method (used to avoid
            infinite recursion when updating selection state)
        :type source: object
        :param signal: whether to emit a signal when changing selection state
        :type signal: bool
        """

        nodes = self._getNodeInstances(objs)
        edges = self._getEdgeInstances(objs)
        if set.symmetric_difference(nodes, self.selectedNodes()):
            self.selected_nodes = nodes
        if set.symmetric_difference(edges, self.selectedEdges()):
            self.selected_edges = edges
        items = nodes.union(edges)
        if signal:
            self.signals.selectionChanged.emit(items, source)

    #===========================================================================
    # Layout methods
    #===========================================================================
    def springLayout(self, signal=True):
        """
        Performs a spring layout on the current graph.
        """
        node_coord_map = self._getSpringLayoutCoords(
            iterations=ITERATIONS, weight_attr=None, scale=SCALE)
        self.setMultipleNodePos(node_coord_map, signal)

    def _getSpringLayoutCoords(self,
                               dim=2,
                               node_pos_map=None,
                               fixed_nodes=None,
                               iterations=50,
                               weight_attr='weight',
                               scale=1):
        """
        Calculate and return a dictionary mapping nodes to optimally-computed
        Cartesian coordinates for each node. Convenience method that wraps
        `spring_layout()`.

        :param dim: number of dimensions of the layout
        :type dim: int
        :param node_pos_map: optionally, initial positions for nodes; otherwise,
            use random initial positions
        :type node_pos_map: dict(Node, tuple(float))
        :param fixed_nodes: optionally, a list of nodes to keep fixed at their
            initial positions
        :type fixed_nodes: list(Node)
        :param iterations: number of iterations of spring-force relaxation
        :type iterations: int
        :param weight_attr: the edge attribute that holds the numerical value
            used for the edge weight.  If None, then all edge weights are 1.
        :type weight_attr: str or None
        :param scale: scale factor for positions
        :type scale: float
        :return: a dictionary mapping nodes to their calculated positions
        :rtype: dict(Node, tuple(float))
        """

        if fixed_nodes:
            fixed_gnodes = [n.gnode for n in fixed_nodes]
        else:
            fixed_gnodes = None
        if node_pos_map:
            gnode_pos_map = {n.gnode: pos for n, pos in node_pos_map.items()}
        else:
            gnode_pos_map = None
        node_coords = spring_layout(
            self.ggraph,
            dim=dim,
            pos=gnode_pos_map,
            fixed=fixed_gnodes,
            iterations=iterations,
            weight=weight_attr,
            scale=scale)
        node_coord_map = {}
        for name, coords in node_coords.items():
            node = self.getNode(name)
            node_coord_map[node] = coords
        return node_coord_map

    def minCrossingSpringLayout(self,
                                num_iterations=100,
                                fixed_nodes=None,
                                fraction=1.0):
        """
        Perform multiple spring layouts and keep the one with the fewest edge
        intersections, keeping the original positions if the layout could not be
        improved.

        :param num_iterations: number of spring layouts to try
        :type num_iterations: int

        :param fixed: nodes for which the position should be fixed
        :type fixed: iterable of Node

        :param signal: whether to emit the positionChanged signal
        :type signal: bool

        :param fraction: stop iterating if no reduction in crossings is found
            within this fraction of num_iterations
        :type fraction: float
        """

        min_crossings = maxsize
        best_pos_map = None
        edges = self.getEdges()

        fixed_pos_map = None
        if fixed_nodes is not None:
            fixed_pos_map = {node: node.pos() for node in fixed_nodes}
        initial_pos_map = fixed_pos_map

        # if I have positions, take those as the ones to improve on
        if self.hasPositions():
            initial_pos_map = {node: node.pos() for node in self.getNodes()}
            _, min_crossings = _has_fewer_crossings(edges, initial_pos_map,
                                                    maxsize)
            best_pos_map = initial_pos_map

        if initial_pos_map:  # unscale coordinates to preserve location
            self._scaleNodeCoords(initial_pos_map, reverse=True)

        new_pos_map = initial_pos_map
        max_not_better_iters = min(num_iterations, fraction * num_iterations)
        not_better_iters = 0
        for i in range(num_iterations):
            not_better_iters += 1
            if min_crossings == 0:
                break
            if not_better_iters > max_not_better_iters:
                break
            new_pos_map = self._getSpringLayoutCoords(
                iterations=ITERATIONS,
                node_pos_map=new_pos_map,
                fixed_nodes=fixed_nodes,
                weight_attr=None,
                scale=SCALE)
            is_better, crossings = _has_fewer_crossings(edges, new_pos_map,
                                                        min_crossings)
            if is_better:
                not_better_iters = 0
                min_crossings = crossings
                best_pos_map = new_pos_map
            else:
                new_pos_map = fixed_pos_map

        if initial_pos_map is best_pos_map:
            return

        self._scaleNodeCoords(best_pos_map)
        self.setMultipleNodePos(best_pos_map)

    def _scaleNodeCoords(self, pos_dict, reverse=False):
        """
        Scales the positions in the pos_dict dictionary by factor or if reverse
        is True 1/factor, where factor = 0.5 x sqrt(NNodes)/2.
        Through manual testing 0.5 was determined to be a good multiplier.

        :param pos_dict: A dictionary mapping nodes or node names to (x,y)
            tuples.
        :type pos_dict: dict {Node : (int, int)}
        :param reverse: Whether to reverse the scaling
        :type revers: bool
        """
        num_nodes = len(self._ggraph.node) or 1
        scale = 0.5 * math.sqrt(num_nodes)
        if reverse:
            scale = 1.0 / scale
        self._scaleDictPositions(pos_dict, scale)

    @staticmethod
    def _scaleDictPositions(pos_dict, factor):
        """
        Multiplies the positions in {node: (x_pos, y_pos)} dictionary by factor.

        :param pos_dict: A dictionary mapping nodes to (x,y) tuples.
        :type pos_dict: dict {Node : (int, int)}
        :param factor: multiplication factor for positions
        :type factor: float
        """
        for node, xy in pos_dict.items():
            scaled_x_pos = xy[0] * factor
            scaled_y_pos = xy[1] * factor
            pos_dict[node] = [scaled_x_pos, scaled_y_pos]
        return pos_dict

    def hasPositions(self, accept_partial=False):
        """
        Determines whether the nodes in this graph have x-y coordinates.

        :param accept_partial: if set to True, the method will check whether
            at least one node has coordinates. Otherwise it requires that all nodes
            have coordinates.
        :type accept_partial: bool
        """
        fully_positioned = True
        for node in self.getNodes():
            if node.pos() is None:
                fully_positioned = False
            else:
                if accept_partial:
                    return True
        return fully_positioned

    #===========================================================================
    # Undo/redo
    #===========================================================================

    def getState(self):
        """
        Get the current state of the Graph
        """
        ggraph = copy.deepcopy(self._ggraph)
        node_objects = self.node_objects.copy()
        return ggraph, node_objects

    def setState(self, state):
        """
        Set the current state of the Graph
        """
        ggraph, node_objects = state
        self._ggraph = ggraph
        self.node_objects = node_objects
        self.selected_nodes = set()
        self.selected_edges = set()
        self.signals.graphChanged.emit()

    def setUndoPoint(self, signal=True):
        """
        Store the current state to the undo stack. Also wipes out the redo
        stack.
        """
        self.undo_stack.append(self.getState())
        while len(self.undo_stack) > self.max_undo_stack:
            self.undo_stack.pop(0)
        self.redo_stack = []
        if signal:
            self.signals.undoPointSet.emit()

    def undo(self):
        """
        Revert to the last state on the undo stack.
        """

        if not self.undo_stack:
            return

        self.redo_stack.append(self.getState())
        while len(self.redo_stack) > self.max_undo_stack:
            self.redo_stack.pop(0)

        state = self.undo_stack.pop()
        self.setState(state)

    def redo(self):
        """
        Undo the undo
        """

        if not self.redo_stack:
            return
        self.undo_stack.append(self.getState())
        state = self.redo_stack.pop()
        self.setState(state)

    def clearUndoHistory(self):
        """
        Clears both undo and redo stacks
        """
        self.undo_stack = []
        self.redo_stack = []

    def merge(self, g):
        """
        Merge data from another graph into this graph. Nodes with
        duplicate names will be considered to be the same ligand.

        :param g: graph from which data is being merged.
        :type g: `Graph`
        """
        for edge in g.getEdges():
            data_dict = edge.data()
            n1, n2 = edge
            if 'direction' not in data_dict:
                hex1, hex2 = n1.name, n2.name
                d = (hex1, hex2) if hex1 < hex2 else (hex2, hex1)
                edge.setData('direction', d)

        self._ggraph.add_nodes_from(g._ggraph.nodes(data=True))
        self._ggraph.add_edges_from(g._ggraph.edges(data=True))

    def deleteSelectedItems(self, include_edges=True, include_nodes=True):
        """
        Delete selected nodes and/or selected edges.

        :param include_edges: whether selected edges should be deleted
        :type include_edges: bool
        :param include_nodes: whether selected nodes should be deleted
        :type include_nodes: bool
        """
        nodes = self.selectedNodes() if include_nodes else set()
        edges = self.selectedEdges() if include_edges else set()
        if not nodes and not edges:
            return

        self.setUndoPoint()
        self.setSelectedObjs([])
        self.deleteItems(nodes, edges)

    def deleteItems(self, nodes=None, edges=None):
        """
        Delete specified nodes and edges from the FEP map.

        :param nodes: nodes to delete
        :type nodes: Set[Node]
        :param edges: edges to delete
        :type edges: Set[Tuple[Node, Node]]
        """
        nodes = nodes or set()
        edges = edges or set()
        connected_edges = set(self.getEdges(nodes))
        edges = edges.union(connected_edges)
        if edges:
            self.removeEdges(edges)
        if nodes:
            self.removeNodes(nodes)
        if edges or nodes:
            self.update()


class Node:
    """
    Model class for Node. Wraps the NetworkX Graph.node dictionary.
    """
    x_key = 'storedX'
    y_key = 'storedY'

    def __init__(self, name, graph=None):
        """
        Construct a Node object. Most of the time, this will be constructed
        around an existing NetworkX node (i.e. an entry in the
        networkx.Graph.node dict). If a graph is specified, a node of the same
        name must exist in the graph, or a ValueError will result.

        QT signals will only be emitted if a graph is specified.

        :param name: a unique identifier for this node
        :type name: hashable
        :param graph: the graph object to which this node belongs
        :type graph: `Graph`

        :ivar _gnode: the underlying graph node that this node wraps. In this
                class, we use the node name as the graph node, but any hashable
                object can be used.
        :ivar _gdata: dictionary that stores data belonging to the underlying
                graph node.
        """

        gdata = {}
        if graph:
            try:
                gdata = graph.ggraph.node.get(name, {})
            except KeyError:
                raise ValueError('Node %s not found in graph.' % name)
        self._gnode = name
        self._gdata = gdata
        self.graph = graph

    @property
    def gnode(self):
        """
        Return the underlying graph node object wrapped by this `Node` instance
        (not the data dictionary `_gdata`).
        """

        return self._gnode

    @property
    def name(self):
        """
        Return unique string associated with this node. Convert to string for
        subclasses which do not necessarily use strings as graph nodes.
        """

        return str(self.gnode)

    #===========================================================================
    # Positioning
    #===========================================================================
    def x(self):
        return self._gdata.get(self.x_key, None)

    def y(self):
        return self._gdata.get(self.y_key, None)

    def pos(self):
        """
        Returns the Node's current position coordinates. Returns None if there
        are no coordinates.

        :rtype: tuple (float, float)
        """
        pos = (self.x(), self.y())
        if None in pos:
            return None
        return pos

    def setX(self, x, signal=True):
        if self.x() == x:
            return

        self._gdata[self.x_key] = x
        if signal and self.graph:
            self.graph.signals.positionChanged.emit({self})

    def setY(self, y, signal=True):
        if self.y() == y:
            return

        self._gdata[self.y_key] = y
        if signal and self.graph:
            self.graph.signals.positionChanged.emit({self})

    def setPos(self, x, y, signal=True):
        """
        Set the node's position coordinates

        :param x: x coordinate
        :type x: float
        :param y: y coordinate
        :type y: float
        """

        if self.x() == x and self.y() == y:
            return

        self.setX(x, False)
        self.setY(y, False)
        if signal and self.graph:
            self.graph.signals.positionChanged.emit({self})

    #===========================================================================
    # General node properties
    #===========================================================================

    def gdata(self):
        """
        Directly access the node data dictionary. Use this object carefully, as
        directly altering its contents can lead to internal inconsistencies.

        This may be wrapped to restrict access.
        """
        return self._gdata

    def getData(self, key):
        """
        Return the requested item from the node's data dictionary. Returns None
        if the key is not found.
        """
        return self._gdata.get(key, None)

    def setData(self, key, value, signal=True):
        """
        Set the value of an item in the node's data dictionary.
        """
        self._gdata[key] = value
        if signal and self.graph:
            self.graph.signals.nodesChanged.emit({self})

    @property
    def degree(self):
        """
        :return: the degree (number of edges) of the node
        :rtype: int
        """

        return self.graph.ggraph.degree(self.gnode)

    def __repr__(self):
        return '<Node("%s")>' % self.name

    def __str__(self):
        return self.__repr__()

    def __eq__(self, rhs):
        try:
            return id(self.graph) == id(rhs.graph) and self.name == rhs.name
        except AttributeError:
            return False

    def __ne__(self, rhs):
        return not self == rhs

    def __hash__(self):
        return hash((id(self.graph), self.name))


class Edge:

    def __init__(self, gedge, graph):
        """
        :param gedge: the underlying edge object wrapped by this object
        :type gedge: object
        :param graph: the graph object to which this edge belongs
        :type graph: Graph
        """

        self._gedge = gedge
        self._graph = graph

    @property
    def gedge(self):
        """
        :return: the underlying edge object wrapped by this object
        :rtype: fep.graph.Edge
        """
        return self._gedge

    @property
    def graph(self):
        """
        :return: the graph to which this edge belongs
        :rtype: Graph
        """

        return self._graph

    @property
    def nodes(self):
        """
        :return: the nodes connected by this edge in a consistent order, as
            determined by the underlying graph edge
        :rtype: tuple(Node, Node)
        """

        return tuple(self.graph.getNode(gnode) for gnode in self.gedge)

    def data(self):
        """
        :return: the data dictionary associated with this edge
        :rtype: dict(str, object)
        """

        ggraph, gedge = self.graph.ggraph, self.gedge
        return dict(ggraph[gedge[0]][gedge[1]])

    def getData(self, key):
        """
        Return the requested item from the edge's data dictionary. Returns None
        if the key is not found.

        :param key: the data item key
        :type key: str
        :return: the value stored under the specified key in the edge's data
            dictionary, or `None` if it is not found
        :rtype: object
        """

        data_dict = self.data()
        return data_dict.get(key)

    def setData(self, key, value, signal=True):
        """
        Set the specified item in the edge's data dictionary.

        :param key: the data item key
        :type key: str
        :param value: the value to set for the data item
        :type value: object
        """

        ggraph, gedge = self.graph.ggraph, self.gedge
        data_dict = ggraph[gedge[0]][gedge[1]]
        old_value = data_dict.get(key)
        data_dict[key] = value
        if signal and old_value != value:
            self.graph.signals.edgesChanged.emit({self})

    @property
    def name(self):
        """
        :return: the name of the edge, a composite of the connected node names
        :rtype: str
        """

        node0, node1 = self.nodes
        name0 = 'None' if node0 is None else node0.name
        name1 = 'None' if node1 is None else node1.name
        return f'"{name0}" - "{name1}"'

    def __getitem__(self, idx):
        """
        Return a node connected by this edge. Only accepts indices 0 and 1.

        :param idx: node index
        :type idx: `int`

        :return: node corresponding to supplied index
        :rtype: `LigandNode`
        """
        return self.nodes[idx]

    def __eq__(self, rhs):
        try:
            return self.graph == rhs.graph and self.gedge == rhs.gedge
        except AttributeError:
            return False

    def __ne__(self, rhs):
        return not self == rhs

    def __hash__(self):
        return hash((self.graph, self.gedge))

    def __str__(self):
        return f'<{self.__class__.__name__}({self.name})>'

    def __repr__(self):
        return self.__str__()


class ConnectionValidator:
    """
    Create a subclass of this and assign it using
    NetworkViewer.setConnectionValidator( )
    to do extra work making sure node's are compatible to connect.
    val1 and val2 are node1.val and node2.val
    """

    def __init__(self):
        self.first_node = None

    def validate(self, node1, node2):
        return True, "No problem"

    def firstNode(self):
        return self.first_node

    def setFirstNode(self, node):
        self.first_node = node

    def validateSecondVal(self, val):
        if self.firstNode():
            return self.validate(self.first_node.val, val)


#===============================================================================
# Network View Classes
#===============================================================================


class AbstractNetworkView:
    """
    A base class for views on Graph models. Use setModel to replace the model
    object. Signals from the model are automatically connected to appropriate
    synchronization slots.

    The abstract view does not provide any built-in support for effecting
    changes back into the model (ex. deleting nodes, changing selection). Any
    such operations should be implemented in the subclass by making calls
    directly to the model. These changes will then be automatically synchronized
    forward to all views.

    self.nodes is a dictionary mapping model node objects to view node objects.

    self.edges is a dictionary mapping pairs of model node objects to view edge
    objects. There is no such thing as a edge model object.

    Note that all references to the word node and edge in method names refer to
    view objects. For example, makeNode() will make a view node, addEdge() will
    add an edge view object to the view.

    :cvar MODEL_CLASS: an instance of this class will be created as the default
        model when `setModel`
    :vartype MODEL_CLASS: `Graph` or subclass of `Graph`
    :ivar _sync_with_model: whether to automatically synchronize this view (and
        its subviews) with the model
    :vartype _sync_with_model: bool
    """
    MODEL_CLASS = Graph

    def __init__(self):
        self.model = None
        self.nodes = {}
        self.edges = {}
        self.skip_selectionChanged = False
        self._subviews = set()
        self._sync_with_model = True

    #===========================================================================
    # Model-View Connections
    #===========================================================================

    def syncAll(self):
        """
        Synchronize the full model and selection state.
        """

        model = self.model
        self.syncModel()
        selection = model.selected_nodes.union(model.selected_edges)
        self.syncSelection(selection, model)

    def syncRecursive(self):
        """
        Synchronize the full model and selection state on this view and all
        subviews.
        """

        self.syncAll()
        for subview in self._subviews:
            subview.syncRecursive()

    def setModelSyncEnabled(self, enable):
        """
        Enable or disable automatic synchronization with the model for this view
        and all subviews.
        """

        for subview in self._subviews:
            subview.setModelSyncEnabled(enable)

        if self._sync_with_model == enable:
            return

        self._sync_with_model = enable
        if enable:
            self._connectSignals()
            self.syncAll()
        else:
            self._disconnectSignals()

    def setModel(self, model):
        """
        Set the model for this view and synchronize to it. Any subviews will
        have the model set on them as well.

        :param model: the graph model
        :type model: Graph
        """
        if model is None:
            model = self.MODEL_CLASS()
        if self._sync_with_model:
            self._disconnectSignals()
        self.model = model
        for subview in self._subviews:
            subview.setModel(model)
        if self._sync_with_model:
            self._connectSignals()
            self.syncAll()

    def _connectSignals(self):
        """
        If a model is defined, connect all signal/slot pairs.
        """

        if self.model:
            for signal, slot in self.getSignalsAndSlots(self.model):
                signal.connect(slot)

    def _disconnectSignals(self):
        """
        If a model is defined, disconnect all signal/slot pairs.
        """

        if self.model:
            for signal, slot in self.getSignalsAndSlots(self.model):
                signal.disconnect(slot)

    def getSignalsAndSlots(self, model):
        """
        Get a list of signal/slot pairs for a model. This list will be used when
        setting a new model to disconnect the old model signals from their slots
        and connect the new model's signals to those slots.

        Override this method to modify or extend signals/slots in derived
        classes.

        :param model: the graph model
        :type model: Graph
        """
        signals = model.signals
        ss_list = [
            (signals.graphChanged, self.syncModel),
            (signals.nodesAdded, self.syncNodesAdded),
            (signals.nodesDeleted, self.syncNodesDeleted),
            (signals.nodesChanged, self.syncNodesChanged),
            (signals.edgesChanged, self.syncModel),
            (signals.selectionChanged, self.syncSelection),
        ]
        return ss_list

    def addSubview(self, subview):
        """
        Add a subview to this view. A subview is another AbstractNetworkView
        that should always have the same model as its parent view (this view).

        Adding will automatically set its model to the current model. Changing
        the model on this view will result in all its subviews getting the new
        model set

        :param subview: the new subview to add to this view
        :type subview: AbstractNetworkView
        """
        self._subviews.add(subview)
        subview.setModel(self.model)

    def removeSubview(self, subview):
        """
        Removes the specified subview. The subview is not deleted or altered,
        and the model remains set.

        :param subview:
        :type subview:
        """
        self._subviews.remove(subview)

    #===========================================================================
    # Model-View Synchronization
    #===========================================================================

    def syncModel(self):
        self.syncNodes()
        self.syncEdges()

    def syncNodes(self):
        graph = self.model
        nodeset = graph.getNodes()
        delnodes = set(self.nodes).difference(nodeset)
        self.syncNodesAdded(nodeset)
        self.syncNodesChanged(nodeset)
        self.syncNodesDeleted(delnodes)

    def syncNodesDeleted(self, nodes):
        self._removeNodes(nodes)
        self.syncEdges()

    def syncNodesAdded(self, nodes):
        new_nodes = nodes.difference(set(self.nodes))
        self._addNodes(new_nodes)
        self.syncEdges()

    def syncNodesChanged(self, nodes):
        self.updateNodes(nodes)

        if self.edges:
            edges = self.model.getEdges(nodes)
            self.updateEdges(edges)

    def syncEdges(self):
        model_edges = set(self.model.getEdges())
        known_edges = set(self.edges)

        del_edges = known_edges.difference(model_edges)
        self._removeEdges(del_edges)

        add_edges = model_edges.difference(known_edges)
        self._addEdges(add_edges)

        up_edges = model_edges.intersection(known_edges)
        self.updateEdges(up_edges)

    def syncSelection(self, selection, source):
        if source == self:
            return
        selected_view_objects = []
        for model_obj in selection:
            if isinstance(model_obj, Node):
                viewnode = self.nodes.get(model_obj)
                if viewnode:
                    selected_view_objects.append(viewnode)
            elif isinstance(model_obj, Edge):
                viewedge = self.getEdge(model_obj)
                if viewedge:
                    selected_view_objects.append(viewedge)
        self.skip_selectionChanged = True
        self.selectItems(selected_view_objects)
        self.skip_selectionChanged = False

    #===========================================================================
    # Node operations
    #===========================================================================

    def _addNodes(self, nodes):
        node_map = self.makeNodes(nodes)
        self.nodes.update(node_map)
        self.addNodes(set(node_map.values()))

    def _removeNodes(self, nodes):
        viewnodes = [self.getNode(node) for node in nodes]
        self.removeNodes(viewnodes)
        for node in nodes:
            self.nodes.pop(node)

    def makeNodes(self, nodes):
        """
        Create new view nodes and return a dictionary mapping supplied model
        nodes to corresponding view nodes. Do not add new view nodes to the
        view.

        By default this method returns an "identity dictionary" that maps nodes
        to themselves. Subclasses should override this method to implement their
        own view nodes.

        :param nodes: model nodes
        :type nodes: list(Node)

        :return: a dictionary mapping supplied nodes to view nodes
        :rtype: dict(Node, object)
        """

        return {node: node for node in nodes}

    def makeNode(self, node):
        """
        Convenience method for calling `makeNodes()` with a single node. Rather
        than returning a dictionary mapping nodes to view nodes, returns the
        view node corresponding to the supplied node.

        :param node: the model node
        :type node: Node

        :return: the view node
        :rtype: object
        """

        node_map = self.makeNodes([node])
        return node_map.get(node)

    def addNode(self, viewnode):
        """
        A convenience function for calling `addNodes()` for a single node.

        :param viewnode: a view node
        :type viewnode: object
        """

        self.addNodes([viewnode])

    def removeNode(self, viewnode):
        """
        Convenience method for calling `removeNode()` for a single node.

        :param viewnode: a view node
        :type viewnode: object
        """

        self.removeNodes([viewnode])

    def updateNode(self, node):
        """
        Convenience method for calling `updateNodes()` for a single node.

        :param node: the model node to update to
        :type node: Node
        """

        self.updateNodes([node])

    def getModelNodes(self, node_keys=None):
        """
        Retrieve a set of model nodes optionally indicated by a list of keys. If
        none is provided, return all nodes.

        :param node_keys: optionally, a list of nodes, gnodes, or strings that
            correspond to the desired model nodes
        :type node_keys: list(object) or NoneType
        :return: a set of nodes
        :rtype: set(Node)
        """

        nodes_in_view = set(list(self.nodes))
        nodes_in_model = self.model.getNodes(node_keys)
        return nodes_in_view.intersection(nodes_in_model)

    def getNode(self, node):
        """
        :param node: a model node
        :type node: Node

        :return: corresponding view node, if available
        :rtype: `object` or `None`
        """

        return self.nodes.get(node, None)

    #===========================================================================
    # Edge operations
    #===========================================================================

    def _addEdges(self, edges):
        for edge in edges:
            view_edge = self.getEdge(edge)
            if view_edge is not None:
                msg = f'A view edge already exists for {edge}.'
                raise ValueError(msg)
        edge_map = self.makeEdges(edges)
        for edge, view_edge in edge_map.items():
            self.edges[edge] = view_edge
        self.addEdges(list(edge_map.values()))

    def _removeEdges(self, edges):
        view_edges = [self.getEdge(edge) for edge in edges]
        self.removeEdges(view_edges)
        for edge in edges:
            self.edges.pop(edge)

    def makeEdges(self, edges):
        """
        Given a list of model edges, return a dictionary mapping them to
        corresponding view edges. Does not add view edges to the view.

        By default this method returns an identity dictionary, mapping model
        edges to themselves. Subclasses should override this method if they
        want to implement their own view edges.

        :param edges: a list model nodes
        :type nodepairs: list(Edge)
        :return: a dictionary mapping model edges to view edges
        :rtype: dict(Edge, object)
        """

        return {edge: edge for edge in edges}

    def makeEdge(self, edge):
        """
        Convenience method for calling `makeEdges()` for a single edge. Rather
        than return a dictionary mapping model edges to view edges, returns a
        singe view edge. Does not add a view edge to the view.

        :param edge: a model edge
        :type edge: Edge
        :return: a view edge
        :rtype: object
        """

        edge_map = self.makeEdges([edge])
        return edge_map.get(edge)

    def addEdge(self, viewedge):
        """
        Convenience method for calling `addEdges()` for a single edge.

        :param viewedge: the view edge to add to the view
        :type viewedge: object
        """

        self.addEdges([viewedge])

    def removeEdge(self, viewedge):
        """
        Convenience method for calling `removeEdges()` for a single edge.

        :param viewedge: the view edge to remove from the view
        :type viewedge: object
        """

        self.removeEdges([viewedge])

    def updateEdge(self, edge):
        """
        A convenience method for calling `updateEdges()` for a single edge.

        :param edge: the model edge corresponding to the view edge to update
        :type edge: Edge
        """

        self.updateEdges([edge])

    def getModelEdges(self, nodes=None):
        """
        Return all model edges connected to a model node or set of model nodes.
        If no node is specified, all the edges in the graph are returned. This
        method acts like `Graph.getEdges()`, but it filters for model edges that
        are available in this view.

        :param nodes: optionally a node or list of nodes
        :type nodes: `list(Node)`, `Node`, or `None`
        :return: a list of model edges
        :rtype: list(Edge)
        """

        model_edges = set(self.model.getEdges(nodes=nodes))
        return list(model_edges.intersection(set(self.edges)))

    def getEdge(self, edge):
        """
        Return the view edge corresponding to the supplied model edge.

        :param edge: a model edge
        :type edge: Edge
        :return: the corresponding view edge if available
        :rtype: object or None
        """

        return self.edges.get(edge)

    def getEdges(self, nodes=None):
        """
        Return a list of view edges, filtering the list so that the edges are
        connected to the optionally-supplied node or iterable of nodes.

        :param nodes: a node or iterable of nodes
        :type nodes: iterable[Node] or Node or NoneType
        :return: list of view edges
        :rtype: list[NetworkEdge or NoneType]
        """

        return [self.getEdge(edge) for edge in self.model.getEdges(nodes)]

    #===========================================================================
    # Pure virtual methods
    #===========================================================================

    def addNodes(self, viewnodes):
        """
        Takes view nodes and adds them to the view if that makes sense (eg. add
        graphics items to scene, add rows to table, etc.) It should not add
        the view node to `self.nodes`; that is handled in `_addNodes()`.

        :param viewnodes: view nodes to add to the view
        :type viewnodes: list(object)
        """

    def removeNodes(self, viewnodes):
        """
        Removes view nodes from the view if that makes sense (eg. remove
        graphics items from scene, remove table rows, etc.) It should not remove
        view nodes from `self.nodes`; that is handled in `_removeNodes()`.

        :param viewnodes: a list of view nodes
        :type viewnodes: list(object)
        """

    def updateNodes(self, nodes):
        """
        Performs any operations necessary to update the view to the current
        model state. Note that this method takes model nodes, not view nodes.

        :param nodes: model nodes which must have their views updated
        :type nodes: list(Node)
        """

    def addEdges(self, viewedges):
        """
        Adds view edges to the view. Does not add view edges to `self.edges`.

        :param viewedges: view edges to add to the view
        :type viewedges: list(object)
        """

    def removeEdges(self, viewedges):
        """
        Removes view edges from the view. Does not remove view edges from
        `self.edges`.

        :param viewedges: view edges to remove from the view
        :type viewedges: list(object)
        """

    def updateEdges(self, edges):
        """
        Performs any operations necessary to update the view to the current
        model state.

        :param edges: a list of model edges corresponding to view edges that
            should be updated
        :type edges: list(Edge)
        """

    def selectItems(self, selected_view_objects):
        """
        Selects view objects in the view. Currently only view nodes will be
        requested, but may be expanded to allow a combination of nodes and
        edges to be selected.

        :param selected_view_objects: a list of view objects to be selected
        :type selected_view_objects: list(object)
        """


#===============================================================================
# Layout calculations
#===============================================================================

#
# line segment intersection using vectors
# see Computer Graphics by F.S. Hill
#


def perp(a):
    b = np.empty_like(a)
    b[0] = -a[1]
    b[1] = a[0]
    return b


# line segment a given by endpoints a1, a2
# line segment b given by endpoints b1, b2
# return
_eps = 1e-8


def seg_intersect(a1, a2, b1, b2):
    """
    Checks whether two line segments cross each other.

    :param a1: first endpoint of line segment a
    :type a1: numpy.array
    :param a2: second endpoint of line segment a
    :type a2: numpy.array
    :param b1: first endpoint of line segment b
    :type b1: numpy.array
    :param b2: second endpoint of line segment b
    :type b2: numpy.array

    :return: whether the line segments intersect
    :rtype: bool
    """
    da = a2 - a1
    db = b2 - b1
    dap = perp(da)
    denom = np.dot(dap, db)
    if denom == 0:  # Line segments are parallel
        return False
    dp = a1 - b1
    num = np.dot(dap, dp)
    cx = (num / denom) * db[0] + b1[0]  # x-value of intersecting point
    # The epsilon is added to account for floating point precision.
    return (cx - _eps > min(a1[0], a2[0]) and cx + _eps < max(a1[0], a2[0]) and
            cx - _eps > min(b1[0], b2[0]) and cx + _eps < max(b1[0], b2[0]))


def _has_fewer_crossings(edges, node_coords, goal):
    """
    Determines whether the graph has less intersections than goal.
    """
    np_edges = []
    for n1, n2 in edges:
        x1, y1 = node_coords[n1]
        x2, y2 = node_coords[n2]
        p1 = np.array([x1, y1])
        p2 = np.array([x2, y2])
        np_edges.append((p1, p2))
    num_edges = len(np_edges)
    crossings = 0
    for i in range(num_edges - 1):
        for j in range(i + 1, num_edges):
            a1, a2 = np_edges[i]
            b1, b2 = np_edges[j]
            if seg_intersect(a1, a2, b1, b2):
                if crossings == goal:
                    return False, crossings
                crossings += 1
    return crossings < goal, crossings


#===============================================================================
# Code copied and modified from networkx.drawing.layout
#===============================================================================


def fruchterman_reingold_layout(G,
                                dim=2,
                                pos=None,
                                fixed=None,
                                iterations=50,
                                weight='weight',
                                scale=1):
    """
    Position nodes using Fruchterman-Reingold force-directed algorithm.

    :param G: NetworkX graph

    :param dim: Dimension of layout
    :type dim: int

    :param pos: Initial positions for nodes as a dictionary with node as keys
        and values as a list or tuple.  If None, then use random initial
        positions.
    :type pos: dict

    :param fixed: Nodes to keep fixed at initial position. optional
    :type fixed: list

    :param iterations: Number of iterations of spring-force relaxation
    :type iterations: int

    :param weight: The edge attribute that holds the numerical value used for
        the edge weight.  If None, then all edge weights are 1.
    :type weight: str or None

    :param scale: Scale factor for positions
    :type scale: float

    :rtype: dict
    :returns: A dictionary of positions keyed by gnode

    Examples::

        >>> G=nx.path_graph(4)
        >>> pos=nx.spring_layout(G)

        # The same using longer function name
        >>> pos=nx.fruchterman_reingold_layout(G)
    """

    if fixed is not None:
        gnode_idx_map = {gnode: idx for idx, gnode in enumerate(G)}
        fixed = np.asarray([gnode_idx_map[v] for v in fixed])

    if pos is not None:
        pos_arr = np.asarray(np.random.random((len(G), dim)))
        for i, n in enumerate(G):
            if n in pos:
                pos_arr[i] = np.asarray(pos[n])
    else:
        pos_arr = None

    if len(G) == 0:
        return {}
    if len(G) == 1:
        return {next(iter(G.nodes())): (1,) * dim}

    A = nx.to_numpy_matrix(G, weight=weight)
    pos = _fruchterman_reingold(A, dim, pos_arr, fixed, iterations)

    if fixed is None:
        pos = _rescale_layout(pos, scale=scale)
    return dict(list(zip(G, pos)))


spring_layout = fruchterman_reingold_layout


def _fruchterman_reingold(A, dim=2, pos=None, fixed=None, iterations=50):
    # Position nodes in adjacency matrix A using Fruchterman-Reingold
    # Entry point for NetworkX graph is fruchterman_reingold_layout()
    try:
        import numpy as np
    except ImportError:
        raise ImportError(
            "_fruchterman_reingold() requires numpy: http://scipy.org/ ")

    try:
        nnodes, _ = A.shape
    except AttributeError:
        raise nx.NetworkXError(
            "fruchterman_reingold() takes an adjacency matrix as input")

    A = np.asarray(A)  # make sure we have an array instead of a matrix

    if pos is None:
        # random initial positions
        pos = np.asarray(np.random.random((nnodes, dim)), dtype=A.dtype)
    else:
        # make sure positions are of same type as matrix
        pos = pos.astype(A.dtype)

    # optimal distance between nodes
    k = np.sqrt(1.0 / nnodes)
    # the initial "temperature"  is about .1 of domain area (=1x1)
    # this is the largest step allowed in the dynamics.
    t = STARTING_TEMPERATURE
    # simple cooling scheme.
    # linearly step down by dt on each iteration so last iteration is size dt.
    dt = t / (iterations + 1)
    delta = np.zeros((pos.shape[0], pos.shape[0], pos.shape[1]), dtype=A.dtype)
    # the inscrutable (but fast) version
    # this is still O(V^2)
    # could use multilevel methods to speed this up significantly
    for iteration in range(iterations):
        # matrix of difference between points
        for i in range(pos.shape[1]):
            delta[:, :, i] = pos[:, i, None] - pos[:, i]
        # distance between points
        distance = np.sqrt((delta**2).sum(axis=-1))
        # enforce minimum distance of 0.01
        distance = np.where(distance < 0.01, 0.01, distance)
        # displacement "force"
        displacement = np.transpose(
            np.transpose(delta) *
            (PUSH * k**PUSHEXP / distance**PUSHEXP - A * distance / k
            )).sum(axis=1)
        # update positions
        length = np.sqrt((displacement**2).sum(axis=1))
        length = np.where(length < 0.01, 0.01, length)
        delta_pos = np.transpose(np.transpose(displacement) * t / length)
        if fixed is not None:
            # don't change positions of fixed nodes
            delta_pos[fixed] = 0.0
        pos += delta_pos
        # cool temperature
        t -= dt

    return pos


def _rescale_layout(pos, scale=1):
    # rescale to (0,pscale) in all axes

    # shift origin to (0,0)
    lim = 0  # max coordinate for all axes
    for i in range(pos.shape[1]):
        pos[:, i] -= pos[:, i].min()
        lim = max(pos[:, i].max(), lim)
    # rescale to (0,scale) in all directions, preserves aspect
    for i in range(pos.shape[1]):
        pos[:, i] *= scale / lim
    return pos