强化学习的学习之路(十)_2021-01-10:K臂老虎机介绍及其Python实现

作为一个新手,写这个教程也是想和大家分享一下自己学习强化学习的心路历程,希望对大家能有所帮助。这个系列后面会不断更新,希望自己能保证起码平均一天一更的速度,先是介绍强化学习的一些基础知识,后面介绍强化学习的相关论文。本来是想每一篇多更新一点内容的,后面想着大家看CSDN的话可能还是喜欢短一点的文章,就把很多拆分开来了,目录我单独放在一篇单独的博客里面了。完整的我整理好了会放在github上,大家一起互相学习啊!可能会有很多错漏,希望大家批评指正!
歇了几天就发现落下了好多!明天全部补上!

K臂老虎机介绍及其Python实现

如果大家想对K臂老虎机做一个比较深入的了解的话,建议大家去阅读这篇博客,作者写的挺清楚的,而且还推荐了很多的其他材料,我这里主要是对K臂老虎机做一个简要的介绍。

定义

K臂老虎机(Multi-armed bandit,简称MAB)最早的场景是在赌场里面。赌场里面有K台老虎机,每次去摇老虎机都需要一个代币,且老虎机都会以一定概率吐出钱,你手上如果有T个代币,也就是你一共可以摇T次,你怎么才能使你的期望回报最大?当然我们要先假设每个老虎机吐钱的概率是不一样的,不然你怎么摇就都是一样的了。

在这里插入图片描述

我们一般也将所有的下面这种形式的问题成为K臂老虎机问题: 你可以反复面对 k 种不同的选择或行动。在每次选择之后,你会收到一个数值奖励,该奖励取决于你选择的行动的固定概率分布。 你的目标是在一段时间内最大化预期的总奖励。

如果我们是贝叶斯人,我们在实际对老虎机进行操作之前其实对老虎机吐钱的概率就已经有了一个先验的分布,然后我们不断地进行试验,根据试验的结果取调整我们前面的分布;而如果我们是频率学家,那我们一开始对这些机器吐钱的概率其实是没有先验的,我们会通过实验去预测出每台机器吐钱的概率,然后根据这个概率去不断优化我们的决策。

但不管从哪种角度出发,K臂老虎机的问题其实就是一个探索与利用的问题,就比如说我们先进行来m次实验(m<T),发现了第一个臂吐钱的频率更高,那接下来我们是一直去摇第一个臂(利用:exploitation)还是说我们还去试着摇一下其他的臂(探索:exploration),从短期来看利用是好的,但是从长期来看探索是好的。

基本概念

在我们的K臂老虎机中,只要选择了该动作, k k k 个动作的每一个都有预期的或平均的奖励, 让我们 称之为该行动的 价值。我们将在时间步 t t t 选择的动作表示为 A t , A_{t}, At, 并将相应的奖励表示为 R t ∘ R_{t_{\circ}} Rt 然 后, 对于任意动作 a a a 的价值, 定义 q ∗ ( a ) q_{*}(a) q(a) 是给定 a a a 选择的预期奖励:
q ∗ ( a ) ≐ E [ R t ∣ A t = a ] q_{*}(a) \doteq \mathbb{E}\left[R_{t} \mid A_{t}=a\right] q(a)E[RtAt=a]
如果我们知道每个动作的价值, 那么解决 K臂老虎机将是轻而易举的:你总是选择具有最高价值 的动作。但是我们不知道实际动作价值, 尽管你可能有估计值。 我们将在时间步骤 t t t 的动作 a a a 的估计值表示为 Q t ( a ) Q_{t}(a) Qt(a) 。 我们希望 Q t ( a ) Q_{t}(a) Qt(a) 接近 q ∗ ( a ) q_{*}(a) q(a)

K臂老虎机的变种

我们在上面定义中介绍的K臂老虎机其实是最简单的一种场景,K臂老虎机还有很多其他的变形:

  • 如果那些臂的吐钱的概率分布在一开始就设定好了,而且之后不再改变,则称为oblivious adversary setting。
  • 如果那些臂吐钱的概率设定好了之后还会发生变化,那么称为adaptive adversary setting。
  • 如果把待推荐的商品作为MAB问题的arm,那么我们在推荐系统中我们就还需要考虑用户作为一个活生生的个体本身的兴趣点、偏好、购买力等因素都是不同的,也就是我们需要考虑同一臂在不同上下文中是不同的。
  • 如果每台老虎机每天摇的次数有上限,那我们就得到了一个Bandit with Knapsack问题。

greedy和 ϵ − g r e e d y \epsilon-greedy ϵgreedy

