# Modules/data_collector.py

import threading
import time
import re
import sys
import numpy as np
import cv2
import pytesseract
from PIL import Image, ImageGrab, ImageFilter

from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
from PyQt5.QtGui import QPainter, QPen, QColor, QFont

pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"

DEFAULT_WIDTH = 180
DEFAULT_HEIGHT = 130
HANDLE_SIZE = 8
LABEL_HEIGHT = 20

collector_mutex = QMutex()
regions = {}

app_instance = None


def _ensure_qapplication():
    """
    Ensures that QApplication is initialized before creating widgets.
    """
    global app_instance
    if QApplication.instance() is None:
        app_instance = QApplication(sys.argv)
        threading.Thread(target=app_instance.exec_, daemon=True).start()


def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
    """
    Creates an OCR region with a visible, resizable box on the screen.
    The color parameter allows customization (default yellow, blue for overlays).
    """

    _ensure_qapplication()

    collector_mutex.lock()
    if region_id in regions:
        collector_mutex.unlock()
        return
    regions[region_id] = {
        'bbox': [x, y, w, h],
        'raw_text': "",
        'widget': OCRRegionWidget(x, y, w, h, region_id, color)
    }
    collector_mutex.unlock()


def get_raw_text(region_id):
    collector_mutex.lock()
    if region_id not in regions:
        collector_mutex.unlock()
        return ""
    text = regions[region_id]['raw_text']
    collector_mutex.unlock()
    return text


def start_collector():
    t = threading.Thread(target=_update_ocr_loop, daemon=True)
    t.start()


def _update_ocr_loop():
    while True:
        collector_mutex.lock()
        region_ids = list(regions.keys())
        collector_mutex.unlock()

        for rid in region_ids:
            collector_mutex.lock()
            bbox = regions[rid]['bbox'][:]
            collector_mutex.unlock()

            x, y, w, h = bbox
            screenshot = ImageGrab.grab(bbox=(x, y, x + w, y + h))
            processed = _preprocess_image(screenshot)
            raw_text = pytesseract.image_to_string(processed, config='--psm 6 --oem 1')

            collector_mutex.lock()
            if rid in regions:
                regions[rid]['raw_text'] = raw_text
            collector_mutex.unlock()

        time.sleep(0.7)


def _preprocess_image(image):
    gray = image.convert("L")
    scaled = gray.resize((gray.width * 3, gray.height * 3))
    thresh = scaled.point(lambda p: 255 if p > 200 else 0)
    return thresh.filter(ImageFilter.MedianFilter(3))


def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
    """
    Finds positions of a specific word within the OCR region.
    Applies user-defined offset and margin adjustments.
    Returns a list of bounding box coordinates relative to the OCR box.
    """
    collector_mutex.lock()
    if region_id not in regions:
        collector_mutex.unlock()
        return []
    
    bbox = regions[region_id]['bbox']
    collector_mutex.unlock()

    # Extract OCR region position and size
    x, y, w, h = bbox
    left, top, right, bottom = x, y, x + w, y + h

    if right <= left or bottom <= top:
        print(f"[ERROR] Invalid OCR region bounds: {bbox}")
        return []

    try:
        image = ImageGrab.grab(bbox=(left, top, right, bottom))
        processed = _preprocess_image(image)

        # Get original and processed image sizes
        orig_width, orig_height = image.size
        proc_width, proc_height = processed.size

        # Scale factor between processed image and original screenshot
        scale_x = orig_width / proc_width
        scale_y = orig_height / proc_height

        data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)

        word_positions = []
        for i in range(len(data['text'])):
            if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
                # Scale the detected coordinates back to region-relative positions
                x_scaled = int(data['left'][i] * scale_x)
                y_scaled = int(data['top'][i] * scale_y)
                w_scaled = int(data['width'][i] * scale_x)
                h_scaled = int(data['height'][i] * scale_y)

                # Apply user-configured margin
                x_margin = max(0, x_scaled - margin)
                y_margin = max(0, y_scaled - margin)
                w_margin = w_scaled + (margin * 2)
                h_margin = h_scaled + (margin * 2)

                # Apply user-configured offset
                x_final = x_margin + offset_x
                y_final = y_margin + offset_y

                word_positions.append((x_final, y_final, w_margin, h_margin))

        return word_positions
    except Exception as e:
        print(f"[ERROR] Failed to capture OCR region: {e}")
        return []





