Enabled GPU Acceleration for Identification Node.
This commit is contained in:
parent
c10fc1ba6d
commit
0515f8feeb
5
Installation Requirements.txt
Normal file
5
Installation Requirements.txt
Normal file
@ -0,0 +1,5 @@
|
||||
We need to run the following commands to install the prerequisites for the project.
|
||||
|
||||
|
||||
This command is used to install pytorch and torchvision for the purposes of GPU-accelerated vision tasks.
|
||||
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
Binary file not shown.
@ -6,13 +6,22 @@ import re
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
# Vision-related Imports
|
||||
import pytesseract
|
||||
import easyocr
|
||||
import torch
|
||||
|
||||
from PIL import Image, ImageGrab, ImageFilter
|
||||
|
||||
from PyQt5.QtWidgets import QApplication, QWidget
|
||||
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
|
||||
from PyQt5.QtGui import QPainter, QPen, QColor, QFont
|
||||
|
||||
# Initialize EasyOCR with CUDA support
|
||||
reader_cpu = easyocr.Reader(['en'], gpu=False)
|
||||
reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)
|
||||
|
||||
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
|
||||
|
||||
DEFAULT_WIDTH = 180
|
||||
@ -29,11 +38,11 @@ app_instance = None
|
||||
def _ensure_qapplication():
|
||||
"""
|
||||
Ensures that QApplication is initialized before creating widgets.
|
||||
Must be called from the main thread.
|
||||
"""
|
||||
global app_instance
|
||||
if QApplication.instance() is None:
|
||||
app_instance = QApplication(sys.argv)
|
||||
threading.Thread(target=app_instance.exec_, daemon=True).start()
|
||||
if app_instance is None:
|
||||
app_instance = QApplication(sys.argv) # Start in main thread
|
||||
|
||||
|
||||
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
|
||||
@ -102,11 +111,11 @@ def _preprocess_image(image):
|
||||
return thresh.filter(ImageFilter.MedianFilter(3))
|
||||
|
||||
|
||||
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
||||
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
|
||||
"""
|
||||
Finds positions of a specific word within the OCR region.
|
||||
Applies user-defined offset and margin adjustments.
|
||||
Returns a list of bounding box coordinates relative to the OCR box.
|
||||
Uses Tesseract (CPU) or EasyOCR (GPU) depending on the selected engine.
|
||||
"""
|
||||
collector_mutex.lock()
|
||||
if region_id not in regions:
|
||||
@ -136,28 +145,36 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
||||
scale_x = orig_width / proc_width
|
||||
scale_y = orig_height / proc_height
|
||||
|
||||
word_positions = []
|
||||
|
||||
if ocr_engine == "CPU":
|
||||
# Use Tesseract (CPU)
|
||||
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
|
||||
|
||||
word_positions = []
|
||||
for i in range(len(data['text'])):
|
||||
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
|
||||
# Scale the detected coordinates back to region-relative positions
|
||||
x_scaled = int(data['left'][i] * scale_x)
|
||||
y_scaled = int(data['top'][i] * scale_y)
|
||||
w_scaled = int(data['width'][i] * scale_x)
|
||||
h_scaled = int(data['height'][i] * scale_y)
|
||||
|
||||
# Apply user-configured margin
|
||||
x_margin = max(0, x_scaled - margin)
|
||||
y_margin = max(0, y_scaled - margin)
|
||||
w_margin = w_scaled + (margin * 2)
|
||||
h_margin = h_scaled + (margin * 2)
|
||||
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
|
||||
|
||||
# Apply user-configured offset
|
||||
x_final = x_margin + offset_x
|
||||
y_final = y_margin + offset_y
|
||||
else:
|
||||
# Use EasyOCR (GPU) - Convert PIL image to NumPy array
|
||||
image_np = np.array(processed)
|
||||
results = reader_gpu.readtext(image_np)
|
||||
|
||||
word_positions.append((x_final, y_final, w_margin, h_margin))
|
||||
for (bbox, text, _) in results:
|
||||
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
|
||||
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
|
||||
|
||||
x_scaled = int(x_min * scale_x)
|
||||
y_scaled = int(y_min * scale_y)
|
||||
w_scaled = int((x_max - x_min) * scale_x)
|
||||
h_scaled = int((y_max - y_min) * scale_y)
|
||||
|
||||
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
|
||||
|
||||
return word_positions
|
||||
except Exception as e:
|
||||
@ -165,9 +182,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
||||
return []
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
|
||||
"""
|
||||
Draws non-interactive rectangles at specified positions within the given OCR region.
|
||||
|
Binary file not shown.
@ -3,14 +3,13 @@
|
||||
Identification Overlay Node:
|
||||
- Creates an OCR region in data_collector with a blue overlay.
|
||||
- Detects instances of a specified word and draws adjustable overlays.
|
||||
- Users can configure offset, margin, and polling frequency dynamically.
|
||||
- Users can configure offset, margin, polling frequency, and select OCR engine.
|
||||
"""
|
||||
|
||||
import re
|
||||
from OdenGraphQt import BaseNode
|
||||
from PyQt5.QtWidgets import QMessageBox
|
||||
from PyQt5.QtCore import QTimer
|
||||
from Modules import data_manager, data_collector
|
||||
from Modules import data_collector
|
||||
|
||||
|
||||
class IdentificationOverlayNode(BaseNode):
|
||||
@ -24,7 +23,9 @@ class IdentificationOverlayNode(BaseNode):
|
||||
self.add_text_input("search_term", "Search Term", text="Aibatt")
|
||||
self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset
|
||||
self.add_text_input("margin", "Margin", text="5") # Box Margin
|
||||
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # New input
|
||||
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # Polling Rate
|
||||
self.add_combo_menu("ocr_engine", "Type", items=["CPU", "GPU"])
|
||||
self.set_property("ocr_engine", "CPU") # Set default value after adding the menu
|
||||
|
||||
self.region_id = "identification_overlay"
|
||||
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255))
|
||||
@ -45,7 +46,7 @@ class IdentificationOverlayNode(BaseNode):
|
||||
"""
|
||||
polling_text = self.get_property("polling_freq")
|
||||
try:
|
||||
polling_interval = max(100, int(polling_text)) # Minimum 100ms to avoid overloading
|
||||
polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time
|
||||
except ValueError:
|
||||
polling_interval = 500 # Default to 500ms
|
||||
|
||||
@ -58,6 +59,7 @@ class IdentificationOverlayNode(BaseNode):
|
||||
search_term = self.get_property("search_term")
|
||||
offset_text = self.get_property("offset_value")
|
||||
margin_text = self.get_property("margin")
|
||||
ocr_engine = self.get_property("ocr_engine")
|
||||
|
||||
# Read and apply polling frequency updates
|
||||
self.update_polling_frequency()
|
||||
@ -77,8 +79,10 @@ class IdentificationOverlayNode(BaseNode):
|
||||
if not search_term:
|
||||
return
|
||||
|
||||
# Get detected word positions
|
||||
detected_positions = data_collector.find_word_positions(self.region_id, search_term, offset_x, offset_y, margin)
|
||||
# Get detected word positions using the selected OCR engine
|
||||
detected_positions = data_collector.find_word_positions(
|
||||
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine
|
||||
)
|
||||
|
||||
# Draw detected word boxes
|
||||
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255))
|
||||
|
Loading…
x
Reference in New Issue
Block a user