# Modules/data_collector.py

import threading
import time
import re
import sys
import numpy as np
import cv2
import concurrent.futures

# Vision-related Imports
import pytesseract
import easyocr
import torch

from PIL import Image, ImageGrab, ImageFilter

from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
from PyQt5.QtGui import QPainter, QPen, QColor, QFont

# Initialize EasyOCR with CUDA support
reader_cpu = None
reader_gpu = None

def initialize_ocr_engines():
    global reader_cpu, reader_gpu
    reader_cpu = easyocr.Reader(['en'], gpu=False)
    reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)

pytesseract.pytesseract.tesseract_cmd = r"C:\\Program Files\\Tesseract-OCR\\tesseract.exe"

DEFAULT_WIDTH = 180
DEFAULT_HEIGHT = 130
HANDLE_SIZE = 5
LABEL_HEIGHT = 20

collector_mutex = QMutex()
regions = {}

app_instance = None

def _ensure_qapplication():
    """
    Ensures that QApplication is initialized before creating widgets.
    Must be called from the main thread.
    """
    global app_instance
    if app_instance is None:
        app_instance = QApplication(sys.argv)  # Start in main thread

def capture_region_as_image(region_id):
    collector_mutex.lock()
    if region_id not in regions:
        collector_mutex.unlock()
        return None
    x, y, w, h = regions[region_id]['bbox'][:]
    collector_mutex.unlock()
    screenshot = ImageGrab.grab(bbox=(x, y, x + w, y + h))
    return screenshot

def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0), thickness=2):
    """
    Creates an OCR region with a visible, resizable box on the screen.
    Allows setting custom color (RGB) and line thickness.
    """
    _ensure_qapplication()
    
    collector_mutex.lock()
    if region_id in regions:
        collector_mutex.unlock()
        return
    regions[region_id] = {
        'bbox': [x, y, w, h],
        'raw_text': "",
        'widget': OCRRegionWidget(x, y, w, h, region_id, color, thickness)
    }
    collector_mutex.unlock()

def get_raw_text(region_id):
    collector_mutex.lock()
    if region_id not in regions:
        collector_mutex.unlock()
        return ""
    text = regions[region_id]['raw_text']
    collector_mutex.unlock()
    return text

def start_collector():
    initialize_ocr_engines()
    t = threading.Thread(target=_update_ocr_loop, daemon=True)
    t.start()

def _update_ocr_loop():
    while True:
        collector_mutex.lock()
        region_ids = list(regions.keys())
        collector_mutex.unlock()

        for rid in region_ids:
            collector_mutex.lock()
            bbox = regions[rid]['bbox'][:]
            collector_mutex.unlock()

            x, y, w, h = bbox
            screenshot = ImageGrab.grab(bbox=(x, y, x + w, y + h))
            processed = _preprocess_image(screenshot)
            raw_text = pytesseract.image_to_string(processed, config='--psm 6 --oem 1')

            collector_mutex.lock()
            if rid in regions:
                regions[rid]['raw_text'] = raw_text
            collector_mutex.unlock()

        time.sleep(0.7)

def _preprocess_image(image):
    gray = image.convert("L")
    scaled = gray.resize((gray.width * 3, gray.height * 3))
    thresh = scaled.point(lambda p: 255 if p > 200 else 0)
    return thresh.filter(ImageFilter.MedianFilter(3))


