helper.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319
  1. import cv2
  2. import datetime
  3. import io
  4. import numpy as np
  5. import os
  6. import threading
  7. import subprocess
  8. import multiprocessing
  9. import sys
  10. import tempfile
  11. from io import BytesIO
  12. from PIL import Image
  13. from ppadb.client import Client as AdbClient
  14. from dotenv import load_dotenv
  15. with_cuda = 0
  16. if cv2.cuda.getCudaEnabledDeviceCount() > 0:
  17. print("CUDA is available")
  18. with_cuda = 1
  19. else:
  20. print("CUDA is not available")
  21. with_cuda = 0
  22. load_dotenv()
  23. ram_drive_path = os.getenv("RAMDRIVE_PATH")
  24. if (
  25. not ram_drive_path
  26. or not os.path.exists(ram_drive_path)
  27. or not os.path.ismount(ram_drive_path)
  28. ):
  29. ram_drive_path = os.getenv("SCREENSHOT_PATH")
  30. print(f"no ram drive (fallback to {ram_drive_path})")
  31. android_address = os.getenv("ANDROID_ADDRESS")
  32. client = AdbClient(host="127.0.0.1", port=5037)
  33. if not android_address or not client:
  34. print(f"android address wrong? ({android_address})")
  35. sys.exit()
  36. def get_current_screen():
  37. if not current_screen:
  38. print("something went wrong. not able to get screen.")
  39. sys.exit()
  40. return current_screen
  41. def screencap_worker(device, temp_file_name):
  42. try:
  43. screenshot = device.screencap()
  44. with open(temp_file_name, "wb") as f:
  45. f.write(screenshot)
  46. except Exception as e:
  47. print(f"Error in worker process: {e}")
  48. def capture_current_screen(timeout=10):
  49. # Create a temporary file
  50. temp_file = tempfile.NamedTemporaryFile(delete=False)
  51. temp_file_name = temp_file.name
  52. temp_file.close()
  53. capture_process = multiprocessing.Process(
  54. target=screencap_worker, args=(device, temp_file_name)
  55. )
  56. capture_process.start()
  57. capture_process.join(timeout)
  58. if capture_process.is_alive():
  59. capture_process.terminate()
  60. capture_process.join()
  61. print("Screen capture timed out")
  62. os.remove(temp_file_name)
  63. return None
  64. if not os.path.exists(temp_file_name) or os.path.getsize(temp_file_name) == 0:
  65. print("No data in the temporary file")
  66. os.remove(temp_file_name)
  67. return None
  68. # Read the screenshot from the temporary file
  69. global current_screen
  70. with open(temp_file_name, "rb") as f:
  71. current_screen = f.read()
  72. # Clean up
  73. os.remove(temp_file_name)
  74. return current_screen
  75. # def capture_current_screen(timeout=5): # Timeout in seconds
  76. # def target():
  77. # global current_screen
  78. # current_screen = device.screencap()
  79. # capture_thread = threading.Thread(target=target)
  80. # capture_thread.start()
  81. # capture_thread.join(timeout)
  82. # if capture_thread.is_alive():
  83. # print("Screen capture timed out")
  84. # # Handle the timeout situation, e.g., by retrying or aborting
  85. # capture_thread.join()
  86. # return current_screen
  87. def find_center(x1, y1, x2, y2):
  88. centerX = round(x1 + (x2 - x1) / 2)
  89. centerY = round(y1 + (y2 - y1) / 2)
  90. return centerX, centerY
  91. def call_device_shell(action, timeout=10):
  92. def target():
  93. device.shell(action)
  94. thread = threading.Thread(target=target)
  95. thread.start()
  96. thread.join(timeout)
  97. if thread.is_alive():
  98. print("ran into timeout")
  99. thread.join()
  100. def tap(x, y=None, text=None):
  101. # Check if x is an int
  102. if isinstance(x, int):
  103. if not isinstance(y, int):
  104. raise ValueError("y must be an int when x is an int")
  105. # Construct the location string from both x and y
  106. location = f"{x} {y}"
  107. # Check if x is a string
  108. elif isinstance(x, str):
  109. location = x
  110. elif isinstance(x, tuple):
  111. location = f"{x[0]} {x[1]}"
  112. else:
  113. raise TypeError("x must be either an int or a string")
  114. # Assuming 'device' is a previously defined object with a 'shell' method
  115. action = f"input tap {location}"
  116. print(f"{action} {text}")
  117. call_device_shell(action, timeout=5)
  118. def tap_button(template):
  119. button = find_template(template)
  120. if len(button) == 0:
  121. return
  122. tap(f"{button[0][0]} {button[0][1]}")
  123. def swipe(start, end, duration=1000):
  124. action = f"input swipe {start} {end} {duration}"
  125. print(action)
  126. call_device_shell(action, timeout=5)
  127. def look_for_templates(templates):
  128. for name, template in templates.items():
  129. locations = find_template(template)
  130. if len(locations) > 0:
  131. return name, locations
  132. return None, None
  133. def first_template(template_image):
  134. result = find_template(template_image)
  135. if len(result) > 0:
  136. return result[0]
  137. return None
  138. def find_template(template_image):
  139. if with_cuda == 1:
  140. # Ensure the images are in the correct format (BGR for OpenCV)
  141. target_image = get_current_screen()
  142. # Upload images to GPU
  143. target_image_gpu = cv2.cuda_GpuMat()
  144. template_image_gpu = cv2.cuda_GpuMat()
  145. target_image_gpu.upload(target_image)
  146. template_image_gpu.upload(template_image)
  147. # Perform template matching on the GPU
  148. result_gpu = cv2.cuda.createTemplateMatching(cv2.CV_8UC3, cv2.TM_CCOEFF_NORMED)
  149. result = result_gpu.match(target_image_gpu, template_image_gpu)
  150. # Download result from GPU to CPU
  151. result = result.download()
  152. else:
  153. target_image = Image.open(BytesIO(get_current_screen()))
  154. # Convert the image to a NumPy array and then to BGR format (which OpenCV uses)
  155. target_image = np.array(target_image)
  156. target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
  157. # target_image = target_image.astype(np.uint8)
  158. # template_image = template_image.astype(np.uint8)
  159. h, w = template_image.shape[:-1]
  160. # Template matching
  161. result = cv2.matchTemplate(target_image, template_image, cv2.TM_CCOEFF_NORMED)
  162. # Define a threshold
  163. threshold = 0.9 # Adjust this threshold based on your requirements
  164. # Finding all locations where match exceeds threshold
  165. locations = np.where(result >= threshold)
  166. locations = list(zip(*locations[::-1]))
  167. # Create list of rectangles
  168. rectangles = [(*loc, loc[0] + w, loc[1] + h) for loc in locations]
  169. # Apply non-maximum suppression to remove overlaps
  170. rectangles = non_max_suppression(rectangles, 0.3)
  171. # Initialize an empty list to store coordinates
  172. coordinates = []
  173. for startX, startY, endX, endY in rectangles:
  174. # Append the coordinate pair to the list
  175. coordinates.append(find_center(startX, startY, endX, endY))
  176. # Sort the coordinates by y value in ascending order
  177. return sorted(coordinates, key=lambda x: x[1])
  178. def non_max_suppression(boxes, overlapThresh):
  179. if len(boxes) == 0:
  180. return []
  181. # Convert to float
  182. boxes = np.array(boxes, dtype="float")
  183. # Initialize the list of picked indexes
  184. pick = []
  185. # Grab the coordinates of the bounding boxes
  186. x1 = boxes[:, 0]
  187. y1 = boxes[:, 1]
  188. x2 = boxes[:, 2]
  189. y2 = boxes[:, 3]
  190. # Compute the area of the bounding boxes and sort by bottom-right y-coordinate
  191. area = (x2 - x1 + 1) * (y2 - y1 + 1)
  192. idxs = np.argsort(y2)
  193. # Keep looping while some indexes still remain in the indexes list
  194. while len(idxs) > 0:
  195. # Grab the last index in the indexes list and add the index value to the list of picked indexes
  196. last = len(idxs) - 1
  197. i = idxs[last]
  198. pick.append(i)
  199. # Find the largest (x, y) coordinates for the start of the bounding box and the smallest (x, y)
  200. # coordinates for the end of the bounding box
  201. xx1 = np.maximum(x1[i], x1[idxs[:last]])
  202. yy1 = np.maximum(y1[i], y1[idxs[:last]])
  203. xx2 = np.minimum(x2[i], x2[idxs[:last]])
  204. yy2 = np.minimum(y2[i], y2[idxs[:last]])
  205. # Compute the width and height of the bounding box
  206. w = np.maximum(0, xx2 - xx1 + 1)
  207. h = np.maximum(0, yy2 - yy1 + 1)
  208. # Compute the ratio of overlap
  209. overlap = (w * h) / area[idxs[:last]]
  210. # Delete all indexes from the index list that have overlap greater than the threshold
  211. idxs = np.delete(
  212. idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
  213. )
  214. # Return only the bounding boxes that were picked
  215. return boxes[pick].astype("int")
  216. def save_screenshot(path="test"):
  217. # Take a screenshot
  218. result = capture_current_screen()
  219. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  220. image = Image.open(io.BytesIO(result))
  221. jpeg_filename = f"{path}/{timestamp}.jpg"
  222. image = image.convert("RGB") # Convert to RGB mode for JPEG
  223. with open(jpeg_filename, "wb") as fp:
  224. image.save(fp, format="JPEG", quality=85) # Adjust quality as needed
  225. print(f"snap: {jpeg_filename}")
  226. def save_screenshot2(path="test"):
  227. proc = subprocess.Popen(
  228. "adb exec-out screencap -p", shell=True, stdout=subprocess.PIPE
  229. )
  230. image_bytes = proc.stdout.read()
  231. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
  232. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  233. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  234. jpeg_filename = f"{path}/{timestamp}.jpg"
  235. cv2.imwrite(
  236. jpeg_filename,
  237. cv2.cvtColor(image, cv2.COLOR_RGB2BGR),
  238. [int(cv2.IMWRITE_JPEG_QUALITY), 85],
  239. )
  240. print(f"snap: {jpeg_filename}")
  241. device = client.device(android_address)
  242. current_screen = capture_current_screen()