greedy(贪婪)的算法也就是选择具有最高估计值的动作之一: A t = argmax ⁡ a Q t ( a ) A_{t}=\underset{a}{\operatorname{argmax}} Q_{t}(a) At=aargmaxQt(a),也就是相当于我们只做exploitation ; 而 ϵ − g r e e d y \epsilon-greedy ϵgreedy以较小的概率 ϵ \epsilon ϵ地从具有相同概率的所有动作中随机选择, 相当于我们在做exploitation的同时也做一定程度的exploration。greedy的算法很容易陷入执行次优动作的怪圈,当reward的方差更大时,我们为了做更多的探索应该选择探索度更大的 ϵ − g r e e d y \epsilon-greedy ϵgreedy,但是当reward的方差很小时,我们可以选择更greedy的方法,在实际当中我们很多时候都会让 ϵ \epsilon ϵ 从一个较大的值降低到一个较小的值,比如说从1降低到0.1,相当于我们在前期基本上只做探索,后期只做利用。

softmax 方法

softamx是另一种兼顾探索与利用的方法,它既不像greedy算法那样贪婪,也没有像 ϵ − \epsilon- ϵ greedy那样在探索阶段做随机动作而是使用 softmax函数计算每一个arm被选中的概率,以更高的概率去摇下平均收益高的臂,以更地的概率去摇下平均收益低的臂。 a r m i a r m_{i} armi 表示第i 个手柄, U i \quad U_{i} Ui 表示手柄的平均收 益, k是手柄总数。
p ( a r m i ) = e u i ∑ j k e u i p\left(a r m_{i}\right)=\frac{e^{u_{i}}}{\sum_{j}^{k} e^{u_{i}}} p(armi)=jkeuieui
当然这里有一个问题是为什么要用softmax,我们直接用某一个臂得到的平均收益除以总的平均收益不行吗?我理解上感觉softmax方法是在agrmax方法和直接除这种方法之间的方法,因为softmax加上e之后其实会让平均收益低的臂和平均收益高的臂走向极端,也就是让策略越来越激进,甚至到最终收敛成argmax?而且我感觉图像分类里面经常用softmax一方面是因为求梯度比较好计算,另一方面是因为有时候softmax之前得到的分数可能有负数,那我们这里的好处好可以加上就是刚开始某一个臂的平均收益是0的时候我们依旧会有一定概率选它而不会像下面公式里面的这种一样不选它。
p ( a r m i ) = u i ∑ j k u i p\left(a r m_{i}\right)=\frac{{u_{i}}}{\sum_{j}^{k} {u_{i}}} p(armi)=jkuiui
所以总的来说softmax有三个好处:

  • 便于求梯度
  • 在刚开始某一个臂收益为0的时候这个臂依旧有被选上的可能
  • softmax算法让平均收益低的臂和平均收益高的臂走向极端,也就是让策略越来越激进,甚至到最终收敛成argmax,就有点像 ϵ − g r e e d y \epsilon-greedy ϵgreedy ϵ \epsilon ϵ不断下降一样。

一个简单的赌博机算法

在这里插入图片描述

循环的最后一步其实用到了

Q n + 1 = 1 n ∑ i = 1 n R i = 1 n ( R n + ∑ i = 1 n − 1 R i ) = 1 n ( R n + ( n − 1 ) 1 n − 1 ∑ i = 1 n − 1 R i ) = 1 n ( R n + ( n − 1 ) Q n ) = 1 n ( R n + n Q n − Q n ) = Q n + 1 n ( R n − Q n ) \begin{aligned} Q_{n+1} &=\frac{1}{n} \sum_{i=1}^{n} R_{i} \\ &=\frac{1}{n}\left(R_{n}+\sum_{i=1}^{n-1} R_{i}\right) \\ &=\frac{1}{n}\left(R_{n}+(n-1) \frac{1}{n-1} \sum_{i=1}^{n-1} R_{i}\right) \\ &=\frac{1}{n}\left(R_{n}+(n-1) Q_{n}\right) \\ &=\frac{1}{n}\left(R_{n}+n Q_{n}-Q_{n}\right) \\ &=Q_{n}+\frac{1}{n}\left(R_{n}-Q_{n}\right) \end{aligned} Qn+1=n1i=1nRi=n1(Rn+i=1n1Ri)=n1(Rn+(n1)n11i=1n1Ri)=n1(Rn+(n1)Qn)=n1(Rn+nQnQn)=Qn+n1(RnQn)

也就是:新估计←旧估计+步长[目标−旧估计]

Python 代码实现

