橘智橘智
FakeOrange
预计阅读时间:10分钟26秒

蒙特卡洛树搜索 (MCTS) 算法通俗指南

一套MCTS的解说及代码实现,用AI玩游戏吧!

0
0


前言



本文属于搬运内容,原作者:michelangelo (Salute!)


如果我们想真正理解 AlphaZero 和 MuZero 的工作原理,就必须先深入了解蒙特卡洛树搜索 (Monte Carlo Tree Search, MCTS) 算法,它是这两个算法的基础。



粗略理解



MCTS 让我们能够通过统计方法而非完全探索来寻找最佳的行动路径,从而可以在较少限制的情况下导航庞大的状态空间,而不像蛮力搜索那样效率低下。

事实上,MCTS 的搜索过程是有指导的:它会更频繁地探索有希望的节点,收集的统计数据越多,得到的结果就越可靠。



更具体地说,MCTS 包括以下四个步骤:



  • 选择 (Selection)我们需要一种聪明的方法来探索树中的节点,以便找到更有前景的结果。因此,需要对节点赋予某种价值,以便选择要探索的节点。


  • 扩展 (Expansion)当到达一个叶子节点(即游戏的某一状态)时,需要扩展树,考虑从该节点出发的所有可能的有效行动(即采取有效行动后游戏的后续状态)。


  • 模拟 (Simulation)从某个节点(游戏状态)开始,我们需要能够模拟一场完全随机的游戏。单独的一次模拟可能会得到随机的结果,但随着从该节点进行的模拟次数增加,节点的平均价值估计会变得越来越准确。


  • 回溯 (Backpropagation)模拟完成后,需要将结果向上传播到树的上层节点,以便在上层节点中存储这些信息,从而将叶子节点的信息传递到根节点。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/bce2900e-fb7f-46ab-b690-04062991fb2dimage.png



具体实施



简而言之,MCTS 算法的四个步骤重复执行得越多,我们关于从特定游戏状态采取最佳行动的信息就越可靠、统计上越有效。将这些步骤应用于每一次行动,我们就可以像高手一样制定游戏策略!


接下来,我们看看如何实现 MCTS 并将其用于玩一个 OpenAI Gym 的视频游戏。



创建环境



1. 创建 OpenAI Gym 环境


首先,我们需要创建一个 OpenAI Gym 环境作为算法的测试平台,并检查游戏中可能的动作数量以及观察空间的维度。以下是代码实现:


import gym

# 游戏名称
GAME_NAME = 'CartPole-v0'

# 创建环境
env = gym.make(GAME_NAME)

# 获取游戏的可能动作数和观察空间维度
GAME_ACTIONS = env.action_space.n
GAME_OBS = env.observation_space.shape[0]

# 输出游戏信息
print('In the ' + GAME_NAME + ' environment there are: ' + str(GAME_ACTIONS) + ' possible actions.')
print('In the ' + GAME_NAME + ' environment the observation is composed of: ' + str(GAME_OBS) + ' values.')

# 重置环境
env.reset()
# 关闭环境
env.close()



代码解释:


  • gym.make(GAME_NAME):加载指定游戏环境(这里是 CartPole-v0)。


  • env.action_space.n:返回游戏中可能的动作数量(离散动作空间)。


  • env.observation_space.shape[0]:返回观察空间的维度,即环境提供的观测值数量。


  • env.reset():重置环境以准备新一轮模拟。


  • env.close():关闭环境释放资源。


执行这段代码,你会看到类似以下的输出:


In the CartPole-v0 environment there are: 2 possible actions.  
In the CartPole-v0 environment the observation is composed of: 4 values.  


这表明 CartPole-v0 游戏有两个可能的动作(向左或向右),观察值由 4 个变量组成(如位置、速度、角度等)。

接下来,你可以基于这个环境实现 MCTS 算法,模拟和评估不同的动作选择策略。



CartPole 环境和 MCTS 树的 Node 类定义



1. CartPole 环境简介


