Successfully implemented multi-core OCR functionality.

This commit is contained in:
2025-02-26 03:36:59 -07:00
parent 5c23653d59
commit b6ef14b559
5 changed files with 176 additions and 65 deletions

View File

@ -6,6 +6,7 @@ import re
import sys
import numpy as np
import cv2
import concurrent.futures
# Vision-related Imports
import pytesseract
@ -111,11 +112,9 @@ def _preprocess_image(image):
return thresh.filter(ImageFilter.MedianFilter(3))
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU"):
def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_engine="CPU", num_slices=1):
"""
Optimized function to detect word positions in an OCR region.
Uses raw screen data without preprocessing for max performance.
Uses Tesseract (CPU) or EasyOCR (GPU) depending on user selection.
Uses user-defined horizontal slices and threading for faster inference.
"""
collector_mutex.lock()
if region_id not in regions:
@ -125,7 +124,6 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
bbox = regions[region_id]['bbox']
collector_mutex.unlock()
# Extract OCR region position and size
x, y, w, h = bbox
left, top, right, bottom = x, y, x + w, y + h
@ -134,61 +132,84 @@ def find_word_positions(region_id, word, offset_x=0, offset_y=0, margin=5, ocr_e
return []
try:
# Capture raw screen image (NO preprocessing)
image = ImageGrab.grab(bbox=(left, top, right, bottom))
# Get original image size
orig_width, orig_height = image.size
word_positions = []
if ocr_engine == "CPU":
# Use Tesseract directly on raw PIL image (no preprocessing)
data = pytesseract.image_to_data(image, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
# Ensure number of slices does not exceed image height
num_slices = min(num_slices, orig_height)
strip_height = max(1, orig_height // num_slices)
for i in range(len(data['text'])):
if re.search(rf"\b{word}\b", data['text'][i], re.IGNORECASE):
x_scaled = int(data['left'][i])
y_scaled = int(data['top'][i])
w_scaled = int(data['width'][i])
h_scaled = int(data['height'][i])
def process_strip(strip_id):
strip_y = strip_id * strip_height
strip = image.crop((0, strip_y, orig_width, strip_y + strip_height))
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
strip_np = np.array(strip)
else:
# Convert PIL image to NumPy array for EasyOCR
image_np = np.array(image)
detected_positions = []
if ocr_engine == "CPU":
ocr_data = pytesseract.image_to_data(strip, config='--psm 6 --oem 1', output_type=pytesseract.Output.DICT)
# Run GPU OCR
results = reader_gpu.readtext(image_np)
for i in range(len(ocr_data['text'])):
if re.search(rf"\b{word}\b", ocr_data['text'][i], re.IGNORECASE):
x_scaled = int(ocr_data['left'][i])
y_scaled = int(ocr_data['top'][i]) + strip_y
w_scaled = int(ocr_data['width'][i])
h_scaled = int(ocr_data['height'][i])
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
x_scaled = int(x_min)
y_scaled = int(y_min)
w_scaled = int(x_max - x_min)
h_scaled = int(y_max - y_min)
else:
results = reader_gpu.readtext(strip_np)
for (bbox, text, _) in results:
if re.search(rf"\b{word}\b", text, re.IGNORECASE):
(x_min, y_min), (x_max, y_max) = bbox[0], bbox[2]
word_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
x_scaled = int(x_min)
y_scaled = int(y_min) + strip_y
w_scaled = int(x_max - x_min)
h_scaled = int(y_max - y_min)
detected_positions.append((x_scaled + offset_x, y_scaled + offset_y, w_scaled + (margin * 2), h_scaled + (margin * 2)))
return detected_positions
with concurrent.futures.ThreadPoolExecutor(max_workers=num_slices) as executor:
strip_results = list(executor.map(process_strip, range(num_slices)))
for strip_result in strip_results:
word_positions.extend(strip_result)
return word_positions
except Exception as e:
print(f"[ERROR] Failed to capture OCR region: {e}")
return []
def draw_identification_boxes(region_id, positions, color=(0, 0, 255), thickness=2):
"""
Draws non-interactive rectangles at specified positions within the given OCR region.
Uses a separate rendering thread to prevent blocking OCR processing.
"""
collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget']
widget.set_draw_positions(positions, color, thickness)
widget.update_draw_positions(positions, color, thickness)
collector_mutex.unlock()
def update_region_slices(region_id, num_slices):
"""
Updates the number of visual slices in the OCR region.
"""
collector_mutex.lock()
if region_id in regions and 'widget' in regions[region_id]:
widget = regions[region_id]['widget']
widget.set_num_slices(num_slices)
collector_mutex.unlock()
class OCRRegionWidget(QWidget):
def __init__(self, x, y, w, h, region_id, color, thickness):
@ -199,12 +220,12 @@ class OCRRegionWidget(QWidget):
self.setAttribute(Qt.WA_TranslucentBackground, True)
self.setAttribute(Qt.WA_TransparentForMouseEvents, False)
self.drag_offset = None
self.selected_handle = None
self.region_id = region_id
self.box_color = QColor(*color)
self.line_thickness = thickness
self.draw_positions = []
self.previous_positions = [] # This prevents redundant redraws
self.num_slices = 1 # Ensures slice count is initialized
self.show()
@ -221,6 +242,16 @@ class OCRRegionWidget(QWidget):
for x, y, w, h in self.draw_positions:
painter.drawRect(x, y, w, h)
# Draw faint slice division lines
if self.num_slices > 1:
strip_height = self.height() // self.num_slices
pen.setColor(QColor(150, 150, 150, 100)) # Light gray, semi-transparent
pen.setWidth(1)
painter.setPen(pen)
for i in range(1, self.num_slices): # Do not draw the last one at the bottom
painter.drawLine(0, i * strip_height, self.width(), i * strip_height)
def set_draw_positions(self, positions, color, thickness):
"""
Updates the overlay positions and visual settings.
@ -230,6 +261,27 @@ class OCRRegionWidget(QWidget):
self.line_thickness = thickness
self.update()
def update_draw_positions(self, positions, color, thickness):
"""
Updates the overlay positions and redraws only if the positions have changed.
This prevents unnecessary flickering.
"""
if positions == self.previous_positions:
return # No change, do not update
self.previous_positions = positions # Store last known positions
self.draw_positions = positions
self.box_color = QColor(*color)
self.line_thickness = thickness
self.update() # Redraw only if needed
def set_num_slices(self, num_slices):
"""
Updates the number of horizontal slices for visualization.
"""
self.num_slices = num_slices
self.update()
def _resize_handles(self):
w, h = self.width(), self.height()
return [
@ -254,19 +306,23 @@ class OCRRegionWidget(QWidget):
new_h = h + (self.y() - event.globalY())
new_x = event.globalX()
new_y = event.globalY()
if new_w < 20: new_w = 20
if new_h < 20: new_h = 20
if new_w < 20:
new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(new_x, new_y, new_w, new_h)
elif self.selected_handle == 1: # Bottom-right
new_w = event.globalX() - self.x()
new_h = event.globalY() - self.y()
if new_w < 20: new_w = 20
if new_h < 20: new_h = 20
if new_w < 20:
new_w = 20
if new_h < 20:
new_h = 20
self.setGeometry(self.x(), self.y(), new_w, new_h)
collector_mutex.lock()
if self.region_id in regions:
regions[self.region_id]['bbox'] = [self.x(), self.y(), self.width(), self.height()]
regions[self.region_id]["bbox"] = [self.x(), self.y(), self.width(), self.height()]
collector_mutex.unlock()
self.update()
@ -277,5 +333,7 @@ class OCRRegionWidget(QWidget):
collector_mutex.lock()
if self.region_id in regions:
regions[self.region_id]['bbox'] = [new_x, new_y, self.width(), self.height()]
regions[self.region_id]["bbox"] = [new_x, new_y, self.width(), self.height()]
collector_mutex.unlock()