def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU", num_slices=1):
    """
    Uses user-defined horizontal slices and threading for faster inference.
    """
    collector_mutex.lock()
    if region_id not in regions:
        collector_mutex.unlock()
        return []
    
    bbox = regions[region_id]['bbox']
    collector_mutex.unlock()

    x, y, w, h = bbox
    left, top, right, bottom = x, y, x + w, y + h

    if right <= left or bottom <= top:
        print(f"[ERROR] Invalid OCR region bounds: {bbox}")
        return []

    try:
        image = ImageGrab.grab(bbox=(left, top, right, bottom))
        orig_width, orig_height = image.size

        word_positions = []

        # Ensure number of slices does not exceed image height
        num_slices = min(num_slices, orig_height)
        strip_height = max(1, orig_height // num_slices)  

        def process_strip(strip_id):
            strip_y = strip_id * strip_height
            strip = image.crop((0, strip_y, orig_width, strip_y + strip_height))

            strip_np = np.array(strip)  

            detected_positions = []
            if ocr_engine == "CPU":
                ocr_data = pytesseract.image_to_data(strip, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)

                for i in range(len(ocr_data['text'])):
                    if re.search(rf"\b{word}\b", ocr_data['text'][i], re.IGNORECASE):
                        x_scaled = int(ocr_data['left'][i])
                        y_scaled = int(ocr_data['top'][i]) + strip_y  
                        w_scaled = int(ocr_data['width'][i])
                        h_scaled = int(ocr_data['height'][i])

                        detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))

            else:
                results = reader_gpu.readtext(strip_np)
                for (bbox, text, _) in results:
                    if re.search(rf"\b{word}\b", text, re.IGNORECASE):
                        (x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]

                        x_scaled = int(x_min)
                        y_scaled = int(y_min) + strip_y
                        w_scaled = int(x_max - x_min)
                        h_scaled = int(y_max - y_min)

                        detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))

            return detected_positions

        with concurrent.futures.ThreadPoolExecutor(max_workers=num_slices) as executor:
            strip_results = list(executor.map(process_strip, range(num_slices)))

        for strip_result in strip_results:
            word_positions.extend(strip_result)

        return word_positions

    except Exception as e:
        print(f"[ERROR] Failed to capture OCR region: {e}")
        return []

def draw_identification_boxes(region_id, positions, color=(0, 0, 255), thickness=2):
    """
    Draws non-interactive rectangles at specified positions within the given OCR region.
    Uses a separate rendering thread to prevent blocking OCR processing.
    """
    collector_mutex.lock()
    if region_id in regions and 'widget' in regions[region_id]:
        widget = regions[region_id]['widget']
        widget.update_draw_positions(positions, color, thickness)
    collector_mutex.unlock()

def update_region_slices(region_id, num_slices):
    """
    Updates the number of visual slices in the OCR region.
    """
    collector_mutex.lock()
    if region_id in regions and 'widget' in regions[region_id]:
        widget = regions[region_id]['widget']
        widget.set_num_slices(num_slices)
    collector_mutex.unlock()