在代码里面实现了 ϵ − g r e e d y \epsilon-greedy ϵgreedy、softmax,以及直接根据当前各个臂的平均收益去决策三种方法,完整的代码放在github上了,写的比较匆忙,后面会再更新一下放到github的仓库之中

# 作者:Yunhui
# 创建时间:2021/1/13 23:54
# IDE:PyCharm 
# encoding: utf-8

import random
import math
import numpy as np
import matplotlib.pyplot as plt

ARM_NUM = 5
E_GREEDY_FACTOR = 0.9
SEED = random.randint(1, 10000)
TEST_STEPS = 1000


class MAB:
    def __init__(self, arm_num: int) -> None:
        """
        :param arm_num:  the number of arms  臂的数量
        """
        self.arm_num = arm_num    # 设置臂的数量
        self.probability = dict({})  # 设置每个臂能摇出一块钱的概率
        self.try_time = dict({})  # 每个臂已经摇过的次数
        self.reward = dict({})  # 每个臂已经获得的钱
        self.reward_all = 0  # 所有臂获得的收益之和
        self.try_time_all = 0  # 总的尝试的次数

    def reset(self, seed: int) -> None:
        """
        Each arm is initialized, and each arm is set the same when passing in the same random seed
        对每一个臂进行初始化,当传入的随机种子一样时,每个臂的设置相同
        :param seed: random seed  传入的随机种子
        """
        print("We have %d arms" % self.arm_num)
        for num in range(self.arm_num):
            random.seed(num+seed)
            self.probability[str(num + 1)] = random.random()
            self.try_time[str(num + 1)] = 0
            self.reward[str(num + 1)] = 0

    def step(self, arm_id: str):
        """
        Change the arm according to the arm_id
        当传入每次要摇下的臂的编号后,老虎机的状态发生变化
        :param arm_id: the id of the arm in this step 这一步控制摇下杆的id
        """
        self.try_time[arm_id] += 1
        self.try_time_all += 1
        if random.random() < self.probability[arm_id]:
            self.reward[arm_id] += 1
            self.reward_all += 1

    def render(self):
        """
           draw the multi-armed bandit,including tried times and reward
           for each arm, and total tried times and rewards.
        """
        if self.arm_num <= 10:
            print('*' * 8 * (self.arm_num + 1) + '**')
            title = str(self.arm_num) + " arm bandit"
            title_format = '{:^' + str(8 * (self.arm_num + 1)) + 's}'
            print('*' + title_format.format(title) + '*')

            print('*' + ' ' * 8 * (self.arm_num + 1) + '*')

            print('*{:^8s}'.format('arm'), end='')
            for arm in range(self.arm_num):
                print('{:^8d}'.format(arm + 1), end='')
            print('*\n')

            print('*{:^8s}'.format('tried'), end='')
            for arm in range(self.arm_num):
                print('{:^8d}'.format(self.try_time[str(arm + 1)]), end='')
            print('*\n')

            print('*{:^8s}'.format('reward'), end='')
            for arm in range(self.arm_num):
                print('{:^8d}'.format(self.reward[str(arm + 1)]), end='')
            print('*\n')

            print('*' + title_format.format("total tried:" + str(self.try_time_all)) + '*')
            print('*' + title_format.format("total rewards:" + str(self.reward_all)) + '*')
            print('*' + ' ' * 8 * (self.arm_num + 1) + '*')
            print('*' * 8 * (self.arm_num + 1) + '**')


def e_greedy_method(mab):
    """
       e greedy method: define a e_greedy_factor and create a random number,
       when the random number is less then e_greedy_factor, then pick a arm
       randomly, else pick the arm with argmax q_table.
       :param mab: the class MBA
       :return: selected arm_id
    """
    q_table = []
    for arm_num in range(mab.arm_num):
        if mab.try_time[str(arm_num+1)] != 0:
            q_table.append(mab.reward[str(arm_num+1)]/mab.try_time[str(arm_num+1)])
        else:
            q_table.append(0)
    if random.random() < E_GREEDY_FACTOR:
        arm_id = random.randint(1, mab.arm_num)
    else:
        arm_id = np.argmax(q_table) + 1
    return arm_id


