網(wǎng)上接單做衣服哪個網(wǎng)站如何創(chuàng)建一個自己的網(wǎng)站
離線學(xué)習(xí):不需要更新數(shù)據(jù)
CQL(Conservative Q-Learning)算法是一種用于離線強(qiáng)化學(xué)習(xí)的方法,它通過學(xué)習(xí)一個保守的Q函數(shù)來解決標(biāo)準(zhǔn)離線RL方法可能由于數(shù)據(jù)集和學(xué)習(xí)到的策略之間的分布偏移而導(dǎo)致的過高估計問題 。CQL算法的核心思想是在Q值的基礎(chǔ)上增加一個正則化項(xiàng)(regularizer),從而得到真實(shí)動作值函數(shù)的下界估計。這種方法在理論上被證明可以產(chǎn)生當(dāng)前策略的真實(shí)值下界,并且可以進(jìn)行策略評估和策略提升的過程 。
CQL算法通過修改值函數(shù)的備份方式,添加正則化項(xiàng)來實(shí)現(xiàn)保守性。在優(yōu)化過程中,CQL旨在找到一個Q函數(shù),該函數(shù)在給定策略下的期望值低于其真實(shí)值。這通過在Q學(xué)習(xí)的目標(biāo)函數(shù)中添加一個懲罰項(xiàng)來實(shí)現(xiàn),該懲罰項(xiàng)限制了策略π下Q函數(shù)的期望值不能偏離數(shù)據(jù)分布Q函數(shù)的期望值 。
CQL算法的實(shí)現(xiàn)相對簡單,只需要在現(xiàn)有的深度Q學(xué)習(xí)和行動者-評論家實(shí)現(xiàn)的基礎(chǔ)上添加少量代碼。在實(shí)驗(yàn)中,CQL在多個領(lǐng)域和數(shù)據(jù)集上的表現(xiàn)優(yōu)于現(xiàn)有的離線強(qiáng)化學(xué)習(xí)方法,尤其是在學(xué)習(xí)復(fù)雜和多模態(tài)數(shù)據(jù)分布時,通??梢允箤W(xué)習(xí)策略獲得2到5倍的最終回報 。
此外,CQL算法的一個關(guān)鍵優(yōu)勢是它提供了一種有效的解決方案,可以在不與環(huán)境進(jìn)行額外交互的情況下,利用先前收集的靜態(tài)數(shù)據(jù)集學(xué)習(xí)有效的策略。這使得CQL在自動駕駛和醫(yī)療機(jī)器人等領(lǐng)域具有潛在的應(yīng)用價值,這些領(lǐng)域中與環(huán)境的交互次數(shù)在成本和風(fēng)險方面都是有限的 。
總的來說,CQL算法通過其保守的Q函數(shù)估計和正則化策略,為離線強(qiáng)化學(xué)習(xí)領(lǐng)域提供了一種有效的策略學(xué)習(xí)框架,并在理論和實(shí)踐上都顯示出了其有效性
import gym
from matplotlib import pyplot as plt
import numpy as np
import random
%matplotlib inline
#創(chuàng)建環(huán)境
env = gym.make('Pendulum-v1')
env.reset()#打印游戲
def show():plt.imshow(env.render(mode='rgb_array'))plt.show()
定義sac模型,代碼略http://t.csdnimg.cn/ic2HX
定義teacher模型
#定義teacher模型
teacher = SAC()teacher.train(torch.tandn(5,3),torch.randn(5,1),torch.randn(5,1),torch.randn(5,3),torch.zeros(5,1).long(),
)
定義Data類
#樣本池
datas = []#向樣本池中添加N條數(shù)據(jù),刪除M條最古老的數(shù)據(jù)
def update_data():#初始化游戲state = env.reset()#玩到游戲結(jié)束為止over = Falsewhile not over:#根據(jù)當(dāng)前狀態(tài)得到一個動作action = get_action(state)#執(zhí)行當(dāng)作,得到反饋next_state,reward,over, _ = env.step([action])#記錄數(shù)據(jù)樣本datas.append((states,action,reward,next_state,over))#更新游戲狀態(tài),開始下一個當(dāng)作state = next_state#數(shù)據(jù)上限,超出時從最古老的開始刪除while len(datas)>10000:datas.pop(0)#獲取一批數(shù)據(jù)樣本
def get_sample():samples = random.sample(datas,64)#[b,4]state = torch.FloatTensor([i[0]for i in samples]).reshape(-1,3)#[b,1]action = torch.LongTensor([i[1]for i in samples]).reshape(-1,1)#[b,1]reward = torch.FloatTensor([i[2]for i in samples]).reshape(-1,1)#[b,4]next_state = torch.FloatTensor([i[3]for i in samples]).reshape(-1,3)#[b,1]over = torch.LongTensor([i[4]for i in samples]).reshape(-1,1)return state,action,reward,next_state,overstate,action,reward,next_state,over=get_sample()state[:5],action[:5],reward[:5],next_state[:5],over[:5]
data = Data()
data.update_data(teacher),data.get_sample()
訓(xùn)練teacher模型
#訓(xùn)練teacher模型
for epoch in range(100):#更新N條數(shù)據(jù)datat.update_data(teacher)#每次更新過數(shù)據(jù)后,學(xué)習(xí)N次for i in range(200):teacher.train(*data.get_sample())if epoch%10==0:test_result = sum([teacher.test(play=False)for _ in range(10)])/10print(epoch,test_result)
定義CQL模型
class CQL(SAC):def __init__(self):super().__init__()def _get_loss_value(self,model_value,target,state,action,next_state):#計算valuevalue = model_value(state,action)#計算loss,value的目標(biāo)是要貼近targetloss_value = self.loss_fn(value,tarfet)"""以上與SAC相同,以下是CQL部分"""#把state復(fù)制5彼遍state = state.unsqueeze(dim=1)state = state.repeat(1,5,1).reshape(-1,3)#把next_state復(fù)制5遍next_state = next_state.unsqueeze(1)next_state = next_state.repeat(1,5,1).reshape(-1,3)#隨機(jī)一批動作,數(shù)量是數(shù)據(jù)量的5倍,值域在-1到1之間rand_action = torch.empty([len(state),1]).uniform_(-1,1)#計算state的動作和熵curr_action,next_entropy = self..mdoel_action(next_state)#計算三方動作的valuevalue_rand = model_value(state,rand_action).reshape(-1,5,1)value_curr = model_value(state,curr_action).reshape(-1,5,1)value_next = model_value(state,next_action).reshape(-1,5,1)curr_entropy = curr_entropy.detach().reshape(-1,5,1)next_entropy = next_entropy.detach().reshape(-1,5,1)#三份value分別減去他們的熵value_rand -=mat.log(0.5)value_curr -=curr_entropyvalue_next -=next_entropy#拼合三份valuevalue_cat = torch.cat([value_rand,value_curr,value_next],dim=1)#等價t.logsumexp(dim=1),t.exp().sum(dim=1).log()loss_cat = torch.logsumexp(value_cat,dim =1).mean()#在原本的loss上增加上這一部分loss_value += 5.0*(loss_cat - value.mean())"""差異到此為止"""
學(xué)生模型
student = CQL()
student.train(torch.randn(5,3),torch.randn(5,1),torch.randn(5,1),torch.randn(5,3),torch.zeros(5,1)long(),
)
離線訓(xùn)練,訓(xùn)練過程中完全不更新數(shù)據(jù)
#訓(xùn)練N次,訓(xùn)練過程中不需要更新數(shù)據(jù)
for i in range(50000):#采樣一批數(shù)據(jù)student.train(*data.get_sample())if i%2000 ==0:test_result = sum([student.test(play = False) for _ in range(10)])print(i,test_result)