Bladeren bron

introduce multiprocessing with temporary storing on ramdisk

Richard Köhl 1 jaar geleden
bovenliggende
commit
f922a1659e
1 gewijzigde bestanden met toevoegingen van 56 en 13 verwijderingen
  1. 56 13
      helper.py

+ 56 - 13
helper.py

@@ -5,6 +5,9 @@ import numpy as np
 import os
 import threading
 import subprocess
+import multiprocessing
+import sys
+import tempfile
 from io import BytesIO
 from PIL import Image
 from ppadb.client import Client as AdbClient
@@ -19,27 +22,68 @@ else:
     with_cuda = 0
 
 load_dotenv()
+ram_drive_path = os.getenv("RAMDRIVE_PATH")
+if (
+    not ram_drive_path
+    or not os.path.exists(ram_drive_path)
+    or not os.path.ismount(ram_drive_path)
+):
+    print("no ram drive ({ram_drive_path})")
+
 android_address = os.getenv("ANDROID_ADDRESS")
+client = AdbClient(host="127.0.0.1", port=5037)
+if not android_address or not client:
+    print(f"android address wrong? ({android_address})")
+    sys.exit()
 
 
 def get_current_screen():
+    if not current_screen:
+        print("something went wrong. not able to get screen.")
+        sys.exit()
     return current_screen
 
 
-def capture_current_screen(timeout=5):  # Timeout in seconds
-    def target():
-        global current_screen
-        current_screen = device.screencap()
+def screencap_worker(device, temp_file_name):
+    try:
+        screenshot = device.screencap()
+        with open(temp_file_name, "wb") as f:
+            f.write(screenshot)
+    except Exception as e:
+        print(f"Error in worker process: {e}")
 
-    capture_thread = threading.Thread(target=target)
-    capture_thread.start()
-    capture_thread.join(timeout)
 
-    if capture_thread.is_alive():
+def capture_current_screen(timeout=10):
+    # Create a temporary file
+    temp_file = tempfile.NamedTemporaryFile(delete=False)
+    temp_file_name = temp_file.name
+    temp_file.close()
+
+    capture_process = multiprocessing.Process(
+        target=screencap_worker, args=(device, temp_file_name)
+    )
+    capture_process.start()
+    capture_process.join(timeout)
+
+    if capture_process.is_alive():
+        capture_process.terminate()
+        capture_process.join()
         print("Screen capture timed out")
-        # Handle the timeout situation, e.g., by retrying or aborting
-        capture_thread.join()
+        os.remove(temp_file_name)
+        return None
+
+    if not os.path.exists(temp_file_name) or os.path.getsize(temp_file_name) == 0:
+        print("No data in the temporary file")
+        os.remove(temp_file_name)
+        return None
 
+    # Read the screenshot from the temporary file
+    global current_screen
+    with open(temp_file_name, "rb") as f:
+        current_screen = f.read()
+
+    # Clean up
+    os.remove(temp_file_name)
     return current_screen
 
 
@@ -115,7 +159,7 @@ def first_template(template_image):
 def find_template(template_image):
     if with_cuda == 1:
         # Ensure the images are in the correct format (BGR for OpenCV)
-        target_image = capture_current_screen()
+        target_image = get_current_screen()
 
         # Upload images to GPU
         target_image_gpu = cv2.cuda_GpuMat()
@@ -218,7 +262,7 @@ def non_max_suppression(boxes, overlapThresh):
 
 def save_screenshot(path="test"):
     # Take a screenshot
-    result = capture_current_screen()
+    result = get_current_screen()
 
     timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
     image = Image.open(io.BytesIO(result))
@@ -249,7 +293,6 @@ def save_screenshot2(path="test"):
     print(f"snap: {jpeg_filename}")
 
 
-client = AdbClient(host="127.0.0.1", port=5037)
 device = client.device(android_address)
 
 current_screen = capture_current_screen()