老虎机,投入钱币会随机返还钱币(reward)
这里设置两台老虎机,一台均值500,标准差5,一台均值550,标准差10
初始值均为998,更新规则为reward之和/轮数
最后结果会在均值附近收敛
import matplotlib.pyplot as plt
import numpy as np
def renew(value1, value2, round1, round2):
# 更新大者
if value1 >= value2:
choice = 1
reward1 = int(np.random.normal(loc=500, scale=5, size=None)) # 随机奖励
round1 += 1
value1 = int((value1 + reward1) / round1)
print('reward1:', reward1)
reward = reward1
else:
choice = 2
reward2 = int(np.random.normal(loc=550, scale=10, size=None)) # 随机奖励
round2 += 1
value2 = int((value2 + reward2) / round2)
print('reward2:', reward2)
reward = reward2
print('choice:', choice)
return value1, value2, choice, round1, round2, reward
if __name__ == '__main__':
value1, value2 = 998, 998
s1 = [value1]
s2 = [value2]
round1, round2 = 1, 1
for i in range(1, 100):
value1, value2, choice, round1, round2, reward = renew(value1, value2, round1, round2)
if choice == 1:
s1.append(reward)
else:
s2.append(reward)
print('round1:', round1, 'round2:', round2)
print('s1:', s1, 's2:', s2)
value1 = int(sum(s1)/round1)
value2 = int(sum(s2)/round2)
print('value1:', value1, 'value2:', value2)
print(' ')
plt.plot([i for i in range(len(s1))], s1, 'b', label='value1')
plt.plot([i for i in range(len(s2))], s2, 'g', label='value2')
plt.legend()
plt.show()