Enabled GPU Acceleration for Identification Node.

This commit is contained in:
2025-02-26 01:57:47 -07:00
parent c10fc1ba6d
commit 0515f8feeb
5 changed files with 56 additions and 33 deletions

View File

@ -6,13 +6,22 @@ import re
import sys
import numpy as np
import cv2
# Vision-related Imports
import pytesseract
import easyocr
import torch
from PIL import Image, ImageGrab, ImageFilter
from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
from PyQt5.QtGui import QPainter, QPen, QColor, QFont
# Initialize EasyOCR with CUDA support
reader_cpu = easyocr.Reader(['en'], gpu=False)
reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
DEFAULT_WIDTH = 180
@ -29,11 +38,11 @@ app_instance = None
def _ensure_qapplication():
"""
Ensures that QApplication is initialized before creating widgets.
Must be called from the main thread.
"""
global app_instance
if QApplication.instance() is None:
app_instance = QApplication(sys.argv)
threading.Thread(target=app_instance.exec_, daemon=True).start()
if app_instance is None:
app_instance = QApplication(sys.argv) # Start in main thread
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
@ -102,11 +111,11 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
"""
Finds positions of a specific word within the OCR region.
Applies user-defined offset and margin adjustments.
Returns a list of bounding box coordinates relative to the OCR box.
Uses Tesseract (CPU) or EasyOCR (GPU) depending on the selected engine.
"""
collector_mutex.lock()
if region_id not in regions:
@ -136,28 +145,36 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
scale_x = orig_width / proc_width
scale_y = orig_height / proc_height
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
word_positions = []
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
# Scale the detected coordinates back to region-relative positions
x_scaled = int(data['left'][i] * scale_x)
y_scaled = int(data['top'][i] * scale_y)
w_scaled = int(data['width'][i] * scale_x)
h_scaled = int(data['height'][i] * scale_y)
# Apply user-configured margin
x_margin = max(0, x_scaled - margin)
y_margin = max(0, y_scaled - margin)
w_margin = w_scaled + (margin * 2)
h_margin = h_scaled + (margin * 2)
if ocr_engine == "CPU":
# Use Tesseract (CPU)
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
# Apply user-configured offset
x_final = x_margin + offset_x
y_final = y_margin + offset_y
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
x_scaled = int(data['left'][i] * scale_x)
y_scaled = int(data['top'][i] * scale_y)
w_scaled = int(data['width'][i] * scale_x)
h_scaled = int(data['height'][i] * scale_y)
word_positions.append((x_final, y_final, w_margin, h_margin))
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
else:
# Use EasyOCR (GPU) - Convert PIL image to NumPy array
image_np = np.array(processed)
results = reader_gpu.readtext(image_np)
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
x_scaled = int(x_min * scale_x)
y_scaled = int(y_min * scale_y)
w_scaled = int((x_max - x_min) * scale_x)
h_scaled = int((y_max - y_min) * scale_y)
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return word_positions
except Exception as e:
@ -165,9 +182,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
"""
Draws non-interactive rectangles at specified positions within the given OCR region.