在 CartPole 环境中,目标是通过向左右施加力使小车保持平衡,从而避免连接的杆子倒下:


  • 可能的动作:向左推小车或向右推小车(2 种可能动作)。


  • 观察值:由 4 个值组成:



2. Node 类定义
Node 类用于表示 MCTS 树中的一个节点,包含了算法运行所需的信息:


class Node:
    '''
    Node 类表示 MCTS 树中的一个节点。
    它包含算法执行搜索所需的全部信息。
    '''

    def __init__(self, game, done, parent, observation, action_index):
        # 子节点字典 {action: child_node}
        self.child = {}
        
        # 来自 MCTS 探索的总奖励值
        self.T = 0
        
        # 节点访问次数
        self.N = 0        
        
        # 环境副本,表示当前节点的游戏状态
        self.game = game
        
        # 环境的当前观测值
        self.observation = observation
        
        # 游戏是否结束(胜/负/平)
        self.done = done

        # 指向父节点的链接(用于回溯)
        self.parent = parent
        
        # 导致到达当前节点的动作索引
        self.action_index = action_index

Node 类的关键属性详解


  • child子节点字典,表示当前节点采取不同动作后的后续状态。结构:{action_index: child_node}。


  • T从该节点开始的模拟(rollouts)的总奖励值。用于统计从该节点起步的探索质量。


  • N节点访问次数,表示从根节点到达该节点的次数。用于平衡探索(未被充分访问的节点)和利用(高奖励节点)。


  • game游戏环境的当前状态。是原始环境的副本,用于模拟和搜索。


  • observation节点表示的游戏状态,在 CartPole 中即 4 个观察值。


  • done表示游戏是否在当前节点结束(例如杆子倒下、小车越界等)。如果结束,则停止进一步扩展和模拟。


  • parent父节点的引用,用于回溯(backpropagation)。


  • action_index父节点采取的动作索引,用于从父节点到达当前节点。



定义 Node 中的赋值以支持搜索



我们首先定义一个方法,用于为节点赋值。这是非常关键的步骤,因为它将指导我们沿着搜索树进行探索:


def getUCBscore(self):
    '''
    该公式为节点赋值。
    MCTS 将选择值最大的节点进行进一步搜索。
    '''
    
    # 未探索的节点分配最大值以鼓励探索
    if self.N == 0:
        return float('inf')
    
    # 获取当前节点的父节点
    top_node = self
    if top_node.parent:
        top_node = top_node.parent
        
    # 使用 MCTS 的公式之一计算节点的值
    return (self.T / self.N) + c * sqrt(log(top_node.N) / self.N)



公式解释:UCB(Upper Confidence Bound)


data/78df2c1f-e442-415d-a382-fa7925af0c4b/b0d53817-2bec-4fed-a886-6ad0abaf0981image.png



方法运行逻辑


data/78df2c1f-e442-415d-a382-fa7925af0c4b/61bf5f59-52cc-467d-b1f2-f031da1f579bimage.png



定义函数创建子节点



def create_child(self):
    '''
    我们为游戏中的每个可能动作创建一个子节点,
    然后将该动作应用于当前节点环境的一个副本,
    并使用执行该动作后返回的相关信息创建对应的子节点。
    '''
    
    if self.done:
        return

    actions = []
    games = []
    for i in range(GAME_ACTIONS): 
        actions.append(i)           
        new_game = deepcopy(self.game)
        games.append(new_game)
        
    child = {} 
    for action, game in zip(actions, games):
        observation, reward, done, _ = game.step(action)
        child[action] = Node(game, done, self, observation, action)                        
        
    self.child = child



注意,我们始终使用游戏中所有可能的动作(GAME_ACTIONS),为每个动作创建一个子节点,并将游戏环境的副本与该动作关联。在此基础上,通过执行动作(game.step(action))获取游戏的结果状态信息,并将其存储到子节点中。



定义搜索



我们现在准备实现算法的核心部分,在这里进行搜索:


def explore(self):
    '''
    树的搜索过程如下:
    - 从当前节点开始,递归地选择最大化 MCTS 公式值的子节点
    - 当到达叶节点时:
        - 如果该节点之前从未被探索过,则执行一次随机游戏并更新其当前值
        - 否则,扩展该节点,创建其子节点,随机选择一个子节点,执行随机游戏并更新其值
    - 将更新后的统计信息从叶节点回传到树的根节点:更新值和访问次数
    '''
    
    # 通过选择具有最大 U 值的节点来寻找叶节点
    current = self
    
    while current.child:
        child = current.child
        max_U = max(c.getUCBscore() for c in child.values())
        actions = [a for a, c in child.items() if c.getUCBscore() == max_U]
        
        if len(actions) == 0:
            print("error zero length ", max_U)                      
        
        action = random.choice(actions)
        current = child[action]
    
    # 执行随机游戏,或者如果需要,扩展节点
    if current.N < 1:
        current.T = current.T + current.rollout()
    else:
        current.create_child()
        if current.child:
            current = random.choice(current.child)
        current.T = current.T + current.rollout()
    
    current.N += 1      

    # 更新统计信息并回传
    parent = current
    
    while parent.parent:
        parent = parent.parent
        parent.N += 1
        parent.T = parent.T + current.T



树的搜索过程


  • 递归选择子节点
    • 从当前节点开始,递归地选择最大化 UCB 值的子节点,直到达到一个叶节点。
  • 叶节点的处理
    • 如果是第一次访问该叶节点,就进行一次随机游戏(rollout),并更新其值。
    • 如果该节点已经被访问过,扩展该节点,创建子节点,然后随机选择一个子节点进行随机游戏并更新其值。
  • 回传更新的统计信息
    • 无论哪种情况,都会将更新的统计信息(包括值和访问次数)回传到父节点,直到根节点。



算法工作流程总结



  • 选择最大 UCB 值的子节点
    • 递归地进行搜索直到找到叶节点。


  • 叶节点处理
    • 如果叶节点没有被访问过,则进行一次随机模拟并更新值。
    • 如果叶节点已被访问,则扩展节点并选择一个随机子节点进行模拟。


  • 回传统计信息
    • 更新每个节点的统计信息并回传,直到树的根节点。



定义随机回合



我们差不多完成了,接下来我们需要定义如何进行随机回合(rollout),使用当前节点的游戏副本(即当前游戏状态):


def rollout(self):
    '''
    回合是从当前节点的环境副本开始,使用随机动作进行的游戏模拟。
    这将为当前节点提供一个值。
    单独来看,这个值是随机的,但是,进行更多的回合后,
    该节点的平均值将会更加准确。这是 MCTS 算法的核心。
    '''
    
    if self.done:
        return 0        
    
    v = 0
    done = False
    new_game = deepcopy(self.game)
    while not done:
        action = new_game.action_space.sample()  # 随机选择一个动作
        observation, reward, done, _ = new_game.step(action)
        v = v + reward  # 累加奖励
        if done:
            new_game.reset()  # 游戏结束时重置游戏
            new_game.close()
            break             
    return v


回合过程:


  • 随机播放:从当前节点的游戏副本开始,随机选择动作进行游戏直到结束。


  • 回合值:游戏的累计奖励值 v 作为当前节点的估计值。进行更多回合会使得值更准确。


然后,我们需要定义如何在搜索完成后选择下一个动作:


def next(self):
    ''' 
    一旦在树中进行了足够的搜索,树中包含的值应该是统计上准确的。
    然后,我们将选择从当前节点开始的下一个动作,这就是这个函数的作用。
    选择动作的方法有多种,这个实现的策略如下:
    - 从具有最大访问次数的节点中随机选择一个,因为这意味着该节点有一个好的值。
    '''

    if self.done:
        raise ValueError("game has ended")

    if not self.child:
        raise ValueError('no children found and game hasn\'t ended')
    
    child = self.child
    
    # 找到最大访问次数的节点
    max_N = max(node.N for node in child.values())
    
    # 找到所有访问次数为最大值的子节点
    max_children = [c for a, c in child.items() if c.N == max_N]
    
    if len(max_children) == 0:
        print("error zero length ", max_N)
        
    max_child = random.choice(max_children)
    
    return max_child, max_child.action_index