class OCRRegionWidget(QWidget):
    def __init__(self, x, y, w, h, region_id, color, thickness):
        super().__init__()

        self.setGeometry(x, y, w, h)
        self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.Tool)
        self.setAttribute(Qt.WA_TranslucentBackground, True)
        self.setAttribute(Qt.WA_TransparentForMouseEvents, False)

        self.region_id = region_id
        self.box_color = QColor(*color)
        self.line_thickness = thickness
        self.draw_positions = []
        self.previous_positions = []  # This prevents redundant redraws
        self.num_slices = 1  # Ensures slice count is initialized

        # --- Initialization for interactive handles ---
        self.selected_handle = None  # Tracks which handle is being dragged/resized
        self.drag_offset = None      # Tracks the offset for moving the widget

        self.show()

    def paintEvent(self, event):
        painter = QPainter(self)
        pen = QPen(self.box_color)
        pen.setWidth(self.line_thickness)
        painter.setPen(pen)

        # Draw main rectangle
        painter.drawRect(0, 0, self.width(), self.height())

        # Draw detected word overlays
        for x, y, w, h in self.draw_positions:
            painter.drawRect(x, y, w, h)

        # Draw faint slice division lines
        if self.num_slices > 1:
            strip_height = self.height() // self.num_slices
            pen.setColor(QColor(150, 150, 150, 100))  # Light gray, semi-transparent
            pen.setWidth(1)
            painter.setPen(pen)

            for i in range(1, self.num_slices):  # Do not draw the last one at the bottom
                painter.drawLine(0, i * strip_height, self.width(), i * strip_height)
        
        # --- Draw interactive handles (grabbers) with reduced opacity (15%) ---
        # 15% opacity of 255 is approximately 38
        handle_color = QColor(0, 0, 0, 50)  
        for handle in self._resize_handles():
            painter.fillRect(handle, handle_color)
            painter.drawRect(handle)  # Optional: draw a border around the handle

    def set_draw_positions(self, positions, color, thickness):
        """
        Updates the overlay positions and visual settings.
        """
        self.draw_positions = positions
        self.box_color = QColor(*color)
        self.line_thickness = thickness
        self.update()

    def update_draw_positions(self, positions, color, thickness):
        """
        Updates the overlay positions and redraws only if the positions have changed.
        This prevents unnecessary flickering.
        """
        if positions == self.previous_positions:
            return  # No change, do not update

        self.previous_positions = positions  # Store last known positions
        self.draw_positions = positions
        self.box_color = QColor(*color)
        self.line_thickness = thickness
        self.update()  # Redraw only if needed

    def set_num_slices(self, num_slices):
        """
        Updates the number of horizontal slices for visualization.
        """
        self.num_slices = num_slices
        self.update()

    def _resize_handles(self):
        """
        Returns a list of QRect objects representing the interactive handles:
          - Index 0: Top-left (resize)
          - Index 1: Top-right (resize)
          - Index 2: Bottom-left (resize)
          - Index 3: Bottom-right (resize)
          - Index 4: Top-center (dragger)
        """
        w, h = self.width(), self.height()
        handles = [
            QRect(0, 0, HANDLE_SIZE, HANDLE_SIZE),                           # Top-left
            QRect(w - HANDLE_SIZE, 0, HANDLE_SIZE, HANDLE_SIZE),                # Top-right
            QRect(0, h - HANDLE_SIZE, HANDLE_SIZE, HANDLE_SIZE),                # Bottom-left
            QRect(w - HANDLE_SIZE, h - HANDLE_SIZE, HANDLE_SIZE, HANDLE_SIZE)     # Bottom-right
        ]
        # Top-center handle: centered along the top edge
        top_center_x = (w - HANDLE_SIZE) // 2
        top_center = QRect(top_center_x, 0, HANDLE_SIZE, HANDLE_SIZE)
        handles.append(top_center)
        return handles

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            # Check if any handle (including the new top-center) is clicked
            for i, handle in enumerate(self._resize_handles()):
                if handle.contains(event.pos()):
                    self.selected_handle = i
                    # For the top-center handle (index 4), initialize drag offset for moving
                    if i == 4:
                        self.drag_offset = event.pos()
                    return
            # If no handle is clicked, allow dragging by clicking anywhere in the widget
            self.drag_offset = event.pos()

    def mouseMoveEvent(self, event):
        if self.selected_handle is not None:
            if self.selected_handle == 4:
                # --- Top-center handle dragging ---
                new_x = event.globalX() - self.drag_offset.x()
                new_y = event.globalY() - self.drag_offset.y()
                self.move(new_x, new_y)
                collector_mutex.lock()
                if self.region_id in regions:
                    regions[self.region_id]["bbox"] = [new_x, new_y, self.width(), self.height()]
                collector_mutex.unlock()
                self.update()
            else:
                # --- Resizing logic for corner handles ---
                if self.selected_handle == 0:  # Top-left
                    new_w = self.width() + (self.x() - event.globalX())
                    new_h = self.height() + (self.y() - event.globalY())
                    new_x = event.globalX()
                    new_y = event.globalY()
                elif self.selected_handle == 1:  # Top-right
                    new_w = event.globalX() - self.x()
                    new_h = self.height() + (self.y() - event.globalY())
                    new_x = self.x()
                    new_y = event.globalY()
                elif self.selected_handle == 2:  # Bottom-left
                    new_w = self.width() + (self.x() - event.globalX())
                    new_h = event.globalY() - self.y()
                    new_x = event.globalX()
                    new_y = self.y()
                elif self.selected_handle == 3:  # Bottom-right
                    new_w = event.globalX() - self.x()
                    new_h = event.globalY() - self.y()
                    new_x = self.x()
                    new_y = self.y()

                if new_w < 20:
                    new_w = 20
                if new_h < 20:
                    new_h = 20

                self.setGeometry(new_x, new_y, new_w, new_h)
                collector_mutex.lock()
                if self.region_id in regions:
                    regions[self.region_id]["bbox"] = [self.x(), self.y(), self.width(), self.height()]
                collector_mutex.unlock()
                self.update()
        elif self.drag_offset:
            # --- General widget dragging (if no handle was clicked) ---
            new_x = event.globalX() - self.drag_offset.x()
            new_y = event.globalY() - self.drag_offset.y()
            self.move(new_x, new_y)
            collector_mutex.lock()
            if self.region_id in regions:
                regions[self.region_id]["bbox"] = [new_x, new_y, self.width(), self.height()]
            collector_mutex.unlock()

    def mouseReleaseEvent(self, event):
        """
        Resets the drag/resize state once the mouse button is released.
        """
        self.selected_handle = None
        self.drag_offset = None