helper.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298
  1. import cv2
  2. import datetime
  3. import io
  4. import numpy as np
  5. import os
  6. import threading
  7. import subprocess
  8. import multiprocessing
  9. import sys
  10. import tempfile
  11. from io import BytesIO
  12. from PIL import Image
  13. from ppadb.client import Client as AdbClient
  14. from dotenv import load_dotenv
  15. with_cuda = 0
  16. if cv2.cuda.getCudaEnabledDeviceCount() > 0:
  17. print("CUDA is available")
  18. with_cuda = 1
  19. else:
  20. print("CUDA is not available")
  21. with_cuda = 0
  22. load_dotenv()
  23. ram_drive_path = os.getenv("RAMDRIVE_PATH")
  24. if (
  25. not ram_drive_path
  26. or not os.path.exists(ram_drive_path)
  27. or not os.path.ismount(ram_drive_path)
  28. ):
  29. print("no ram drive ({ram_drive_path})")
  30. android_address = os.getenv("ANDROID_ADDRESS")
  31. client = AdbClient(host="127.0.0.1", port=5037)
  32. if not android_address or not client:
  33. print(f"android address wrong? ({android_address})")
  34. sys.exit()
  35. def get_current_screen():
  36. if not current_screen:
  37. print("something went wrong. not able to get screen.")
  38. sys.exit()
  39. return current_screen
  40. def screencap_worker(device, temp_file_name):
  41. try:
  42. screenshot = device.screencap()
  43. with open(temp_file_name, "wb") as f:
  44. f.write(screenshot)
  45. except Exception as e:
  46. print(f"Error in worker process: {e}")
  47. def capture_current_screen(timeout=10):
  48. # Create a temporary file
  49. temp_file = tempfile.NamedTemporaryFile(delete=False)
  50. temp_file_name = temp_file.name
  51. temp_file.close()
  52. capture_process = multiprocessing.Process(
  53. target=screencap_worker, args=(device, temp_file_name)
  54. )
  55. capture_process.start()
  56. capture_process.join(timeout)
  57. if capture_process.is_alive():
  58. capture_process.terminate()
  59. capture_process.join()
  60. print("Screen capture timed out")
  61. os.remove(temp_file_name)
  62. return None
  63. if not os.path.exists(temp_file_name) or os.path.getsize(temp_file_name) == 0:
  64. print("No data in the temporary file")
  65. os.remove(temp_file_name)
  66. return None
  67. # Read the screenshot from the temporary file
  68. global current_screen
  69. with open(temp_file_name, "rb") as f:
  70. current_screen = f.read()
  71. # Clean up
  72. os.remove(temp_file_name)
  73. return current_screen
  74. def find_center(x1, y1, x2, y2):
  75. centerX = round(x1 + (x2 - x1) / 2)
  76. centerY = round(y1 + (y2 - y1) / 2)
  77. return centerX, centerY
  78. def call_device_shell(action, timeout=10):
  79. def target():
  80. device.shell(action)
  81. thread = threading.Thread(target=target)
  82. thread.start()
  83. thread.join(timeout)
  84. if thread.is_alive():
  85. print("ran into timeout")
  86. thread.join()
  87. def tap(x, y=None):
  88. # Check if x is an int
  89. if isinstance(x, int):
  90. if not isinstance(y, int):
  91. raise ValueError("y must be an int when x is an int")
  92. # Construct the location string from both x and y
  93. location = f"{x} {y}"
  94. # Check if x is a string
  95. elif isinstance(x, str):
  96. location = x
  97. elif isinstance(x, tuple):
  98. location = f"{x[0]} {x[1]}"
  99. else:
  100. raise TypeError("x must be either an int or a string")
  101. # Assuming 'device' is a previously defined object with a 'shell' method
  102. action = f"input tap {location}"
  103. print(action)
  104. call_device_shell(action, timeout=5)
  105. def tap_button(template):
  106. button = find_template(template)
  107. if len(button) == 0:
  108. return
  109. tap(f"{button[0][0]} {button[0][1]}")
  110. def swipe(start, end, duration=1000):
  111. action = f"input swipe {start} {end} {duration}"
  112. print(action)
  113. call_device_shell(action, timeout=5)
  114. def look_for_templates(templates):
  115. for name, template in templates.items():
  116. locations = find_template(template)
  117. if len(locations) > 0:
  118. return name, locations
  119. return None, None
  120. def first_template(template_image):
  121. result = find_template(template_image)
  122. if len(result) > 0:
  123. return result[0]
  124. return None
  125. def find_template(template_image):
  126. if with_cuda == 1:
  127. # Ensure the images are in the correct format (BGR for OpenCV)
  128. target_image = get_current_screen()
  129. # Upload images to GPU
  130. target_image_gpu = cv2.cuda_GpuMat()
  131. template_image_gpu = cv2.cuda_GpuMat()
  132. target_image_gpu.upload(target_image)
  133. template_image_gpu.upload(template_image)
  134. # Perform template matching on the GPU
  135. result_gpu = cv2.cuda.createTemplateMatching(cv2.CV_8UC3, cv2.TM_CCOEFF_NORMED)
  136. result = result_gpu.match(target_image_gpu, template_image_gpu)
  137. # Download result from GPU to CPU
  138. result = result.download()
  139. else:
  140. target_image = Image.open(BytesIO(get_current_screen()))
  141. # Convert the image to a NumPy array and then to BGR format (which OpenCV uses)
  142. target_image = np.array(target_image)
  143. target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
  144. h, w = template_image.shape[:-1]
  145. # Template matching
  146. result = cv2.matchTemplate(target_image, template_image, cv2.TM_CCOEFF_NORMED)
  147. # Define a threshold
  148. threshold = 0.9 # Adjust this threshold based on your requirements
  149. # Finding all locations where match exceeds threshold
  150. locations = np.where(result >= threshold)
  151. locations = list(zip(*locations[::-1]))
  152. # Create list of rectangles
  153. rectangles = [(*loc, loc[0] + w, loc[1] + h) for loc in locations]
  154. # Apply non-maximum suppression to remove overlaps
  155. rectangles = non_max_suppression(rectangles, 0.3)
  156. # Initialize an empty list to store coordinates
  157. coordinates = []
  158. for startX, startY, endX, endY in rectangles:
  159. # Append the coordinate pair to the list
  160. coordinates.append(find_center(startX, startY, endX, endY))
  161. # Sort the coordinates by y value in ascending order
  162. return sorted(coordinates, key=lambda x: x[1])
  163. def non_max_suppression(boxes, overlapThresh):
  164. if len(boxes) == 0:
  165. return []
  166. # Convert to float
  167. boxes = np.array(boxes, dtype="float")
  168. # Initialize the list of picked indexes
  169. pick = []
  170. # Grab the coordinates of the bounding boxes
  171. x1 = boxes[:, 0]
  172. y1 = boxes[:, 1]
  173. x2 = boxes[:, 2]
  174. y2 = boxes[:, 3]
  175. # Compute the area of the bounding boxes and sort by bottom-right y-coordinate
  176. area = (x2 - x1 + 1) * (y2 - y1 + 1)
  177. idxs = np.argsort(y2)
  178. # Keep looping while some indexes still remain in the indexes list
  179. while len(idxs) > 0:
  180. # Grab the last index in the indexes list and add the index value to the list of picked indexes
  181. last = len(idxs) - 1
  182. i = idxs[last]
  183. pick.append(i)
  184. # Find the largest (x, y) coordinates for the start of the bounding box and the smallest (x, y)
  185. # coordinates for the end of the bounding box
  186. xx1 = np.maximum(x1[i], x1[idxs[:last]])
  187. yy1 = np.maximum(y1[i], y1[idxs[:last]])
  188. xx2 = np.minimum(x2[i], x2[idxs[:last]])
  189. yy2 = np.minimum(y2[i], y2[idxs[:last]])
  190. # Compute the width and height of the bounding box
  191. w = np.maximum(0, xx2 - xx1 + 1)
  192. h = np.maximum(0, yy2 - yy1 + 1)
  193. # Compute the ratio of overlap
  194. overlap = (w * h) / area[idxs[:last]]
  195. # Delete all indexes from the index list that have overlap greater than the threshold
  196. idxs = np.delete(
  197. idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
  198. )
  199. # Return only the bounding boxes that were picked
  200. return boxes[pick].astype("int")
  201. def save_screenshot(path="test"):
  202. # Take a screenshot
  203. result = get_current_screen()
  204. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  205. image = Image.open(io.BytesIO(result))
  206. jpeg_filename = f"{path}/{timestamp}.jpg"
  207. image = image.convert("RGB") # Convert to RGB mode for JPEG
  208. with open(jpeg_filename, "wb") as fp:
  209. image.save(fp, format="JPEG", quality=85) # Adjust quality as needed
  210. print(f"snap: {jpeg_filename}")
  211. def save_screenshot2(path="test"):
  212. proc = subprocess.Popen(
  213. "adb exec-out screencap -p", shell=True, stdout=subprocess.PIPE
  214. )
  215. image_bytes = proc.stdout.read()
  216. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
  217. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  218. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  219. jpeg_filename = f"{path}/{timestamp}.jpg"
  220. cv2.imwrite(
  221. jpeg_filename,
  222. cv2.cvtColor(image, cv2.COLOR_RGB2BGR),
  223. [int(cv2.IMWRITE_JPEG_QUALITY), 85],
  224. )
  225. print(f"snap: {jpeg_filename}")
  226. device = client.device(android_address)
  227. current_screen = capture_current_screen()