def sofmax_method(mab):
    """
    softmax method: calculate the softmax value of each arm's avarage reward,
    and pick the arm with greatest softmax value.
    :param mab: the class MBA
    :return: selected arm_id
    """
    exp_sum = 0
    softmax_list = []

    for arm_num in range(mab.arm_num):
        if mab.try_time[str(arm_num+1)] > 0:
            exp_sum += math.exp(mab.reward[str(arm_num+1)] / mab.try_time[str(arm_num+1)])
        else:
            exp_sum += math.exp(0)
    assert exp_sum > 0
    for arm_num in range(mab.arm_num):
        if mab.try_time[str(arm_num+1)] == 0:
            avg_reward_temp = 0
        else:
            avg_reward_temp = mab.reward[str(arm_num+1)] / mab.try_time[str(arm_num+1)]
        softmax_list.append(math.exp(avg_reward_temp) / exp_sum)
    arm_id = np.random.choice(mab.arm_num, 1, p=softmax_list)[0]
    print("The softmax list is", softmax_list)
    print("The id of returned arm is ", arm_id+1)
    return arm_id + 1


def average_method(mab):
    """
    decide the arm_id according to the average return of each arm but don't do the math.exp() operation like softmax
    :param mab: the class MBA
    :return: selected arm_id
    """
    sum_average = 0
    softmax_list = []

    for arm_num in range(mab.arm_num):
        if mab.try_time[str(arm_num + 1)] > 0:
            sum_average += (mab.reward[str(arm_num + 1)] / mab.try_time[str(arm_num + 1)])
        else:
            sum_average += 0
    if sum_average == 0:
        arm_id = np.random.choice(mab.arm_num) + 1
    else:
        for arm_num in range(mab.arm_num):
            if mab.try_time[str(arm_num + 1)] == 0:
                avg_reward_temp = 0
            else:
                avg_reward_temp = mab.reward[str(arm_num + 1)] / mab.try_time[str(arm_num + 1)]
            softmax_list.append(avg_reward_temp / sum_average)
        arm_id = np.random.choice(mab.arm_num, 1, p=softmax_list)[0]
    print("The softmax list is", softmax_list)
    print("The id of returned arm is ", arm_id + 1)
    return arm_id + 1


if __name__ == '__main__':
    reward_list = []
    mab_test = MAB(ARM_NUM)
    print("****Multi-armed Bandit***")
    mab_test.reset(SEED)
    mab_test.render()
    for i in range(TEST_STEPS):
        mab_test.step(str(average_method(mab_test)))
        reward_list.append(mab_test.reward_all/mab_test.try_time_all)
        if (i+1) % 20 == 0:
            print("We have test for %i times" % (i+1))
            mab_test.render()
    plt.plot(reward_list)
    plt.show()

热门文章

编程学习 ·

SpringBoot解决跨域

