helper.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. import cv2
  2. import numpy as np
  3. import os
  4. import threading
  5. from io import BytesIO
  6. from PIL import Image
  7. from ppadb.client import Client as AdbClient
  8. from dotenv import load_dotenv
  9. with_cuda = 0
  10. if cv2.cuda.getCudaEnabledDeviceCount() > 0:
  11. print("CUDA is available")
  12. with_cuda = 1
  13. else:
  14. print("CUDA is not available")
  15. with_cuda = 0
  16. load_dotenv()
  17. android_address = os.getenv("ANDROID_ADDRESS")
  18. def get_current_screen():
  19. return current_screen
  20. def capture_current_screen():
  21. global current_screen
  22. current_screen = device.screencap()
  23. return current_screen
  24. def find_center(x1, y1, x2, y2):
  25. centerX = round(x1 + (x2 - x1) / 2)
  26. centerY = round(y1 + (y2 - y1) / 2)
  27. return centerX, centerY
  28. def call_device_shell(action, timeout=10):
  29. def target():
  30. device.shell(action)
  31. thread = threading.Thread(target=target)
  32. thread.start()
  33. thread.join(timeout)
  34. if thread.is_alive():
  35. print("ran into timeout")
  36. thread.join()
  37. def tap(x, y=None):
  38. # Check if x is an int
  39. if isinstance(x, int):
  40. if not isinstance(y, int):
  41. raise ValueError("y must be an int when x is an int")
  42. # Construct the location string from both x and y
  43. location = f"{x} {y}"
  44. # Check if x is a string
  45. elif isinstance(x, str):
  46. location = x
  47. elif isinstance(x, tuple):
  48. location = f"{x[0]} {x[1]}"
  49. else:
  50. raise TypeError("x must be either an int or a string")
  51. # Assuming 'device' is a previously defined object with a 'shell' method
  52. action = f"input tap {location}"
  53. print(action)
  54. call_device_shell(action, timeout=5)
  55. def tap_button(template):
  56. button = find_template(template)
  57. if len(button) == 0:
  58. return
  59. tap(f"{button[0][0]} {button[0][1]}")
  60. def swipe(start, end, duration=1000):
  61. action = f"input swipe {start} {end} {duration}"
  62. print(action)
  63. call_device_shell(action, timeout=5)
  64. def look_for_templates(templates):
  65. for name, template in templates.items():
  66. locations = find_template(template)
  67. if len(locations) > 0:
  68. return name, locations
  69. return None, None
  70. def first_template(template_image):
  71. result = find_template(template_image)
  72. if len(result) > 0:
  73. return result[0]
  74. return None
  75. def find_template(template_image):
  76. if with_cuda == 1:
  77. # Ensure the images are in the correct format (BGR for OpenCV)
  78. target_image = capture_current_screen()
  79. # Upload images to GPU
  80. target_image_gpu = cv2.cuda_GpuMat()
  81. template_image_gpu = cv2.cuda_GpuMat()
  82. target_image_gpu.upload(target_image)
  83. template_image_gpu.upload(template_image)
  84. # Perform template matching on the GPU
  85. result_gpu = cv2.cuda.createTemplateMatching(cv2.CV_8UC3, cv2.TM_CCOEFF_NORMED)
  86. result = result_gpu.match(target_image_gpu, template_image_gpu)
  87. # Download result from GPU to CPU
  88. result = result.download()
  89. else:
  90. target_image = Image.open(BytesIO(get_current_screen()))
  91. # Convert the image to a NumPy array and then to BGR format (which OpenCV uses)
  92. target_image = np.array(target_image)
  93. target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
  94. h, w = template_image.shape[:-1]
  95. # Template matching
  96. result = cv2.matchTemplate(target_image, template_image, cv2.TM_CCOEFF_NORMED)
  97. # Define a threshold
  98. threshold = 0.9 # Adjust this threshold based on your requirements
  99. # Finding all locations where match exceeds threshold
  100. locations = np.where(result >= threshold)
  101. locations = list(zip(*locations[::-1]))
  102. # Create list of rectangles
  103. rectangles = [(*loc, loc[0] + w, loc[1] + h) for loc in locations]
  104. # Apply non-maximum suppression to remove overlaps
  105. rectangles = non_max_suppression(rectangles, 0.3)
  106. # Initialize an empty list to store coordinates
  107. coordinates = []
  108. for startX, startY, endX, endY in rectangles:
  109. # Append the coordinate pair to the list
  110. coordinates.append(find_center(startX, startY, endX, endY))
  111. # Sort the coordinates by y value in ascending order
  112. return sorted(coordinates, key=lambda x: x[1])
  113. def non_max_suppression(boxes, overlapThresh):
  114. if len(boxes) == 0:
  115. return []
  116. # Convert to float
  117. boxes = np.array(boxes, dtype="float")
  118. # Initialize the list of picked indexes
  119. pick = []
  120. # Grab the coordinates of the bounding boxes
  121. x1 = boxes[:, 0]
  122. y1 = boxes[:, 1]
  123. x2 = boxes[:, 2]
  124. y2 = boxes[:, 3]
  125. # Compute the area of the bounding boxes and sort by bottom-right y-coordinate
  126. area = (x2 - x1 + 1) * (y2 - y1 + 1)
  127. idxs = np.argsort(y2)
  128. # Keep looping while some indexes still remain in the indexes list
  129. while len(idxs) > 0:
  130. # Grab the last index in the indexes list and add the index value to the list of picked indexes
  131. last = len(idxs) - 1
  132. i = idxs[last]
  133. pick.append(i)
  134. # Find the largest (x, y) coordinates for the start of the bounding box and the smallest (x, y)
  135. # coordinates for the end of the bounding box
  136. xx1 = np.maximum(x1[i], x1[idxs[:last]])
  137. yy1 = np.maximum(y1[i], y1[idxs[:last]])
  138. xx2 = np.minimum(x2[i], x2[idxs[:last]])
  139. yy2 = np.minimum(y2[i], y2[idxs[:last]])
  140. # Compute the width and height of the bounding box
  141. w = np.maximum(0, xx2 - xx1 + 1)
  142. h = np.maximum(0, yy2 - yy1 + 1)
  143. # Compute the ratio of overlap
  144. overlap = (w * h) / area[idxs[:last]]
  145. # Delete all indexes from the index list that have overlap greater than the threshold
  146. idxs = np.delete(
  147. idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
  148. )
  149. # Return only the bounding boxes that were picked
  150. return boxes[pick].astype("int")
  151. client = AdbClient(host="127.0.0.1", port=5037)
  152. device = client.device(android_address)
  153. current_screen = capture_current_screen()