8.2 实战案例:基于Pygame的交互式路径规划器
本项目是一个基于 Pygame 的交互式路径规划器,通过 A* 和 D* Lite 算法实现路径规划。用户可以在 GUI 界面中设置起点、终点和障碍物,并观察路径规划的过程和结果,以及实时更新的路径信息。
实例8-2:基于Pygame的交互式路径规划器(codes/8/dynamic-a-star/)
8.2.1 项目介绍
在现代社会中,路径规划和算法优化已成为许多领域的重要问题,如物流管理、自动驾驶和游戏开发等。特别是在计算机科学和人工智能领域,路径规划算法的研究和应用愈发重要。A*(A-star)和D* Lite是两种常用的路径规划算法,分别用于寻找最短路径和在动态环境中更新路径。这些算法的实现不仅可以帮助人们理解路径规划的基本原理,还可以应用于实际项目中,如智能导航系统和游戏引擎等。
本项目旨在结合A* 和 D* Lite 算法,开发一个路径规划可视化工具,以帮助用户更直观地了解这两种算法的工作原理和应用场景。该工具提供了一个交互式的图形界面,用户可以在界面上设置起点、终点和障碍物,然后观察算法如何计算最优路径,并实时显示在界面上。通过这个工具,用户可以深入了解A*算法和D* Lite算法的运行过程,以及它们在不同场景下的优劣势。
本项目的核心代码分为两部分:A*算法和D* Lite算法的实现。A*算法用于在静态环境中寻找最短路径,而D* Lite算法则适用于动态环境下的路径更新和重新规划。图形界面部分基于Pygame库实现,用户可以通过鼠标和键盘与界面进行交互,设置参数和观察算法执行过程。该工具不仅是一个教学工具,还可以作为实验平台,用于测试和比较不同路径规划算法的性能和效果。
通过本项目,用户可以学习到A*算法和D* Lite算法的原理和实现细节,并通过可视化界面直观地感受到路径规划的过程和结果。同时,也为学术界和工程界提供了一个开放的平台,以促进路径规划算法的研究和应用。
8.2.2 实现路径规划算法
文件path_finding.py实现了两个路径规划算法:D* Lite和A*,具体实现流程如下所示。
(1)定义函数 h(p1, p2),用于计算两个二维点之间的曼哈顿距离。函数接受两个二维点 p1 和 p2,分别表示为 (x1, y1) 和 (x2, y2),然后计算它们之间的曼哈顿距离,并返回结果。曼哈顿距离是指两点在直角坐标系上的横纵坐标的差的绝对值之和。
def h(p1, p2): x1, y1 = p1 x2, y2 = p2 return abs(x1 - x2) + abs(y1 - y2)(2)下面代码定义了一个用于计算节点关键值的函数 calculate_key(spot, current, k_m),该函数接受三个参数,分别是待计算关键值的节点 spot、当前节点 current 和额外的调整参数 k_m。它根据节点的属性以及当前节点和参数值,计算节点的两个关键值 k1 和 k2,然后返回一个元组 (k1, k2),表示节点的关键值。
def calculate_key(spot, current, k_m): # 计算节点的关键值 k1 = min(spot.g, spot.rhs) + h(spot.get_pos(), current.get_pos()) + k_m k2 = min(spot.g, spot.rhs) return (k1, k2)(3)定义函数 top_key(queue),用于获取优先队列中关键值最小的节点的关键值。函数接受一个优先队列 queue 作为参数。函数 top_key(queue)的具体功能如下:
- 首先,对队列 queue 进行排序,使得队列中具有最小关键值的节点排在队列的前面。
- 然后,检查队列是否为空,如果队列不为空,则返回队列中第一个节点的关键值的前两个元素(即 queue[0][:2])。
- 如果队列为空,则返回一个表示无穷大的元组 (float('inf'), float('inf')),表示找不到最小关键值。
def top_key(queue): # 获取队列中关键值最小的节点的关键值 queue.sort() if len(queue) > 0: return queue[0][:2] else: return (float('inf'), float('inf'))(4)定义函数update_vertex,功能是更新节点的状态。该函数接受六个参数,包括用于绘制的函数 draw,优先队列 queue,待更新状态的节点 spot,当前节点 current,目标节点 end,以及额外的调整参数 k_m。在函数update_vertex中,首先根据当前节点与目标节点的关系计算节点的 rhs 值,然后检查节点是否在队列中,并移除旧的状态。接着,根据节点的新状态将其重新加入队列,并更新节点的状态标记。最后,调用绘制函数 draw 更新可视化界面。综上所述,该函数实现了根据节点与目标节点的关系更新节点状态,并重新加入优先队列的功能。
def update_vertex(draw, queue, spot, current, end, k_m): # 更新节点的状态 s_goal = end if spot != s_goal: min_rhs = float('inf') for neighbor in spot.neighbors: min_rhs = min(min_rhs, neighbor.g + h(spot.get_pos(), neighbor.get_pos())) spot.rhs = min_rhs id_in_queue = [item for item in queue if spot in item] if id_in_queue != []: if len(id_in_queue) != 1: raise ValueError('more than one spot (' + spot.get_pos() + ') in the queue!') queue.remove(id_in_queue[0]) if spot.rhs != spot.g: heapq.heappush(queue, calculate_key(spot, current, k_m) + (spot,)) spot.make_open() draw()(5)定义函数next_in_shortest_path,用于在当前节点的邻居中找到路径代价最小的节点。该函数接受一个参数 current,表示当前节点。函数next_in_shortest_path的主要功能如下:
- 首先,初始化变量 min_rhs 为正无穷。
- 然后,检查当前节点的 rhs 是否为正无穷,如果是,则输出提示信息 "You are stuck!"。
- 接着,遍历当前节点的邻居节点,并计算每个邻居节点到当前节点的路径代价。
- 如果某个邻居节点的路径代价小于 min_rhs,则更新 min_rhs 和 next,分别记录当前找到的最小路径代价和对应的节点。
- 最后,如果找到了符合条件的邻居节点,则返回该节点,否则抛出异常 "No suitable child for transition!"。
def next_in_shortest_path(current): # 获取当前节点邻居中路径代价最小的节点 min_rhs = float('inf') next = None if current.rhs == float('inf'): print('You are stuck!') else: for neighbor in current.neighbors: child_cost = neighbor.g + h(current.get_pos(), neighbor.get_pos()) if (child_cost) < min_rhs: min_rhs = child_cost next = neighbor if next: return next else: raise ValueError('No suitable child for transition!')(6)定义函数scan_obstacles,功能是扫描障碍物并更新路径。该函数接受六个参数,包括用于绘制的函数 draw,优先队列 queue,当前节点 current,目标节点 end,扫描范围 scan_range,以及额外的调整参数 k_m。函数scan_obstacles的主要功能如下:
- 首先,根据给定的扫描范围,初始化 spots_to_update 列表,并将当前节点的邻居节点加入其中,同时标记已经检查的范围为1。
- 然后,利用循环来扩展扫描范围,逐层向外扩展并记录新发现的节点。
- 接着,确保扫描到的节点具有唯一性,去除重复的节点。
- 最后,遍历扫描到的节点,如果发现障碍物或目标物体,则更新相应节点的状态,并将其加入优先队列中,同时标记 new_obstacle 为 True。
def scan_obstacles(draw, queue, current, end, scan_range, k_m): # 扫描障碍物并更新路径 spots_to_update = [] range_checked = 0 if scan_range >= 1: for neighbor in current.neighbors: print(f"adding {neighbor.get_pos()} to spots to update") spots_to_update.append(neighbor) range_checked = 1 while range_checked < scan_range: new_set = [] for spot in spots_to_update: new_set.append(spot) print(f"adding {spot.get_pos()} to spots to update") for neighbor in spot.neighbors: if neighbor not in new_set: new_set.append(neighbor) print(f"adding {neighbor.get_pos()} to spots to update") range_checked += 1 spots_to_update = new_set # 确保唯一性 spots_to_update = list(set(spots_to_update)) new_obstacle = False for spot in spots_to_update: if spot.is_barrier() or spot.is_object(): print('found obstacle in ', spot.get_pos()) for neighbor in spot.neighbors: update_vertex(draw, queue, spot, current, end, k_m) new_obstacle = True return new_obstacle(7)定义函数move_and_rescan,功能是移动并重新扫描。该函数接受六个参数,包括用于绘制的函数 draw,优先队列 queue,当前节点 current,目标节点 end,扫描范围 scan_range,以及额外的调整参数 k_m。函数move_and_rescan的主要功能如下:
- 首先,检查当前节点是否为目标节点,如果是,则返回字符串 'goal' 和当前调整参数 k_m,表示已到达目标。
- 否则,记录当前节点为上一步节点 last,并通过调用 next_in_shortest_path 函数找到路径中的下一个节点 new。
- 如果下一个节点是障碍物或目标物体,则输出提示信息 "obstacle" 并将下一个节点 new 设为当前节点,以避免移动到障碍物。
- 接着,调用 scan_obstacles 函数重新扫描障碍物并更新路径,同时将新的调整参数 k_m 计算为上一步节点到当前节点的曼哈顿距离加上当前 k_m。
- 最后,调用 calc_shortest_path 函数重新计算最短路径,并返回新的当前节点 new 和更新后的调整参数 k_m。
def move_and_rescan(draw, queue, current, end, scan_range, k_m): # 移动并重新扫描 if current == end: return 'goal', k_m else: last = current new = next_in_shortest_path(current) if new.is_object() or new.is_barrier(): print("obstacle") new = current results = scan_obstacles(draw, queue, new, end, scan_range, k_m) k_m += h(last.get_pos(), new.get_pos()) calc_shortest_path(draw, queue, current, end, k_m) return new, k_m(8)定义函数calc_shortest_path,功能是计算最短路径。首先,使用 while 循环,条件为起始节点的 rhs 不等于 g 或者队列中最小关键值小于计算得到的起始节点的关键值。在循环中,首先获取队列中最小关键值 k_old 及对应的节点 u,并将其从队列中弹出。然后,根据条件比较 k_old 和节点 u 的关键值,分别进行不同的处理:
- 如果 k_old 小于节点 u 的关键值,则将节点 u 再次加入队列,并标记为开放状态。
- 如果节点 u 的 g 值大于 rhs 值,则更新节点 u 的 g 值为 rhs 值,并更新其邻居节点的状态。
- 否则,将节点 u 的 g 值设为正无穷,并更新其邻居节点的状态。
最后,将节点 u 标记为关闭状态,并调用绘制函数 draw 更新可视化界面。
def calc_shortest_path(draw, queue, start, end, k_m): # 计算最短路径 while (start.rhs != start.g) or (top_key(queue) < calculate_key(start, start, k_m)): k_old = top_key(queue) u = heapq.heappop(queue)[2] if k_old < calculate_key(u, start, k_m): heapq.heappush(queue, calculate_key(u, start, k_m) + (u,)) u.make_open() elif u.g > u.rhs: u.g = u.rhs for neighbor in u.neighbors: update_vertex(draw, queue, neighbor, start, end, k_m) else: u.g = float('inf') update_vertex(draw, queue, u, start, end, k_m) for neighbor in u.neighbors: update_vertex(draw, queue, neighbor, start, end, k_m) u.make_closed() draw()(9)定义函数d_star_lite,实现了 D* Lite 算法的核心功能,包括初始化节点属性、计算关键值、更新最短路径等步骤,用于寻找起始节点到目标节点的最短路径。该函数接受六个参数,包括用于绘制的函数 draw,网格 grid,优先队列 queue,起始节点 start,目标节点 end,以及额外的调整参数 k_m。函数d_star_lite的具体实现流程如下所示。
- 首先,通过两个嵌套的循环,将网格中所有节点的 g 和 rhs 属性初始化为正无穷。
- 然后,将目标节点 end 的 g 和 rhs 属性都设为0。
- 接着,使用 calculate_key 函数计算目标节点 end 的关键值,并将其加入优先队列 queue 中。
- 调用 calc_shortest_path 函数计算最短路径,更新节点的状态。
- 最后,返回更新后的优先队列 queue 和调整参数 k_m。
def d_star_lite(draw, grid, queue, start, end, k_m): # D* Lite算法 for row in grid: for spot in row: spot.g = float("inf") spot.rhs = float("inf") end.g = 0 end.rhs = 0 heapq.heappush(queue, calculate_key(end, start, k_m) + (end,)) calc_shortest_path(draw, queue, start, end, k_m) return queue, k_m(10)定义函数a_star实现A*算法,该算法用于在给定的网格中寻找从起始节点到目标节点的最短路径。该函数接受四个参数:用于绘制的函数 draw、网格 grid、起始节点 start 和目标节点 end。函数a_star首先初始化计数器和优先队列,然后通过计算节点间的成本来搜索最短路径。在搜索过程中,它记录每个节点的父节点以及从起始节点到每个节点的实际成本和综合成本,并使用优先队列来动态选择下一个要探索的节点。最终,当找到路径或搜索到达终点时,算法返回从起始节点到目标节点的路径。
def a_star(draw, grid, start, end): # A*算法 count = 0 open_set = PriorityQueue() open_set.put((0, count, start)) came_from = {} g_score = {spot: float("inf") for row in grid for spot in row} g_score[start] = 0 f_score = {spot: float("inf") for row in grid for spot in row} f_score[start] = h(start.get_pos(), end.get_pos()) open_set_hash = {start} while not open_set.empty(): for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() current = open_set.get()[2] open_set_hash.remove(current) if current == end: return came_from for neighbor in current.neighbors: temp_g_score = g_score[current] + 1 if temp_g_score < g_score[neighbor]: came_from[neighbor] = current g_score[neighbor] = temp_g_score f_score[neighbor] = temp_g_score + h(neighbor.get_pos(), end.get_pos()) if neighbor not in open_set_hash: count += 1 open_set.put((f_score[neighbor], count, neighbor)) open_set_hash.add(neighbor) neighbor.make_open() if current != start: current.make_closed() draw() return None