helper.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import cv2
  2. import datetime
  3. import io
  4. import numpy as np
  5. import os
  6. import threading
  7. import subprocess
  8. from io import BytesIO
  9. from PIL import Image
  10. from ppadb.client import Client as AdbClient
  11. from dotenv import load_dotenv
  12. with_cuda = 0
  13. if cv2.cuda.getCudaEnabledDeviceCount() > 0:
  14. print("CUDA is available")
  15. with_cuda = 1
  16. else:
  17. print("CUDA is not available")
  18. with_cuda = 0
  19. load_dotenv()
  20. android_address = os.getenv("ANDROID_ADDRESS")
  21. def get_current_screen():
  22. return current_screen
  23. def capture_current_screen(timeout=5): # Timeout in seconds
  24. def target():
  25. global current_screen
  26. current_screen = device.screencap()
  27. capture_thread = threading.Thread(target=target)
  28. capture_thread.start()
  29. capture_thread.join(timeout)
  30. if capture_thread.is_alive():
  31. print("Screen capture timed out")
  32. # Handle the timeout situation, e.g., by retrying or aborting
  33. capture_thread.join()
  34. return current_screen
  35. def find_center(x1, y1, x2, y2):
  36. centerX = round(x1 + (x2 - x1) / 2)
  37. centerY = round(y1 + (y2 - y1) / 2)
  38. return centerX, centerY
  39. def call_device_shell(action, timeout=10):
  40. def target():
  41. device.shell(action)
  42. thread = threading.Thread(target=target)
  43. thread.start()
  44. thread.join(timeout)
  45. if thread.is_alive():
  46. print("ran into timeout")
  47. thread.join()
  48. def tap(x, y=None):
  49. # Check if x is an int
  50. if isinstance(x, int):
  51. if not isinstance(y, int):
  52. raise ValueError("y must be an int when x is an int")
  53. # Construct the location string from both x and y
  54. location = f"{x} {y}"
  55. # Check if x is a string
  56. elif isinstance(x, str):
  57. location = x
  58. elif isinstance(x, tuple):
  59. location = f"{x[0]} {x[1]}"
  60. else:
  61. raise TypeError("x must be either an int or a string")
  62. # Assuming 'device' is a previously defined object with a 'shell' method
  63. action = f"input tap {location}"
  64. print(action)
  65. call_device_shell(action, timeout=5)
  66. def tap_button(template):
  67. button = find_template(template)
  68. if len(button) == 0:
  69. return
  70. tap(f"{button[0][0]} {button[0][1]}")
  71. def swipe(start, end, duration=1000):
  72. action = f"input swipe {start} {end} {duration}"
  73. print(action)
  74. call_device_shell(action, timeout=5)
  75. def look_for_templates(templates):
  76. for name, template in templates.items():
  77. locations = find_template(template)
  78. if len(locations) > 0:
  79. return name, locations
  80. return None, None
  81. def first_template(template_image):
  82. result = find_template(template_image)
  83. if len(result) > 0:
  84. return result[0]
  85. return None
  86. def find_template(template_image):
  87. if with_cuda == 1:
  88. # Ensure the images are in the correct format (BGR for OpenCV)
  89. target_image = capture_current_screen()
  90. # Upload images to GPU
  91. target_image_gpu = cv2.cuda_GpuMat()
  92. template_image_gpu = cv2.cuda_GpuMat()
  93. target_image_gpu.upload(target_image)
  94. template_image_gpu.upload(template_image)
  95. # Perform template matching on the GPU
  96. result_gpu = cv2.cuda.createTemplateMatching(cv2.CV_8UC3, cv2.TM_CCOEFF_NORMED)
  97. result = result_gpu.match(target_image_gpu, template_image_gpu)
  98. # Download result from GPU to CPU
  99. result = result.download()
  100. else:
  101. target_image = Image.open(BytesIO(get_current_screen()))
  102. # Convert the image to a NumPy array and then to BGR format (which OpenCV uses)
  103. target_image = np.array(target_image)
  104. target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
  105. h, w = template_image.shape[:-1]
  106. # Template matching
  107. result = cv2.matchTemplate(target_image, template_image, cv2.TM_CCOEFF_NORMED)
  108. # Define a threshold
  109. threshold = 0.9 # Adjust this threshold based on your requirements
  110. # Finding all locations where match exceeds threshold
  111. locations = np.where(result >= threshold)
  112. locations = list(zip(*locations[::-1]))
  113. # Create list of rectangles
  114. rectangles = [(*loc, loc[0] + w, loc[1] + h) for loc in locations]
  115. # Apply non-maximum suppression to remove overlaps
  116. rectangles = non_max_suppression(rectangles, 0.3)
  117. # Initialize an empty list to store coordinates
  118. coordinates = []
  119. for startX, startY, endX, endY in rectangles:
  120. # Append the coordinate pair to the list
  121. coordinates.append(find_center(startX, startY, endX, endY))
  122. # Sort the coordinates by y value in ascending order
  123. return sorted(coordinates, key=lambda x: x[1])
  124. def non_max_suppression(boxes, overlapThresh):
  125. if len(boxes) == 0:
  126. return []
  127. # Convert to float
  128. boxes = np.array(boxes, dtype="float")
  129. # Initialize the list of picked indexes
  130. pick = []
  131. # Grab the coordinates of the bounding boxes
  132. x1 = boxes[:, 0]
  133. y1 = boxes[:, 1]
  134. x2 = boxes[:, 2]
  135. y2 = boxes[:, 3]
  136. # Compute the area of the bounding boxes and sort by bottom-right y-coordinate
  137. area = (x2 - x1 + 1) * (y2 - y1 + 1)
  138. idxs = np.argsort(y2)
  139. # Keep looping while some indexes still remain in the indexes list
  140. while len(idxs) > 0:
  141. # Grab the last index in the indexes list and add the index value to the list of picked indexes
  142. last = len(idxs) - 1
  143. i = idxs[last]
  144. pick.append(i)
  145. # Find the largest (x, y) coordinates for the start of the bounding box and the smallest (x, y)
  146. # coordinates for the end of the bounding box
  147. xx1 = np.maximum(x1[i], x1[idxs[:last]])
  148. yy1 = np.maximum(y1[i], y1[idxs[:last]])
  149. xx2 = np.minimum(x2[i], x2[idxs[:last]])
  150. yy2 = np.minimum(y2[i], y2[idxs[:last]])
  151. # Compute the width and height of the bounding box
  152. w = np.maximum(0, xx2 - xx1 + 1)
  153. h = np.maximum(0, yy2 - yy1 + 1)
  154. # Compute the ratio of overlap
  155. overlap = (w * h) / area[idxs[:last]]
  156. # Delete all indexes from the index list that have overlap greater than the threshold
  157. idxs = np.delete(
  158. idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
  159. )
  160. # Return only the bounding boxes that were picked
  161. return boxes[pick].astype("int")
  162. def save_screenshot(path="test"):
  163. # Take a screenshot
  164. result = capture_current_screen()
  165. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  166. image = Image.open(io.BytesIO(result))
  167. jpeg_filename = f"{path}/{timestamp}.jpg"
  168. image = image.convert("RGB") # Convert to RGB mode for JPEG
  169. with open(jpeg_filename, "wb") as fp:
  170. image.save(fp, format="JPEG", quality=85) # Adjust quality as needed
  171. print(f"snap: {jpeg_filename}")
  172. def save_screenshot2(path="test"):
  173. proc = subprocess.Popen(
  174. "adb exec-out screencap -p", shell=True, stdout=subprocess.PIPE
  175. )
  176. image_bytes = proc.stdout.read()
  177. image = cv2.imdecode(np.frombuffer(image_bytes, np.uint8), cv2.IMREAD_COLOR)
  178. image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
  179. timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
  180. jpeg_filename = f"{path}/{timestamp}.jpg"
  181. cv2.imwrite(
  182. jpeg_filename,
  183. cv2.cvtColor(image, cv2.COLOR_RGB2BGR),
  184. [int(cv2.IMWRITE_JPEG_QUALITY), 85],
  185. )
  186. print(f"snap: {jpeg_filename}")
  187. client = AdbClient(host="127.0.0.1", port=5037)
  188. device = client.device(android_address)
  189. current_screen = capture_current_screen()