Successfully implemented multi-core OCR functionality.

This commit is contained in:
Nicole Rappe 2025-02-26 03:36:59 -07:00
parent 5c23653d59
commit b6ef14b559
5 changed files with 176 additions and 65 deletions

View File

@ -6,6 +6,7 @@ import re
import sys import sys
import numpy as np import numpy as np
import cv2 import cv2
import concurrent.futures
# Vision-related Imports # Vision-related Imports
import pytesseract import pytesseract
@ -111,11 +112,9 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3)) return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"): def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU", num_slices=1):
""" """
Optimized function to detect word positions in an OCR region. Uses user-defined horizontal slices and threading for faster inference.
Uses raw screen data without preprocessing for max performance.
Uses Tesseract (CPU) or EasyOCR (GPU) depending on user selection.
""" """
collector_mutex.lock() collector_mutex.lock()
if region_id not in regions: if region_id not in regions:
@ -125,7 +124,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
bbox = regions[region_id]['bbox'] bbox = regions[region_id]['bbox']
collector_mutex.unlock() collector_mutex.unlock()
# Extract OCR region position and size
x, y, w, h = bbox x, y, w, h = bbox
left, top, right, bottom = x, y, x + w, y + h left, top, right, bottom = x, y, x + w, y + h
@ -134,61 +132,84 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
return [] return []
try: try:
# Capture raw screen image (NO preprocessing)
image = ImageGrab.grab(bbox=(left, top, right, bottom)) image = ImageGrab.grab(bbox=(left, top, right, bottom))
# Get original image size
orig_width, orig_height = image.size orig_width, orig_height = image.size
word_positions = [] word_positions = []
if ocr_engine == "CPU": # Ensure number of slices does not exceed image height
# Use Tesseract directly on raw PIL image (no preprocessing) num_slices = min(num_slices, orig_height)
data = pytesseract.image_to_data(image, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT) strip_height = max(1, orig_height // num_slices)
for i in range(len(data['text'])): def process_strip(strip_id):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE): strip_y = strip_id * strip_height
x_scaled = int(data['left'][i]) strip = image.crop((0, strip_y, orig_width, strip_y + strip_height))
y_scaled = int(data['top'][i])
w_scaled = int(data['width'][i])
h_scaled = int(data['height'][i])
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2))) strip_np = np.array(strip)
else: detected_positions = []
# Convert PIL image to NumPy array for EasyOCR if ocr_engine == "CPU":
image_np = np.array(image) ocr_data = pytesseract.image_to_data(strip, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
# Run GPU OCR for i in range(len(ocr_data['text'])):
results = reader_gpu.readtext(image_np) if re.search(rf"\b{word}\b", ocr_data['text'][i], re.IGNORECASE):
x_scaled = int(ocr_data['left'][i])
y_scaled = int(ocr_data['top'][i]) + strip_y
w_scaled = int(ocr_data['width'][i])
h_scaled = int(ocr_data['height'][i])
for (bbox, text, _) in results: detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
x_scaled = int(x_min) else:
y_scaled = int(y_min) results = reader_gpu.readtext(strip_np)
w_scaled = int(x_max - x_min) for (bbox, text, _) in results:
h_scaled = int(y_max - y_min) if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2))) x_scaled = int(x_min)
y_scaled = int(y_min) + strip_y
w_scaled = int(x_max - x_min)
h_scaled = int(y_max - y_min)
detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return detected_positions
with concurrent.futures.ThreadPoolExecutor(max_workers=num_slices) as executor:
strip_results = list(executor.map(process_strip, range(num_slices)))
for strip_result in strip_results:
word_positions.extend(strip_result)
return word_positions return word_positions
except Exception as e: except Exception as e:
print(f"[ERROR] Failed to capture OCR region: {e}") print(f"[ERROR] Failed to capture OCR region: {e}")
return [] return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255), thickness=2): def draw_identification_boxes(region_id, positions, color=(0, 0, 255), thickness=2):
""" """
Draws non-interactive rectangles at specified positions within the given OCR region. Draws non-interactive rectangles at specified positions within the given OCR region.
Uses a separate rendering thread to prevent blocking OCR processing.
""" """
collector_mutex.lock() collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]: if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget'] widget = regions[region_id]['widget']
widget.set_draw_positions(positions, color, thickness) widget.update_draw_positions(positions, color, thickness)
collector_mutex.unlock() collector_mutex.unlock()
def update_region_slices(region_id, num_slices):
"""
Updates the number of visual slices in the OCR region.
"""
collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget']
widget.set_num_slices(num_slices)
collector_mutex.unlock()
class OCRRegionWidget(QWidget): class OCRRegionWidget(QWidget):
def __init__(self, x, y, w, h, region_id, color, thickness): def __init__(self, x, y, w, h, region_id, color, thickness):
@ -199,12 +220,12 @@ class OCRRegionWidget(QWidget):
self.setAttribute(Qt.WA_TranslucentBackground, True) self.setAttribute(Qt.WA_TranslucentBackground, True)
self.setAttribute(Qt.WA_TransparentForMouseEvents, False) self.setAttribute(Qt.WA_TransparentForMouseEvents, False)
self.drag_offset = None
self.selected_handle = None
self.region_id = region_id self.region_id = region_id
self.box_color = QColor(*color) self.box_color = QColor(*color)
self.line_thickness = thickness self.line_thickness = thickness
self.draw_positions = [] self.draw_positions = []
self.previous_positions = [] # This prevents redundant redraws
self.num_slices = 1 # Ensures slice count is initialized
self.show() self.show()
@ -221,6 +242,16 @@ class OCRRegionWidget(QWidget):
for x, y, w, h in self.draw_positions: for x, y, w, h in self.draw_positions:
painter.drawRect(x, y, w, h) painter.drawRect(x, y, w, h)
# Draw faint slice division lines
if self.num_slices > 1:
strip_height = self.height() // self.num_slices
pen.setColor(QColor(150, 150, 150, 100)) # Light gray, semi-transparent
pen.setWidth(1)
painter.setPen(pen)
for i in range(1, self.num_slices): # Do not draw the last one at the bottom
painter.drawLine(0, i * strip_height, self.width(), i * strip_height)
def set_draw_positions(self, positions, color, thickness): def set_draw_positions(self, positions, color, thickness):
""" """
Updates the overlay positions and visual settings. Updates the overlay positions and visual settings.
@ -230,6 +261,27 @@ class OCRRegionWidget(QWidget):
self.line_thickness = thickness self.line_thickness = thickness
self.update() self.update()
def update_draw_positions(self, positions, color, thickness):
"""
Updates the overlay positions and redraws only if the positions have changed.
This prevents unnecessary flickering.
"""
if positions == self.previous_positions:
return # No change, do not update
self.previous_positions = positions # Store last known positions
self.draw_positions = positions
self.box_color = QColor(*color)
self.line_thickness = thickness
self.update() # Redraw only if needed
def set_num_slices(self, num_slices):
"""
Updates the number of horizontal slices for visualization.
"""
self.num_slices = num_slices
self.update()
def _resize_handles(self): def _resize_handles(self):
w, h = self.width(), self.height() w, h = self.width(), self.height()
return [ return [
@ -254,19 +306,23 @@ class OCRRegionWidget(QWidget):
new_h = h + (self.y() - event.globalY()) new_h = h + (self.y() - event.globalY())
new_x = event.globalX() new_x = event.globalX()
new_y = event.globalY() new_y = event.globalY()
if new_w < 20: new_w = 20 if new_w < 20:
if new_h < 20: new_h = 20 new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(new_x, new_y, new_w, new_h) self.setGeometry(new_x, new_y, new_w, new_h)
elif self.selected_handle == 1: # Bottom-right elif self.selected_handle == 1: # Bottom-right
new_w = event.globalX() - self.x() new_w = event.globalX() - self.x()
new_h = event.globalY() - self.y() new_h = event.globalY() - self.y()
if new_w < 20: new_w = 20 if new_w < 20:
if new_h < 20: new_h = 20 new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(self.x(), self.y(), new_w, new_h) self.setGeometry(self.x(), self.y(), new_w, new_h)
collector_mutex.lock() collector_mutex.lock()
if self.region_id in regions: if self.region_id in regions:
regions[self.region_id]['bbox'] = [self.x(), self.y(), self.width(), self.height()] regions[self.region_id]["bbox"] = [self.x(), self.y(), self.width(), self.height()]
collector_mutex.unlock() collector_mutex.unlock()
self.update() self.update()
@ -277,5 +333,7 @@ class OCRRegionWidget(QWidget):
collector_mutex.lock() collector_mutex.lock()
if self.region_id in regions: if self.region_id in regions:
regions[self.region_id]['bbox'] = [new_x, new_y, self.width(), self.height()] regions[self.region_id]["bbox"] = [new_x, new_y, self.width(), self.height()]
collector_mutex.unlock() collector_mutex.unlock()

View File

@ -1,9 +1,7 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Identification Overlay Node: Identification Overlay Node:
- Creates an OCR region in data_collector with a blue overlay. - Users can configure threads/slices for parallel processing.
- Detects instances of a specified word and draws adjustable overlays.
- Users can configure offset, margin, polling frequency, overlay color, and thickness.
""" """
import re import re
@ -31,6 +29,7 @@ class IdentificationOverlayNode(BaseNode):
# Custom overlay options # Custom overlay options
self.add_text_input("overlay_color", "Overlay Color (RGB)", text="0,0,255") # Default blue self.add_text_input("overlay_color", "Overlay Color (RGB)", text="0,0,255") # Default blue
self.add_text_input("thickness", "Line Thickness", text="2") # Default 2px self.add_text_input("thickness", "Line Thickness", text="2") # Default 2px
self.add_text_input("threads_slices", "Threads / Slices", text="8") # Default 8 threads/slices
self.region_id = "identification_overlay" self.region_id = "identification_overlay"
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255), thickness=2) data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255), thickness=2)
@ -46,62 +45,59 @@ class IdentificationOverlayNode(BaseNode):
self.update_polling_frequency() self.update_polling_frequency()
def update_polling_frequency(self): def update_polling_frequency(self):
"""
Reads the user-defined polling frequency and updates the timer interval.
"""
polling_text = self.get_property("polling_freq") polling_text = self.get_property("polling_freq")
try: try:
polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time polling_interval = max(50, int(polling_text))
except ValueError: except ValueError:
polling_interval = 500 # Default to 500ms polling_interval = 500
self.timer.start(polling_interval) self.timer.start(polling_interval)
def update_overlay(self): def update_overlay(self):
"""
Updates the overlay with detected word positions.
"""
search_term = self.get_property("search_term") search_term = self.get_property("search_term")
offset_text = self.get_property("offset_value") offset_text = self.get_property("offset_value")
margin_text = self.get_property("margin") margin_text = self.get_property("margin")
ocr_engine = self.get_property("ocr_engine") ocr_engine = self.get_property("ocr_engine")
threads_slices_text = self.get_property("threads_slices")
# Read and apply polling frequency updates
self.update_polling_frequency() self.update_polling_frequency()
# Parse user-defined offset
try: try:
offset_x, offset_y = map(int, offset_text.split(",")) offset_x, offset_y = map(int, offset_text.split(","))
except ValueError: except ValueError:
offset_x, offset_y = 0, 0 # Default to no offset if invalid input offset_x, offset_y = 0, 0
# Parse user-defined margin
try: try:
margin = int(margin_text) margin = int(margin_text)
except ValueError: except ValueError:
margin = 5 # Default margin if invalid input margin = 5
# Parse overlay color
color_text = self.get_property("overlay_color") color_text = self.get_property("overlay_color")
try: try:
color = tuple(map(int, color_text.split(","))) # Convert "255,0,0" -> (255,0,0) color = tuple(map(int, color_text.split(",")))
except ValueError: except ValueError:
color = (0, 0, 255) # Default to blue if invalid input color = (0, 0, 255)
# Parse thickness
thickness_text = self.get_property("thickness") thickness_text = self.get_property("thickness")
try: try:
thickness = max(1, int(thickness_text)) # Ensure at least 1px thickness thickness = max(1, int(thickness_text))
except ValueError: except ValueError:
thickness = 2 # Default thickness thickness = 2
try:
num_slices = max(1, int(threads_slices_text)) # Ensure at least 1 slice
except ValueError:
num_slices = 1
if not search_term: if not search_term:
return return
# Get detected word positions using the selected OCR engine
detected_positions = data_collector.find_word_positions( detected_positions = data_collector.find_word_positions(
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine self.region_id, search_term, offset_x, offset_y, margin, ocr_engine, num_slices
) )
# Draw detected word boxes with custom color & thickness # Ensure slice count is updated visually in the region widget
data_collector.update_region_slices(self.region_id, num_slices)
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=color, thickness=thickness) data_collector.draw_identification_boxes(self.region_id, detected_positions, color=color, thickness=thickness)

View File

@ -0,0 +1,57 @@
{
"graph":{
"layout_direction":0,
"acyclic":true,
"pipe_collision":false,
"pipe_slicing":true,
"pipe_style":1,
"accept_connection_types":{},
"reject_connection_types":{}
},
"nodes":{
"0x20c129abb30":{
"type_":"bunny-lab.io.identification_overlay_node.IdentificationOverlayNode",
"icon":null,
"name":"Identification Overlay",
"color":[
13,
18,
23,
255
],
"border_color":[
74,
84,
85,
255
],
"text_color":[
255,
255,
255,
180
],
"disabled":false,
"selected":false,
"visible":true,
"width":271.0,
"height":330.40000000000003,
"pos":[
44.64929777820301,
256.49596595988965
],
"layout_direction":0,
"port_deletion_allowed":false,
"subgraph_session":{},
"custom":{
"search_term":"Aibatt",
"offset_value":"-10,-10",
"margin":"10",
"polling_freq":"50",
"ocr_engine":"GPU",
"overlay_color":"255,255,255",
"thickness":"5"
}
}
}
}