Enabled GPU Acceleration for Identification Node.

This commit is contained in:
Nicole Rappe 2025-02-26 01:57:47 -07:00
parent c10fc1ba6d
commit 0515f8feeb
5 changed files with 56 additions and 33 deletions

View File

@ -0,0 +1,5 @@
We need to run the following commands to install the prerequisites for the project.
This command is used to install pytorch and torchvision for the purposes of GPU-accelerated vision tasks.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

View File

@ -6,13 +6,22 @@ import re
import sys import sys
import numpy as np import numpy as np
import cv2 import cv2
# Vision-related Imports
import pytesseract import pytesseract
import easyocr
import torch
from PIL import Image, ImageGrab, ImageFilter from PIL import Image, ImageGrab, ImageFilter
from PyQt5.QtWidgets import QApplication, QWidget from PyQt5.QtWidgets import QApplication, QWidget
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
from PyQt5.QtGui import QPainter, QPen, QColor, QFont from PyQt5.QtGui import QPainter, QPen, QColor, QFont
# Initialize EasyOCR with CUDA support
reader_cpu = easyocr.Reader(['en'], gpu=False)
reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe" pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
DEFAULT_WIDTH = 180 DEFAULT_WIDTH = 180
@ -29,11 +38,11 @@ app_instance = None
def _ensure_qapplication(): def _ensure_qapplication():
""" """
Ensures that QApplication is initialized before creating widgets. Ensures that QApplication is initialized before creating widgets.
Must be called from the main thread.
""" """
global app_instance global app_instance
if QApplication.instance() is None: if app_instance is None:
app_instance = QApplication(sys.argv) app_instance = QApplication(sys.argv) # Start in main thread
threading.Thread(target=app_instance.exec_, daemon=True).start()
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)): def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
@ -102,11 +111,11 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3)) return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5): def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
""" """
Finds positions of a specific word within the OCR region. Finds positions of a specific word within the OCR region.
Applies user-defined offset and margin adjustments. Applies user-defined offset and margin adjustments.
Returns a list of bounding box coordinates relative to the OCR box. Uses Tesseract (CPU) or EasyOCR (GPU) depending on the selected engine.
""" """
collector_mutex.lock() collector_mutex.lock()
if region_id not in regions: if region_id not in regions:
@ -136,28 +145,36 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
scale_x = orig_width / proc_width scale_x = orig_width / proc_width
scale_y = orig_height / proc_height scale_y = orig_height / proc_height
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
word_positions = [] word_positions = []
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
# Scale the detected coordinates back to region-relative positions
x_scaled = int(data['left'][i] * scale_x)
y_scaled = int(data['top'][i] * scale_y)
w_scaled = int(data['width'][i] * scale_x)
h_scaled = int(data['height'][i] * scale_y)
# Apply user-configured margin if ocr_engine == "CPU":
x_margin = max(0, x_scaled - margin) # Use Tesseract (CPU)
y_margin = max(0, y_scaled - margin) data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
w_margin = w_scaled + (margin * 2)
h_margin = h_scaled + (margin * 2)
# Apply user-configured offset for i in range(len(data['text'])):
x_final = x_margin + offset_x if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
y_final = y_margin + offset_y x_scaled = int(data['left'][i] * scale_x)
y_scaled = int(data['top'][i] * scale_y)
w_scaled = int(data['width'][i] * scale_x)
h_scaled = int(data['height'][i] * scale_y)
word_positions.append((x_final, y_final, w_margin, h_margin)) word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
else:
# Use EasyOCR (GPU) - Convert PIL image to NumPy array
image_np = np.array(processed)
results = reader_gpu.readtext(image_np)
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
x_scaled = int(x_min * scale_x)
y_scaled = int(y_min * scale_y)
w_scaled = int((x_max - x_min) * scale_x)
h_scaled = int((y_max - y_min) * scale_y)
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return word_positions return word_positions
except Exception as e: except Exception as e:
@ -165,9 +182,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
return [] return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)): def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
""" """
Draws non-interactive rectangles at specified positions within the given OCR region. Draws non-interactive rectangles at specified positions within the given OCR region.

View File

@ -3,14 +3,13 @@
Identification Overlay Node: Identification Overlay Node:
- Creates an OCR region in data_collector with a blue overlay. - Creates an OCR region in data_collector with a blue overlay.
- Detects instances of a specified word and draws adjustable overlays. - Detects instances of a specified word and draws adjustable overlays.
- Users can configure offset, margin, and polling frequency dynamically. - Users can configure offset, margin, polling frequency, and select OCR engine.
""" """
import re import re
from OdenGraphQt import BaseNode from OdenGraphQt import BaseNode
from PyQt5.QtWidgets import QMessageBox
from PyQt5.QtCore import QTimer from PyQt5.QtCore import QTimer
from Modules import data_manager, data_collector from Modules import data_collector
class IdentificationOverlayNode(BaseNode): class IdentificationOverlayNode(BaseNode):
@ -24,7 +23,9 @@ class IdentificationOverlayNode(BaseNode):
self.add_text_input("search_term", "Search Term", text="Aibatt") self.add_text_input("search_term", "Search Term", text="Aibatt")
self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset
self.add_text_input("margin", "Margin", text="5") # Box Margin self.add_text_input("margin", "Margin", text="5") # Box Margin
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # New input self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # Polling Rate
self.add_combo_menu("ocr_engine", "Type", items=["CPU", "GPU"])
self.set_property("ocr_engine", "CPU") # Set default value after adding the menu
self.region_id = "identification_overlay" self.region_id = "identification_overlay"
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255)) data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255))
@ -45,7 +46,7 @@ class IdentificationOverlayNode(BaseNode):
""" """
polling_text = self.get_property("polling_freq") polling_text = self.get_property("polling_freq")
try: try:
polling_interval = max(100, int(polling_text)) # Minimum 100ms to avoid overloading polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time
except ValueError: except ValueError:
polling_interval = 500 # Default to 500ms polling_interval = 500 # Default to 500ms
@ -58,6 +59,7 @@ class IdentificationOverlayNode(BaseNode):
search_term = self.get_property("search_term") search_term = self.get_property("search_term")
offset_text = self.get_property("offset_value") offset_text = self.get_property("offset_value")
margin_text = self.get_property("margin") margin_text = self.get_property("margin")
ocr_engine = self.get_property("ocr_engine")
# Read and apply polling frequency updates # Read and apply polling frequency updates
self.update_polling_frequency() self.update_polling_frequency()
@ -77,8 +79,10 @@ class IdentificationOverlayNode(BaseNode):
if not search_term: if not search_term:
return return
# Get detected word positions # Get detected word positions using the selected OCR engine
detected_positions = data_collector.find_word_positions(self.region_id, search_term, offset_x, offset_y, margin) detected_positions = data_collector.find_word_positions(
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine
)
# Draw detected word boxes # Draw detected word boxes
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255)) data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255))