Enabled GPU Acceleration for Identification Node.

This commit is contained in:
Nicole Rappe 2025-02-26 01:57:47 -07:00
parent c10fc1ba6d
commit 0515f8feeb
5 changed files with 56 additions and 33 deletions

View File

@ -0,0 +1,5 @@
We need to run the following commands to install the prerequisites for the project.
This command is used to install pytorch and torchvision for the purposes of GPU-accelerated vision tasks.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

View File

@ -6,13 +6,22 @@ import re
import sys
import numpy as np
import cv2
# Vision-related Imports
import pytesseract
import easyocr
import torch
from PIL import Image, ImageGrab, ImageFilter
from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
from PyQt5.QtGui import QPainter, QPen, QColor, QFont
# Initialize EasyOCR with CUDA support
reader_cpu = easyocr.Reader(['en'], gpu=False)
reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
DEFAULT_WIDTH = 180
@ -29,11 +38,11 @@ app_instance = None
def _ensure_qapplication():
"""
Ensures that QApplication is initialized before creating widgets.
Must be called from the main thread.
"""
global app_instance
if QApplication.instance() is None:
app_instance = QApplication(sys.argv)
threading.Thread(target=app_instance.exec_, daemon=True).start()
if app_instance is None:
app_instance = QApplication(sys.argv) # Start in main thread
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
@ -102,11 +111,11 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
"""
Finds positions of a specific word within the OCR region.
Applies user-defined offset and margin adjustments.
Returns a list of bounding box coordinates relative to the OCR box.
Uses Tesseract (CPU) or EasyOCR (GPU) depending on the selected engine.
"""
collector_mutex.lock()
if region_id not in regions:
@ -136,28 +145,36 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
scale_x = orig_width / proc_width
scale_y = orig_height / proc_height
word_positions = []
if ocr_engine == "CPU":
# Use Tesseract (CPU)
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
word_positions = []
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
# Scale the detected coordinates back to region-relative positions
x_scaled = int(data['left'][i] * scale_x)
y_scaled = int(data['top'][i] * scale_y)
w_scaled = int(data['width'][i] * scale_x)
h_scaled = int(data['height'][i] * scale_y)
# Apply user-configured margin
x_margin = max(0, x_scaled - margin)
y_margin = max(0, y_scaled - margin)
w_margin = w_scaled + (margin * 2)
h_margin = h_scaled + (margin * 2)
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
# Apply user-configured offset
x_final = x_margin + offset_x
y_final = y_margin + offset_y
else:
# Use EasyOCR (GPU) - Convert PIL image to NumPy array
image_np = np.array(processed)
results = reader_gpu.readtext(image_np)
word_positions.append((x_final, y_final, w_margin, h_margin))
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
x_scaled = int(x_min * scale_x)
y_scaled = int(y_min * scale_y)
w_scaled = int((x_max - x_min) * scale_x)
h_scaled = int((y_max - y_min) * scale_y)
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return word_positions
except Exception as e:
@ -165,9 +182,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
"""
Draws non-interactive rectangles at specified positions within the given OCR region.

View File

@ -3,14 +3,13 @@
Identification Overlay Node:
- Creates an OCR region in data_collector with a blue overlay.
- Detects instances of a specified word and draws adjustable overlays.
- Users can configure offset, margin, and polling frequency dynamically.
- Users can configure offset, margin, polling frequency, and select OCR engine.
"""
import re
from OdenGraphQt import BaseNode
from PyQt5.QtWidgets import QMessageBox
from PyQt5.QtCore import QTimer
from Modules import data_manager, data_collector
from Modules import data_collector
class IdentificationOverlayNode(BaseNode):
@ -24,7 +23,9 @@ class IdentificationOverlayNode(BaseNode):
self.add_text_input("search_term", "Search Term", text="Aibatt")
self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset
self.add_text_input("margin", "Margin", text="5") # Box Margin
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # New input
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # Polling Rate
self.add_combo_menu("ocr_engine", "Type", items=["CPU", "GPU"])
self.set_property("ocr_engine", "CPU") # Set default value after adding the menu
self.region_id = "identification_overlay"
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255))
@ -45,7 +46,7 @@ class IdentificationOverlayNode(BaseNode):
"""
polling_text = self.get_property("polling_freq")
try:
polling_interval = max(100, int(polling_text)) # Minimum 100ms to avoid overloading
polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time
except ValueError:
polling_interval = 500 # Default to 500ms
@ -58,6 +59,7 @@ class IdentificationOverlayNode(BaseNode):
search_term = self.get_property("search_term")
offset_text = self.get_property("offset_value")
margin_text = self.get_property("margin")
ocr_engine = self.get_property("ocr_engine")
# Read and apply polling frequency updates
self.update_polling_frequency()
@ -77,8 +79,10 @@ class IdentificationOverlayNode(BaseNode):
if not search_term:
return
# Get detected word positions
detected_positions = data_collector.find_word_positions(self.region_id, search_term, offset_x, offset_y, margin)
# Get detected word positions using the selected OCR engine
detected_positions = data_collector.find_word_positions(
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine
)
# Draw detected word boxes
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255))