第一种:书写解决跨域的类public class AccessControlAllowOriginFilter implements Filter {@Overridepublic void init(FilterConfig filterConfig) throws ServletException {}@Overridepublic void doFilter(ServletRequest req, ServletResponse res, FilterChain chain) …
编程学习 ·

使用ssh连接window和 oracle virtualBox中的虚拟机 保姆级教程

目录环境基本连接步骤进一步配置hostname环境虚拟机 VM Virtualbox 6.1 虚拟机系统 debain 10.3虚拟机网卡:1.虚拟机网卡2.主机 windows 10配置好以上环境后开始配置虚拟机,当然host-only网络适配器的ip地址可以自己设置。 基本连接步骤 第一步: 虚拟机中运行 ps -e | gr…
编程学习 ·

springboot+idea+bootstrap的带有图片的表格编辑操作

前面已经写了 批量导入,图片显示,现在写的是批量修改,后面会写用echarts+springboot 做折线图,有时间贴上 1、jsp代码如下,编辑按钮formatter: function (value, row, index) {var edit = <input class="btn btn-primary" type="button" value=&qu…
编程学习 ·

期末复习、化学反应工程科目(第四章)

@Author:Runsen @Date:2020/7/1人生最重要的不是所站的位置,而是内心所朝的方向。只要我在每篇博文中写得自己体会,修炼身心;在每天的不断重复学习中,耐住寂寞,练就真功,不畏艰难,奋勇前行,不忘初心,砥砺前行,人生定会有所收获,不留遗憾 (作者:Runsen )作者介…
编程学习 ·

window.performance.navigation.type

performance.navigation.type(该属性返回一个整数值,表示网页的加载来源,可能有以下4种情况):0:网页通过点击链接、地址栏输入、表单提交、脚本操作等方式加载,相当于常数performance.navigation.TYPE_NAVIGATE。1:网页通过“重新加载”按钮或者location.reload()方法加…
编程学习 ·

数据结构:双向链表(1)

双向链表基本思想大体结构增加修改MyList测试遍历修改MyList删除修改MyList测试 基本思想 双向链表与单向链表大同小异,只不过双向链表还有个节点指向最后一个节点 大体结构 新建工程,结构如下package list; public class MyList {long size;Node firstNode;Node lastNode;pu…
编程学习 ·

一线互联网大厂300多道Java面试题【全面解析】,助你备战“金九银十”、进军BAT、斩获offer必备的核心知识点

前言今年因为疫情原因,很多人在家里宅了很长一段时间,“金三银四”黄金季也随之而然的“泡汤”,所有的跳槽涨薪的黄金季都集中在了“金九银十”季,所以程序员的竞争会对比往年更加激烈,为了备战“金九银十”需要有充足的时间复习筹备,为面试做足准备。我这里这筹备了一份…
编程学习 ·

Go map的增删改查及遍历

map的增删改查map 增加和更新map["key"] = value 如果 key 还没有,就是增加,如果 key 存在就是修改cities := make(map[string]string) cities["no1"] = "北京" cities["no2"] = "天津" cities["no3"] = "…
编程学习 ·

结构体学生信息输入

不知不觉学到第七章结构体了,这一章开始到后面的章节网上的免费课程就越来越少了。每次有不会的只能各种百度,心累。。。但还是会坚持的!!! 记录第7章课后习题第3题: 题目:编写一个函数print,打印一个学生的成绩数组,该数组中有5个学生的数据,每个学生的数据包括num(学…
编程学习 ·

Finereport不破解前提下解除并发数限制,突破官网2个并发限制

官方免费版具有全部系统功能,但是只有2个并发,也就是2个以内用户可以访问,第三个用户访问就会提示“未注册,无法访问”,本案例中6个用户,超过了限制,所以没法实际使用,仅仅玩玩还行这里提供一款软件,实现不对免费版进行任何修改,通过搭建特殊环境,突破2用户在线访问限…
编程学习 ·

Kafka中位移提交那些事儿

本文已收录GitHub,更有互联网大厂面试真题,面试攻略,高效学习资料等之前我们说过,Consumer 端有个位移的概念,它和消息在分区中的位移不是一回事儿,虽然它们的英文都是 Offset。今天我们要聊的位移是 Consumer 的消费位移,它记录了Consumer 要消费的下一条消息的位移。这…
编程学习 ·

简单动态字符串

SDS(simple synamic String)用作Redis默认字符串表示。C字符串只会作为字符串字面量用在一些无须对字符串进行修改的地方,例如打印日志等。 SDS定义 每个sds.h/sdshdr结构表示一个SDS值 struct sdshdr {//字符串的长度int len;// buf数组中未使用字节的数量int free;// 字节…
编程学习 ·

从永远到永远-SpringCloud项目实战(七)-前端框架NUXT

1、什么是服务端渲染 服务端渲染又称SSR (Server Side Render)是在服务端完成页面的内容,而不是在客户端通过AJAX获取数据。 服务器端渲染(SSR)的优势主要在于:更好的 SEO,由于搜索引擎爬虫抓取工具可以直接查看完全渲染的页面。 如果你的应用程序初始展示 loading 菊花图,…
编程学习 ·

LeetCode 226. 翻转二叉树

目录结构1.题目2.题解1.题目翻转一棵二叉树。示例:输入:4/ \2 7/ \ / \ 1 3 6 9输出:4/ \7 2/ \ / \ 9 6 3 1备注:这个问题是受到 Max Howell 的 原问题 启发的 :谷歌:我们90%的工程师使用您编写的软件(Homebrew),但是您却无法在面试时在白板上…
编程学习 ·

RocketMQ消费者之消息消费过程分析

心跳机制在Consumer启动后,它就会通过定时任务不断地向RocketMQ集群中的所有Broker实例发送心跳包心跳包内容包含了消息消费分组名称订阅关系集合消息通信模式客户端id的值Broker端在收到Consumer的心跳消息后,会将它维护在ConsumerManager的本地缓存变量—consumerTable,同…
编程学习 ·

千里之行 | 计算机基础要点

第一课:计算机基础要点一. 计算机基本概念1. 计算机是什么?2. 计算要机的组成二. 计算机语言1. 计算机语言的基本概念2. 计算机语言的发展三. 交互方式1. 交互方式的种类2. 文本交互模式的打开方式(Windows系统)3. DOS命令四. 文本文件和字符集1. 文本文件2. 常见字符集五. …
编程学习 ·

分布式计算(课堂测验+MOOC答案)

分布式计算(课堂测验+MOOC答案)分布式计算(课堂测验+MOOC答案)第一章节第二章节第三章节第四章节第五章节 分布式计算(课堂测验+MOOC答案) 第一章节第二章节第三章节 1、【单选题】在EC2服务中,每个实例自身携带 ()个存储模块。(A) A. 1 B. 2 C. 3 D. 4 2、【单选题】…