#!/usr/bin/env python3
"""
Main Application (flow_UI.py)

This file dynamically imports custom node classes from the 'Nodes' package,
registers them with NodeGraphQt, and sets up an empty graph.
Nodes can be added dynamically via the graph’s right-click context menu,
and a "Remove Selected Node" option is provided.
A global update timer periodically calls process_input() on nodes.
Additionally, this file patches QGraphicsScene.setSelectionArea to handle 
selection behavior properly (so that multiple nodes can be selected).
"""

# --- Patch QGraphicsScene.setSelectionArea to handle selection arguments ---
from Qt import QtWidgets, QtCore, QtGui

_original_setSelectionArea = QtWidgets.QGraphicsScene.setSelectionArea

def _patched_setSelectionArea(self, painterPath, second_arg, *args, **kwargs):
    try:
        # Try calling the original method with the provided arguments.
        return _original_setSelectionArea(self, painterPath, second_arg, *args, **kwargs)
    except TypeError as e:
        # If a TypeError is raised, assume the call was made with only a QPainterPath
        # and an ItemSelectionMode, and patch it by supplying defaults.
        # Default operation: ReplaceSelection, default transform: QTransform()
        return _original_setSelectionArea(self, painterPath,
                                          QtCore.Qt.ReplaceSelection,
                                          second_arg,
                                          QtGui.QTransform())

# Monkey-patch the setSelectionArea method.
QtWidgets.QGraphicsScene.setSelectionArea = _patched_setSelectionArea

# --- End of patch section ---

import sys
import pkgutil
import importlib
import inspect
from Qt import QtWidgets, QtCore
from NodeGraphQt import NodeGraph, BaseNode

def import_nodes_from_folder(package_name):
    """
    Dynamically import all modules from the given package and return a list of
    classes that subclass BaseNode.
    """
    imported_nodes = []
    package = importlib.import_module(package_name)
    for loader, module_name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
        module = importlib.import_module(module_name)
        for name, obj in inspect.getmembers(module, inspect.isclass):
            if issubclass(obj, BaseNode) and obj.__module__ == module.__name__:
                imported_nodes.append(obj)
    return imported_nodes

def make_node_command(graph, nt):
    """
    Given a NodeGraph instance and a node type string nt, return a command
    function that creates a node of that type.
    """
    def command():
        try:
            node = graph.create_node(nt)
            # (No need to force top-level if nodes are created as top-level by default.)
        except Exception as e:
            print(f"Error creating node of type {nt}: {e}")
    return command

if __name__ == '__main__':
    app = QtWidgets.QApplication([])

    # Create the NodeGraph controller.
    graph = NodeGraph()
    graph.widget.setWindowTitle("Modular Nodes Demo")

    # Dynamically import custom node classes from the 'Nodes' package.
    custom_nodes = import_nodes_from_folder('Nodes')
    for node_class in custom_nodes:
        graph.register_node(node_class)

    # Add context menu commands for dynamic node creation.
    graph_context_menu = graph.get_context_menu('graph')
    for node_class in custom_nodes:
        # Build the node type string: "<__identifier__>.<ClassName>"
        node_type = f"{node_class.__identifier__}.{node_class.__name__}"
        node_name = node_class.NODE_NAME
        graph_context_menu.add_command(
            f"Add {node_name}",
            make_node_command(graph, node_type)
        )

    # Add a "Remove Selected Node" command to the graph context menu.
    graph_context_menu.add_command(
        "Remove Selected Node",
        lambda: [graph.remove_node(node) for node in graph.selected_nodes()] if graph.selected_nodes() else None
    )

    # Resize and show the graph widget.
    graph.widget.resize(1200, 800)
    graph.widget.show()

    # Global update timer:
    #  - Call process_input() on every node that implements it.
    def global_update():
        for node in graph.all_nodes():
            if hasattr(node, "process_input"):
                try:
                    node.process_input()
                except Exception as e:
                    print("Error updating node", node, e)
    timer = QtCore.QTimer()
    timer.timeout.connect(global_update)
    timer.start(500)

    sys.exit(app.exec())