def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
    """
    Draws non-interactive rectangles at specified positions within the given OCR region.
    """
    collector_mutex.lock()
    if region_id in regions and 'widget' in regions[region_id]:
        widget = regions[region_id]['widget']
        widget.set_draw_positions(positions, color)
    collector_mutex.unlock()


class OCRRegionWidget(QWidget):
    def __init__(self, x, y, w, h, region_id, color):
        super().__init__()

        self.setGeometry(x, y, w, h)
        self.setWindowFlags(Qt.FramelessWindowHint | Qt.WindowStaysOnTopHint | Qt.Tool)
        self.setAttribute(Qt.WA_TranslucentBackground, True)
        self.setAttribute(Qt.WA_TransparentForMouseEvents, False)

        self.drag_offset = None
        self.selected_handle = None
        self.region_id = region_id
        self.box_color = QColor(*color)
        self.draw_positions = []

        self.show()

    def paintEvent(self, event):
        painter = QPainter(self)
        pen = QPen(self.box_color)
        pen.setWidth(5)
        painter.setPen(pen)

        # Draw main rectangle
        painter.drawRect(0, 0, self.width(), self.height())

        # Draw detected word overlays
        pen.setWidth(2)
        pen.setColor(QColor(0, 0, 255))
        painter.setPen(pen)

        for x, y, w, h in self.draw_positions:
            painter.drawRect(x, y, w, h)

        # Draw resize handles
        painter.setBrush(self.box_color)
        for handle in self._resize_handles():
            painter.drawRect(handle)

    def set_draw_positions(self, positions, color):
        """
        Update the positions where identification boxes should be drawn.
        """
        self.draw_positions = positions
        self.box_color = QColor(*color)
        self.update()

    def _resize_handles(self):
        w, h = self.width(), self.height()
        return [
            QRect(0, 0, HANDLE_SIZE, HANDLE_SIZE),  # Top-left
            QRect(w - HANDLE_SIZE, h - HANDLE_SIZE, HANDLE_SIZE, HANDLE_SIZE)  # Bottom-right
        ]

    def mousePressEvent(self, event):
        if event.button() == Qt.LeftButton:
            for i, handle in enumerate(self._resize_handles()):
                if handle.contains(event.pos()):
                    self.selected_handle = i
                    return

            self.drag_offset = event.pos()

    def mouseMoveEvent(self, event):
        if self.selected_handle is not None:
            w, h = self.width(), self.height()
            if self.selected_handle == 0:  # Top-left
                new_w = w + (self.x() - event.globalX())
                new_h = h + (self.y() - event.globalY())
                new_x = event.globalX()
                new_y = event.globalY()
                if new_w < 20: new_w = 20
                if new_h < 20: new_h = 20
                self.setGeometry(new_x, new_y, new_w, new_h)
            elif self.selected_handle == 1:  # Bottom-right
                new_w = event.globalX() - self.x()
                new_h = event.globalY() - self.y()
                if new_w < 20: new_w = 20
                if new_h < 20: new_h = 20
                self.setGeometry(self.x(), self.y(), new_w, new_h)

            collector_mutex.lock()
            if self.region_id in regions:
                regions[self.region_id]['bbox'] = [self.x(), self.y(), self.width(), self.height()]
            collector_mutex.unlock()

            self.update()
        elif self.drag_offset:
            new_x = event.globalX() - self.drag_offset.x()
            new_y = event.globalY() - self.drag_offset.y()
            self.move(new_x, new_y)

            collector_mutex.lock()
            if self.region_id in regions:
                regions[self.region_id]['bbox'] = [new_x, new_y, self.width(), self.height()]
            collector_mutex.unlock()