选择下一个动作的过程:


  • 根据访问次数选择子节点:从当前节点的子节点中选择访问次数最多的子节点,这通常意味着该节点具有较好的值(因为 MCTS 公式会优先选择访问次数多的节点)。


  • 随机选择:如果多个子节点具有相同的最大访问次数,则随机选择其中一个。


通过这种方式,我们能够确保搜索过的节点更加准确,并最终选择最合适的动作。




开始玩游戏!



现在我们已经准备好了,接下来我们需要定义用于玩游戏的策略:


MCTS_POLICY_EXPLORE = 100  # MCTS探索常数:值越高,越可靠,但执行时间更长

def Policy_Player_MCTS(mytree):  

    '''
    我们使用 MCTS 的策略非常简单:
    - 为了从当前节点选择最佳动作:
        - 从该节点开始探索树,进行一定数量的迭代,以收集可靠的统计信息
        - 根据 MCTS,选择下一个最佳的动作节点
    '''
    
    for i in range(MCTS_POLICY_EXPLORE):
        mytree.explore()
        
    next_tree, next_action = mytree.next()
        
    # 注意这里我们将当前节点与其子树分离,并返回从选定动作开始的新树。
    # 下一次搜索将不会从头开始,而是已经收集了节点的信息和统计数据,
    # 所以我们可以重用这些统计数据,使得搜索结果更加可靠!
    next_tree.detach_parent()
    
    return next_tree, next_action


策略说明:


  • 探索树并选择最佳动作:每次需要做出游戏动作时,使用 MCTS 进行探索,并选择最佳可能的下一步动作。探索树的次数由固定参数 MCTS_POLICY_EXPLORE 控制,探索次数越多,结果越可靠。


  • 重用统计信息:每次选择动作后,都会通过 detach_parent() 将当前节点与其父节点分离,从而确保下一次搜索可以重用已有的统计信息。



接下来是运行我们的策略并查看效果的代码:


episodes = 10
rewards = []
moving_average = []

'''
在这里我们测试我们的实现:
- 玩一定数量的游戏回合
- 每次做出动作时,都会应用我们的 MCTS 算法
- 收集并绘制奖励,检查 MCTS 是否有效。
- 对于 CartPole-v0,最大可能奖励为 200。 
'''

for e in range(episodes):

    reward_e = 0    
    game = gym.make(GAME_NAME)
    observation = game.reset() 
    done = False
    
    new_game = deepcopy(game)
    mytree = Node(new_game, False, 0, observation, 0)
    
    print('episode #' + str(e+1))
    
    while not done:
    
        mytree, action = Policy_Player_MCTS(mytree)
        
        observation, reward, done, _ = game.step(action)  
                        
        reward_e = reward_e + reward
        
        # game.render()  # 如果你想看到智能体的动作,取消注释此行!
                
        if done:
            print('reward_e ' + str(reward_e))
            game.close()
            break
        
    rewards.append(reward_e)
    moving_average.append(np.mean(rewards[-100:]))
    
plt.plot(rewards)
plt.plot(moving_average)
plt.show()
print('moving average: ' + str(np.mean(rewards[-20:])))


运行流程:


  • 游戏回合数:我们设置了 episodes = 10,即进行 10 次游戏。


  • 奖励收集:在每次游戏结束后,收集奖励并将其存储在 rewards 列表中。同时,我们计算并绘制每 100 次回合的移动平均奖励,以检查 MCTS 算法是否有效。


  • 显示结果:最终绘制奖励和移动平均的曲线图,查看随着时间推移奖励的变化情况。


对于 CartPole-v0 环境,最大奖励为 200,若 MCTS 算法有效,我们应该看到随着回合数的增加,奖励逐渐接近这一最大值。


data/78df2c1f-e442-415d-a382-fa7925af0c4b/63d9e4de-d6ac-4c53-84ea-2792b7318c5aimage.png



结语



恭喜你!你做到了! 干就完了! 希望你在这其中找到了乐趣~



原文链接





评论