Enabled GPU Acceleration for Identification Node.
This commit is contained in:
parent
c10fc1ba6d
commit
0515f8feeb
5
Installation Requirements.txt
Normal file
5
Installation Requirements.txt
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
We need to run the following commands to install the prerequisites for the project.
|
||||||
|
|
||||||
|
|
||||||
|
This command is used to install pytorch and torchvision for the purposes of GPU-accelerated vision tasks.
|
||||||
|
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
|
Binary file not shown.
@ -6,13 +6,22 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
|
# Vision-related Imports
|
||||||
import pytesseract
|
import pytesseract
|
||||||
|
import easyocr
|
||||||
|
import torch
|
||||||
|
|
||||||
from PIL import Image, ImageGrab, ImageFilter
|
from PIL import Image, ImageGrab, ImageFilter
|
||||||
|
|
||||||
from PyQt5.QtWidgets import QApplication, QWidget
|
from PyQt5.QtWidgets import QApplication, QWidget
|
||||||
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
|
from PyQt5.QtCore import QRect, QPoint, Qt, QMutex, QTimer
|
||||||
from PyQt5.QtGui import QPainter, QPen, QColor, QFont
|
from PyQt5.QtGui import QPainter, QPen, QColor, QFont
|
||||||
|
|
||||||
|
# Initialize EasyOCR with CUDA support
|
||||||
|
reader_cpu = easyocr.Reader(['en'], gpu=False)
|
||||||
|
reader_gpu = easyocr.Reader(['en'], gpu=True if torch.cuda.is_available() else False)
|
||||||
|
|
||||||
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
|
pytesseract.pytesseract.tesseract_cmd = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
|
||||||
|
|
||||||
DEFAULT_WIDTH = 180
|
DEFAULT_WIDTH = 180
|
||||||
@ -29,11 +38,11 @@ app_instance = None
|
|||||||
def _ensure_qapplication():
|
def _ensure_qapplication():
|
||||||
"""
|
"""
|
||||||
Ensures that QApplication is initialized before creating widgets.
|
Ensures that QApplication is initialized before creating widgets.
|
||||||
|
Must be called from the main thread.
|
||||||
"""
|
"""
|
||||||
global app_instance
|
global app_instance
|
||||||
if QApplication.instance() is None:
|
if app_instance is None:
|
||||||
app_instance = QApplication(sys.argv)
|
app_instance = QApplication(sys.argv) # Start in main thread
|
||||||
threading.Thread(target=app_instance.exec_, daemon=True).start()
|
|
||||||
|
|
||||||
|
|
||||||
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
|
def create_ocr_region(region_id, x=250, y=50, w=DEFAULT_WIDTH, h=DEFAULT_HEIGHT, color=(255, 255, 0)):
|
||||||
@ -102,11 +111,11 @@ def _preprocess_image(image):
|
|||||||
return thresh.filter(ImageFilter.MedianFilter(3))
|
return thresh.filter(ImageFilter.MedianFilter(3))
|
||||||
|
|
||||||
|
|
||||||
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
|
||||||
"""
|
"""
|
||||||
Finds positions of a specific word within the OCR region.
|
Finds positions of a specific word within the OCR region.
|
||||||
Applies user-defined offset and margin adjustments.
|
Applies user-defined offset and margin adjustments.
|
||||||
Returns a list of bounding box coordinates relative to the OCR box.
|
Uses Tesseract (CPU) or EasyOCR (GPU) depending on the selected engine.
|
||||||
"""
|
"""
|
||||||
collector_mutex.lock()
|
collector_mutex.lock()
|
||||||
if region_id not in regions:
|
if region_id not in regions:
|
||||||
@ -136,28 +145,36 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
|||||||
scale_x = orig_width / proc_width
|
scale_x = orig_width / proc_width
|
||||||
scale_y = orig_height / proc_height
|
scale_y = orig_height / proc_height
|
||||||
|
|
||||||
|
word_positions = []
|
||||||
|
|
||||||
|
if ocr_engine == "CPU":
|
||||||
|
# Use Tesseract (CPU)
|
||||||
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
|
data = pytesseract.image_to_data(processed, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
|
||||||
|
|
||||||
word_positions = []
|
|
||||||
for i in range(len(data['text'])):
|
for i in range(len(data['text'])):
|
||||||
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
|
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
|
||||||
# Scale the detected coordinates back to region-relative positions
|
|
||||||
x_scaled = int(data['left'][i] * scale_x)
|
x_scaled = int(data['left'][i] * scale_x)
|
||||||
y_scaled = int(data['top'][i] * scale_y)
|
y_scaled = int(data['top'][i] * scale_y)
|
||||||
w_scaled = int(data['width'][i] * scale_x)
|
w_scaled = int(data['width'][i] * scale_x)
|
||||||
h_scaled = int(data['height'][i] * scale_y)
|
h_scaled = int(data['height'][i] * scale_y)
|
||||||
|
|
||||||
# Apply user-configured margin
|
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
|
||||||
x_margin = max(0, x_scaled - margin)
|
|
||||||
y_margin = max(0, y_scaled - margin)
|
|
||||||
w_margin = w_scaled + (margin * 2)
|
|
||||||
h_margin = h_scaled + (margin * 2)
|
|
||||||
|
|
||||||
# Apply user-configured offset
|
else:
|
||||||
x_final = x_margin + offset_x
|
# Use EasyOCR (GPU) - Convert PIL image to NumPy array
|
||||||
y_final = y_margin + offset_y
|
image_np = np.array(processed)
|
||||||
|
results = reader_gpu.readtext(image_np)
|
||||||
|
|
||||||
word_positions.append((x_final, y_final, w_margin, h_margin))
|
for (bbox, text, _) in results:
|
||||||
|
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
|
||||||
|
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
|
||||||
|
|
||||||
|
x_scaled = int(x_min * scale_x)
|
||||||
|
y_scaled = int(y_min * scale_y)
|
||||||
|
w_scaled = int((x_max - x_min) * scale_x)
|
||||||
|
h_scaled = int((y_max - y_min) * scale_y)
|
||||||
|
|
||||||
|
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
|
||||||
|
|
||||||
return word_positions
|
return word_positions
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -165,9 +182,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
|
def draw_identification_boxes(region_id, positions, color=(0, 0, 255)):
|
||||||
"""
|
"""
|
||||||
Draws non-interactive rectangles at specified positions within the given OCR region.
|
Draws non-interactive rectangles at specified positions within the given OCR region.
|
||||||
|
Binary file not shown.
@ -3,14 +3,13 @@
|
|||||||
Identification Overlay Node:
|
Identification Overlay Node:
|
||||||
- Creates an OCR region in data_collector with a blue overlay.
|
- Creates an OCR region in data_collector with a blue overlay.
|
||||||
- Detects instances of a specified word and draws adjustable overlays.
|
- Detects instances of a specified word and draws adjustable overlays.
|
||||||
- Users can configure offset, margin, and polling frequency dynamically.
|
- Users can configure offset, margin, polling frequency, and select OCR engine.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from OdenGraphQt import BaseNode
|
from OdenGraphQt import BaseNode
|
||||||
from PyQt5.QtWidgets import QMessageBox
|
|
||||||
from PyQt5.QtCore import QTimer
|
from PyQt5.QtCore import QTimer
|
||||||
from Modules import data_manager, data_collector
|
from Modules import data_collector
|
||||||
|
|
||||||
|
|
||||||
class IdentificationOverlayNode(BaseNode):
|
class IdentificationOverlayNode(BaseNode):
|
||||||
@ -24,7 +23,9 @@ class IdentificationOverlayNode(BaseNode):
|
|||||||
self.add_text_input("search_term", "Search Term", text="Aibatt")
|
self.add_text_input("search_term", "Search Term", text="Aibatt")
|
||||||
self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset
|
self.add_text_input("offset_value", "Offset Value (X,Y)", text="0,0") # X,Y Offset
|
||||||
self.add_text_input("margin", "Margin", text="5") # Box Margin
|
self.add_text_input("margin", "Margin", text="5") # Box Margin
|
||||||
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # New input
|
self.add_text_input("polling_freq", "Polling Frequency (ms)", text="500") # Polling Rate
|
||||||
|
self.add_combo_menu("ocr_engine", "Type", items=["CPU", "GPU"])
|
||||||
|
self.set_property("ocr_engine", "CPU") # Set default value after adding the menu
|
||||||
|
|
||||||
self.region_id = "identification_overlay"
|
self.region_id = "identification_overlay"
|
||||||
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255))
|
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255))
|
||||||
@ -45,7 +46,7 @@ class IdentificationOverlayNode(BaseNode):
|
|||||||
"""
|
"""
|
||||||
polling_text = self.get_property("polling_freq")
|
polling_text = self.get_property("polling_freq")
|
||||||
try:
|
try:
|
||||||
polling_interval = max(100, int(polling_text)) # Minimum 100ms to avoid overloading
|
polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time
|
||||||
except ValueError:
|
except ValueError:
|
||||||
polling_interval = 500 # Default to 500ms
|
polling_interval = 500 # Default to 500ms
|
||||||
|
|
||||||
@ -58,6 +59,7 @@ class IdentificationOverlayNode(BaseNode):
|
|||||||
search_term = self.get_property("search_term")
|
search_term = self.get_property("search_term")
|
||||||
offset_text = self.get_property("offset_value")
|
offset_text = self.get_property("offset_value")
|
||||||
margin_text = self.get_property("margin")
|
margin_text = self.get_property("margin")
|
||||||
|
ocr_engine = self.get_property("ocr_engine")
|
||||||
|
|
||||||
# Read and apply polling frequency updates
|
# Read and apply polling frequency updates
|
||||||
self.update_polling_frequency()
|
self.update_polling_frequency()
|
||||||
@ -77,8 +79,10 @@ class IdentificationOverlayNode(BaseNode):
|
|||||||
if not search_term:
|
if not search_term:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Get detected word positions
|
# Get detected word positions using the selected OCR engine
|
||||||
detected_positions = data_collector.find_word_positions(self.region_id, search_term, offset_x, offset_y, margin)
|
detected_positions = data_collector.find_word_positions(
|
||||||
|
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine
|
||||||
|
)
|
||||||
|
|
||||||
# Draw detected word boxes
|
# Draw detected word boxes
|
||||||
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255))
|
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=(0, 0, 255))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user