helper.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. import cv2
  2. import numpy as np
  3. from io import BytesIO
  4. from PIL import Image
  5. from ppadb.client import Client as AdbClient
  6. with_cuda = 0
  7. if cv2.cuda.getCudaEnabledDeviceCount() > 0:
  8. print("CUDA is available")
  9. with_cuda = 1
  10. else:
  11. print("CUDA is not available")
  12. with_cuda = 0
  13. def get_current_screen():
  14. return current_screen
  15. def capture_current_screen():
  16. global current_screen
  17. current_screen = device.screencap()
  18. return current_screen
  19. def find_center(x1, y1, x2, y2):
  20. centerX = round(x1 + (x2 - x1) / 2)
  21. centerY = round(y1 + (y2 - y1) / 2)
  22. return centerX, centerY
  23. def tap(x, y=None):
  24. # Check if x is an int
  25. if isinstance(x, int):
  26. if not isinstance(y, int):
  27. raise ValueError("y must be an int when x is an int")
  28. # Construct the location string from both x and y
  29. location = f"{x} {y}"
  30. # Check if x is a string
  31. elif isinstance(x, str):
  32. location = x
  33. elif isinstance(x, tuple):
  34. location = f"{x[0]} {x[1]}"
  35. else:
  36. raise TypeError("x must be either an int or a string")
  37. # Assuming 'device' is a previously defined object with a 'shell' method
  38. action = f"input tap {location}"
  39. print(action)
  40. device.shell(action)
  41. def tap_button(template):
  42. button = find_template(template)
  43. if len(button) == 0:
  44. return
  45. tap(f"{button[0][0]} {button[0][1]}")
  46. def swipe(start, end, duration=1000):
  47. action = f"input swipe {start} {end} {duration}"
  48. print(action)
  49. device.shell(action)
  50. def look_for_templates(templates):
  51. for name, template in templates.items():
  52. locations = find_template(template)
  53. if len(locations) > 0:
  54. return name, locations
  55. return None, None
  56. def first_template(template_image):
  57. result = find_template(template_image)
  58. if len(result) > 0:
  59. return result[0]
  60. return None
  61. def find_template(template_image):
  62. if with_cuda == 1:
  63. # Ensure the images are in the correct format (BGR for OpenCV)
  64. target_image = capture_current_screen()
  65. # Upload images to GPU
  66. target_image_gpu = cv2.cuda_GpuMat()
  67. template_image_gpu = cv2.cuda_GpuMat()
  68. target_image_gpu.upload(target_image)
  69. template_image_gpu.upload(template_image)
  70. # Perform template matching on the GPU
  71. result_gpu = cv2.cuda.createTemplateMatching(cv2.CV_8UC3, cv2.TM_CCOEFF_NORMED)
  72. result = result_gpu.match(target_image_gpu, template_image_gpu)
  73. # Download result from GPU to CPU
  74. result = result.download()
  75. else:
  76. target_image = Image.open(BytesIO(get_current_screen()))
  77. # Convert the image to a NumPy array and then to BGR format (which OpenCV uses)
  78. target_image = np.array(target_image)
  79. target_image = cv2.cvtColor(target_image, cv2.COLOR_RGB2BGR)
  80. h, w = template_image.shape[:-1]
  81. # Template matching
  82. result = cv2.matchTemplate(target_image, template_image, cv2.TM_CCOEFF_NORMED)
  83. # Define a threshold
  84. threshold = 0.9 # Adjust this threshold based on your requirements
  85. # Finding all locations where match exceeds threshold
  86. locations = np.where(result >= threshold)
  87. locations = list(zip(*locations[::-1]))
  88. # Create list of rectangles
  89. rectangles = [(*loc, loc[0] + w, loc[1] + h) for loc in locations]
  90. # Apply non-maximum suppression to remove overlaps
  91. rectangles = non_max_suppression(rectangles, 0.3)
  92. # Initialize an empty list to store coordinates
  93. coordinates = []
  94. for startX, startY, endX, endY in rectangles:
  95. # Append the coordinate pair to the list
  96. coordinates.append(find_center(startX, startY, endX, endY))
  97. # Sort the coordinates by y value in ascending order
  98. return sorted(coordinates, key=lambda x: x[1])
  99. def non_max_suppression(boxes, overlapThresh):
  100. if len(boxes) == 0:
  101. return []
  102. # Convert to float
  103. boxes = np.array(boxes, dtype="float")
  104. # Initialize the list of picked indexes
  105. pick = []
  106. # Grab the coordinates of the bounding boxes
  107. x1 = boxes[:, 0]
  108. y1 = boxes[:, 1]
  109. x2 = boxes[:, 2]
  110. y2 = boxes[:, 3]
  111. # Compute the area of the bounding boxes and sort by bottom-right y-coordinate
  112. area = (x2 - x1 + 1) * (y2 - y1 + 1)
  113. idxs = np.argsort(y2)
  114. # Keep looping while some indexes still remain in the indexes list
  115. while len(idxs) > 0:
  116. # Grab the last index in the indexes list and add the index value to the list of picked indexes
  117. last = len(idxs) - 1
  118. i = idxs[last]
  119. pick.append(i)
  120. # Find the largest (x, y) coordinates for the start of the bounding box and the smallest (x, y)
  121. # coordinates for the end of the bounding box
  122. xx1 = np.maximum(x1[i], x1[idxs[:last]])
  123. yy1 = np.maximum(y1[i], y1[idxs[:last]])
  124. xx2 = np.minimum(x2[i], x2[idxs[:last]])
  125. yy2 = np.minimum(y2[i], y2[idxs[:last]])
  126. # Compute the width and height of the bounding box
  127. w = np.maximum(0, xx2 - xx1 + 1)
  128. h = np.maximum(0, yy2 - yy1 + 1)
  129. # Compute the ratio of overlap
  130. overlap = (w * h) / area[idxs[:last]]
  131. # Delete all indexes from the index list that have overlap greater than the threshold
  132. idxs = np.delete(
  133. idxs, np.concatenate(([last], np.where(overlap > overlapThresh)[0]))
  134. )
  135. # Return only the bounding boxes that were picked
  136. return boxes[pick].astype("int")
  137. client = AdbClient(host="127.0.0.1", port=5037)
  138. device = client.device("192.168.178.32:5555")
  139. current_screen = capture_current_screen()