Successfully implemented multi-core OCR functionality.

This commit is contained in:
Nicole Rappe 2025-02-26 03:36:59 -07:00
parent 5c23653d59
commit b6ef14b559
5 changed files with 176 additions and 65 deletions

View File

@ -6,6 +6,7 @@ import re
import sys
import numpy as np
import cv2
import concurrent.futures
# Vision-related Imports
import pytesseract
@ -111,11 +112,9 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU", num_slices=1):
"""
Optimized function to detect word positions in an OCR region.
Uses raw screen data without preprocessing for max performance.
Uses Tesseract (CPU) or EasyOCR (GPU) depending on user selection.
Uses user-defined horizontal slices and threading for faster inference.
"""
collector_mutex.lock()
if region_id not in regions:
@ -125,7 +124,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
bbox = regions[region_id]['bbox']
collector_mutex.unlock()
# Extract OCR region position and size
x, y, w, h = bbox
left, top, right, bottom = x, y, x + w, y + h
@ -134,61 +132,84 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
return []
try:
# Capture raw screen image (NO preprocessing)
image = ImageGrab.grab(bbox=(left, top, right, bottom))
# Get original image size
orig_width, orig_height = image.size
word_positions = []
# Ensure number of slices does not exceed image height
num_slices = min(num_slices, orig_height)
strip_height = max(1, orig_height // num_slices)
def process_strip(strip_id):
strip_y = strip_id * strip_height
strip = image.crop((0, strip_y, orig_width, strip_y + strip_height))
strip_np = np.array(strip)
detected_positions = []
if ocr_engine == "CPU":
# Use Tesseract directly on raw PIL image (no preprocessing)
data = pytesseract.image_to_data(image, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
ocr_data = pytesseract.image_to_data(strip, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
x_scaled = int(data['left'][i])
y_scaled = int(data['top'][i])
w_scaled = int(data['width'][i])
h_scaled = int(data['height'][i])
for i in range(len(ocr_data['text'])):
if re.search(rf"\b{word}\b", ocr_data['text'][i], re.IGNORECASE):
x_scaled = int(ocr_data['left'][i])
y_scaled = int(ocr_data['top'][i]) + strip_y
w_scaled = int(ocr_data['width'][i])
h_scaled = int(ocr_data['height'][i])
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
else:
# Convert PIL image to NumPy array for EasyOCR
image_np = np.array(image)
# Run GPU OCR
results = reader_gpu.readtext(image_np)
results = reader_gpu.readtext(strip_np)
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
x_scaled = int(x_min)
y_scaled = int(y_min)
y_scaled = int(y_min) + strip_y
w_scaled = int(x_max - x_min)
h_scaled = int(y_max - y_min)
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return detected_positions
with concurrent.futures.ThreadPoolExecutor(max_workers=num_slices) as executor:
strip_results = list(executor.map(process_strip, range(num_slices)))
for strip_result in strip_results:
word_positions.extend(strip_result)
return word_positions
except Exception as e:
print(f"[ERROR] Failed to capture OCR region: {e}")
return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255), thickness=2):
"""
Draws non-interactive rectangles at specified positions within the given OCR region.
Uses a separate rendering thread to prevent blocking OCR processing.
"""
collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget']
widget.set_draw_positions(positions, color, thickness)
widget.update_draw_positions(positions, color, thickness)
collector_mutex.unlock()
def update_region_slices(region_id, num_slices):
"""
Updates the number of visual slices in the OCR region.
"""
collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget']
widget.set_num_slices(num_slices)
collector_mutex.unlock()
class OCRRegionWidget(QWidget):
def __init__(self, x, y, w, h, region_id, color, thickness):
@ -199,12 +220,12 @@ class OCRRegionWidget(QWidget):
self.setAttribute(Qt.WA_TranslucentBackground, True)
self.setAttribute(Qt.WA_TransparentForMouseEvents, False)
self.drag_offset = None
self.selected_handle = None
self.region_id = region_id
self.box_color = QColor(*color)
self.line_thickness = thickness
self.draw_positions = []
self.previous_positions = [] # This prevents redundant redraws
self.num_slices = 1 # Ensures slice count is initialized
self.show()
@ -221,6 +242,16 @@ class OCRRegionWidget(QWidget):
for x, y, w, h in self.draw_positions:
painter.drawRect(x, y, w, h)
# Draw faint slice division lines
if self.num_slices > 1:
strip_height = self.height() // self.num_slices
pen.setColor(QColor(150, 150, 150, 100)) # Light gray, semi-transparent
pen.setWidth(1)
painter.setPen(pen)
for i in range(1, self.num_slices): # Do not draw the last one at the bottom
painter.drawLine(0, i * strip_height, self.width(), i * strip_height)
def set_draw_positions(self, positions, color, thickness):
"""
Updates the overlay positions and visual settings.
@ -230,6 +261,27 @@ class OCRRegionWidget(QWidget):
self.line_thickness = thickness
self.update()
def update_draw_positions(self, positions, color, thickness):
"""
Updates the overlay positions and redraws only if the positions have changed.
This prevents unnecessary flickering.
"""
if positions == self.previous_positions:
return # No change, do not update
self.previous_positions = positions # Store last known positions
self.draw_positions = positions
self.box_color = QColor(*color)
self.line_thickness = thickness
self.update() # Redraw only if needed
def set_num_slices(self, num_slices):
"""
Updates the number of horizontal slices for visualization.
"""
self.num_slices = num_slices
self.update()
def _resize_handles(self):
w, h = self.width(), self.height()
return [
@ -254,19 +306,23 @@ class OCRRegionWidget(QWidget):
new_h = h + (self.y() - event.globalY())
new_x = event.globalX()
new_y = event.globalY()
if new_w < 20: new_w = 20
if new_h < 20: new_h = 20
if new_w < 20:
new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(new_x, new_y, new_w, new_h)
elif self.selected_handle == 1: # Bottom-right
new_w = event.globalX() - self.x()
new_h = event.globalY() - self.y()
if new_w < 20: new_w = 20
if new_h < 20: new_h = 20
if new_w < 20:
new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(self.x(), self.y(), new_w, new_h)
collector_mutex.lock()
if self.region_id in regions:
regions[self.region_id]['bbox'] = [self.x(), self.y(), self.width(), self.height()]
regions[self.region_id]["bbox"] = [self.x(), self.y(), self.width(), self.height()]
collector_mutex.unlock()
self.update()
@ -277,5 +333,7 @@ class OCRRegionWidget(QWidget):
collector_mutex.lock()
if self.region_id in regions:
regions[self.region_id]['bbox'] = [new_x, new_y, self.width(), self.height()]
regions[self.region_id]["bbox"] = [new_x, new_y, self.width(), self.height()]
collector_mutex.unlock()

View File

@ -1,9 +1,7 @@
#!/usr/bin/env python3
"""
Identification Overlay Node:
- Creates an OCR region in data_collector with a blue overlay.
- Detects instances of a specified word and draws adjustable overlays.
- Users can configure offset, margin, polling frequency, overlay color, and thickness.
- Users can configure threads/slices for parallel processing.
"""
import re
@ -31,6 +29,7 @@ class IdentificationOverlayNode(BaseNode):
# Custom overlay options
self.add_text_input("overlay_color", "Overlay Color (RGB)", text="0,0,255") # Default blue
self.add_text_input("thickness", "Line Thickness", text="2") # Default 2px
self.add_text_input("threads_slices", "Threads / Slices", text="8") # Default 8 threads/slices
self.region_id = "identification_overlay"
data_collector.create_ocr_region(self.region_id, x=250, y=50, w=300, h=200, color=(0, 0, 255), thickness=2)
@ -46,62 +45,59 @@ class IdentificationOverlayNode(BaseNode):
self.update_polling_frequency()
def update_polling_frequency(self):
"""
Reads the user-defined polling frequency and updates the timer interval.
"""
polling_text = self.get_property("polling_freq")
try:
polling_interval = max(50, int(polling_text)) # Minimum 50ms for near real-time
polling_interval = max(50, int(polling_text))
except ValueError:
polling_interval = 500 # Default to 500ms
polling_interval = 500
self.timer.start(polling_interval)
def update_overlay(self):
"""
Updates the overlay with detected word positions.
"""
search_term = self.get_property("search_term")
offset_text = self.get_property("offset_value")
margin_text = self.get_property("margin")
ocr_engine = self.get_property("ocr_engine")
threads_slices_text = self.get_property("threads_slices")
# Read and apply polling frequency updates
self.update_polling_frequency()
# Parse user-defined offset
try:
offset_x, offset_y = map(int, offset_text.split(","))
except ValueError:
offset_x, offset_y = 0, 0 # Default to no offset if invalid input
offset_x, offset_y = 0, 0
# Parse user-defined margin
try:
margin = int(margin_text)
except ValueError:
margin = 5 # Default margin if invalid input
margin = 5
# Parse overlay color
color_text = self.get_property("overlay_color")
try:
color = tuple(map(int, color_text.split(","))) # Convert "255,0,0" -> (255,0,0)
color = tuple(map(int, color_text.split(",")))
except ValueError:
color = (0, 0, 255) # Default to blue if invalid input
color = (0, 0, 255)
# Parse thickness
thickness_text = self.get_property("thickness")
try:
thickness = max(1, int(thickness_text)) # Ensure at least 1px thickness
thickness = max(1, int(thickness_text))
except ValueError:
thickness = 2 # Default thickness
thickness = 2
try:
num_slices = max(1, int(threads_slices_text)) # Ensure at least 1 slice
except ValueError:
num_slices = 1
if not search_term:
return
# Get detected word positions using the selected OCR engine
detected_positions = data_collector.find_word_positions(
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine
self.region_id, search_term, offset_x, offset_y, margin, ocr_engine, num_slices
)
# Draw detected word boxes with custom color & thickness
# Ensure slice count is updated visually in the region widget
data_collector.update_region_slices(self.region_id, num_slices)
data_collector.draw_identification_boxes(self.region_id, detected_positions, color=color, thickness=thickness)

View File

@ -0,0 +1,57 @@
{
"graph":{
"layout_direction":0,
"acyclic":true,
"pipe_collision":false,
"pipe_slicing":true,
"pipe_style":1,
"accept_connection_types":{},
"reject_connection_types":{}
},
"nodes":{
"0x20c129abb30":{
"type_":"bunny-lab.io.identification_overlay_node.IdentificationOverlayNode",
"icon":null,
"name":"Identification Overlay",
"color":[
13,
18,
23,
255
],
"border_color":[
74,
84,
85,
255
],
"text_color":[
255,
255,
255,
180
],
"disabled":false,
"selected":false,
"visible":true,
"width":271.0,
"height":330.40000000000003,
"pos":[
44.64929777820301,
256.49596595988965
],
"layout_direction":0,
"port_deletion_allowed":false,
"subgraph_session":{},
"custom":{
"search_term":"Aibatt",
"offset_value":"-10,-10",
"margin":"10",
"polling_freq":"50",
"ocr_engine":"GPU",
"overlay_color":"255,255,255",
"thickness":"5"
}
}
}
}