自拍偷在线精品自拍偷,亚洲欧美中文日韩v在线观看不卡

面向強(qiáng)化學(xué)習(xí)的狀態(tài)空間建模:RSSM的介紹和PyTorch實(shí)現(xiàn)

開發(fā) 前端
循環(huán)狀態(tài)空間模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在論文《Learning Latent Dynamics for Planning from Pixels》中提出。

循環(huán)狀態(tài)空間模型(Recurrent State Space Models, RSSM)最初由 Danijar Hafer 等人在論文《Learning Latent Dynamics for Planning from Pixels》中提出。該模型在現(xiàn)代基于模型的強(qiáng)化學(xué)習(xí)(Model-Based Reinforcement Learning, MBRL)中發(fā)揮著關(guān)鍵作用,其主要目標(biāo)是構(gòu)建可靠的環(huán)境動(dòng)態(tài)預(yù)測(cè)模型。通過這些學(xué)習(xí)得到的模型,智能體能夠模擬未來軌跡并進(jìn)行前瞻性的行為規(guī)劃。

下面我們就來用一個(gè)實(shí)際案例來介紹RSSM。

環(huán)境配置

環(huán)境配置是實(shí)現(xiàn)過程中的首要步驟。我們這里用易于使用的 Gym API。為了提高實(shí)現(xiàn)效率,設(shè)計(jì)了多個(gè)模塊化的包裝器(wrapper),用于初始化參數(shù)并將觀察結(jié)果調(diào)整為指定格式。

InitialWrapper 的設(shè)計(jì)允許在不執(zhí)行任何動(dòng)作的情況下進(jìn)行特定數(shù)量的觀察,同時(shí)支持在返回觀察結(jié)果之前多次重復(fù)同一動(dòng)作。這種設(shè)計(jì)對(duì)于響應(yīng)具有顯著延遲特性的環(huán)境特別有效。

PreprocessFrame 包裝器負(fù)責(zé)將觀察結(jié)果轉(zhuǎn)換為正確的數(shù)據(jù)類型(本文中使用 numpy 數(shù)組),并支持灰度轉(zhuǎn)換功能。

class InitialWrapper(gym.Wrapper):  
     def __init__(self, env: gym.Env, no_ops: int = 0, repeat: int = 1):  
         super(InitialWrapper, self).__init__(env)  
         self.repeat = repeat  
         self.no_ops = no_ops  
 
         self.op_counter = 0  
   
     def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]:  
         if self.op_counter < self.no_ops:  
             obs, reward, done, info = self.env.step(0)  
             self.op_counter += 1  
   
         total_reward = 0.0  
         done = False  
         for _ in range(self.repeat):  
             obs, reward, done, info = self.env.step(action)  
             total_reward += reward  
             if done:  
                 break  
   
         return obs, total_reward, done, info  
 
 
 class PreprocessFrame(gym.ObservationWrapper):  
     def __init__(self, env: gym.Env, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = False):  
         super(PreprocessFrame, self).__init__(env)  
         self.shape = new_shape  
         self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=self.shape, dtype=np.float32)  
         self.grayscale = grayscale  
   
         if self.grayscale:  
             self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(*self.shape[:-1], 1), dtype=np.float32)  
   
     def observation(self, obs: torch.Tensor) -> torch.Tensor:  
         obs = obs.astype(np.uint8)  
         new_frame = cv.resize(obs, self.shape[:-1], interpolation=cv.INTER_AREA)  
         if self.grayscale:  
             new_frame = cv.cvtColor(new_frame, cv.COLOR_RGB2GRAY)  
             new_frame = np.expand_dims(new_frame, -1)  
   
         torch_frame = torch.from_numpy(new_frame).float()  
         torch_frame = torch_frame / 255.0  
   
         return torch_frame  
   
 def make_env(env_name: str, new_shape: Sequence[int] = (128, 128, 3), grayscale: bool = True, **kwargs):  
     env = gym.make(env_name, **kwargs)  
     env = PreprocessFrame(env, new_shape, grayscale=grayscale)  
     return env

make_env 函數(shù)用于創(chuàng)建一個(gè)具有指定配置參數(shù)的環(huán)境實(shí)例。

模型架構(gòu)

RSSM 的實(shí)現(xiàn)依賴于多個(gè)關(guān)鍵模型組件。具體來說,需要實(shí)現(xiàn)以下四個(gè)核心模塊:

  • 原始觀察編碼器(Encoder)
  • 動(dòng)態(tài)模型(Dynamics Model):通過確定性狀態(tài) h 和隨機(jī)狀態(tài) s 對(duì)編碼觀察的時(shí)間依賴性進(jìn)行建模
  • 解碼器(Decoder):將隨機(jī)狀態(tài)和確定性狀態(tài)映射回原始觀察空間
  • 獎(jiǎng)勵(lì)模型(Reward Model):將隨機(jī)狀態(tài)和確定性狀態(tài)映射到獎(jiǎng)勵(lì)值

RSSM 模型組件結(jié)構(gòu)圖。模型包含隨機(jī)狀態(tài) s 和確定性狀態(tài) h。

編碼器實(shí)現(xiàn)

編碼器采用簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)結(jié)構(gòu),將輸入圖像降維到一維嵌入表示。實(shí)現(xiàn)中使用了 BatchNorm 來提升訓(xùn)練穩(wěn)定性。

class EncoderCNN(nn.Module):  
     def __init__(self, in_channels: int, embedding_dim: int = 2048, input_shape: Tuple[int, int] = (128, 128)):  
         super(EncoderCNN, self).__init__()  
         # 定義卷積層結(jié)構(gòu)
         self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1)  
         self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)  
         self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)  
         self.conv4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)  
   
         self.fc1 = nn.Linear(self._compute_conv_output((in_channels, input_shape[0], input_shape[1])), embedding_dim)  
   
         # 批標(biāo)準(zhǔn)化層
         self.bn1 = nn.BatchNorm2d(32)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(128)  
         self.bn4 = nn.BatchNorm2d(256)  
   
     def _compute_conv_output(self, shape: Tuple[int, int, int]):  
         with torch.no_grad():  
             x = torch.randn(1, shape[0], shape[1], shape[2])  
             x = self.conv1(x)  
             x = self.conv2(x)  
             x = self.conv3(x)  
             x = self.conv4(x)  
   
             return x.shape[1] * x.shape[2] * x.shape[3]  
 
     def forward(self, x):  
         x = torch.relu(self.conv1(x))  
         x = self.bn1(x)  
         x = torch.relu(self.conv2(x))  
         x = self.bn2(x)  
   
         x = torch.relu(self.conv3(x))  
         x = self.bn3(x)  
   
         x = self.conv4(x)  
         x = self.bn4(x)  
   
         x = x.view(x.size(0), -1)  
         x = self.fc1(x)  
   
         return x

解碼器實(shí)現(xiàn)

解碼器遵循傳統(tǒng)自編碼器架構(gòu)設(shè)計(jì),其功能是將編碼后的觀察結(jié)果重建回原始觀察空間。

class DecoderCNN(nn.Module):  
     def __init__(self, hidden_size: int, state_size: int,  embedding_size: int,  
                  use_bn: bool = True, output_shape: Tuple[int, int] = (3, 128, 128)):  
         super(DecoderCNN, self).__init__()  
   
         self.output_shape = output_shape  
   
         self.embedding_size = embedding_size  
         # 全連接層進(jìn)行特征變換
         self.fc1 = nn.Linear(hidden_size + state_size, embedding_size)  
         self.fc2 = nn.Linear(embedding_size, 256 * (output_shape[1] // 16) * (output_shape[2] // 16))  
   
         # 反卷積層進(jìn)行上采樣
         self.conv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)  # ×2  
         self.conv4 = nn.ConvTranspose2d(32, output_shape[0], kernel_size=3, stride=2, padding=1, output_padding=1)  
   
         # 批標(biāo)準(zhǔn)化層
         self.bn1 = nn.BatchNorm2d(128)  
         self.bn2 = nn.BatchNorm2d(64)  
         self.bn3 = nn.BatchNorm2d(32)  
   
         self.use_bn = use_bn  
 
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = self.fc1(x)  
         x = torch.relu(x)  
         x = self.fc2(x)  
   
         x = x.view(-1, 256, self.output_shape[1] // 16, self.output_shape[2] // 16)  
   
         if self.use_bn:  
             x = torch.relu(self.bn1(self.conv1(x)))  
             x = torch.relu(self.bn2(self.conv2(x)))  
             x = torch.relu(self.bn3(self.conv3(x)))  
   
         else:  
             x = torch.relu(self.conv1(x))  
             x = torch.relu(self.conv2(x))  
             x = torch.relu(self.conv3(x))  
   
         x = self.conv4(x)  
   
         return x

獎(jiǎng)勵(lì)模型實(shí)現(xiàn)

獎(jiǎng)勵(lì)模型采用了一個(gè)三層前饋神經(jīng)網(wǎng)絡(luò)結(jié)構(gòu),用于將隨機(jī)狀態(tài) s 和確定性狀態(tài) h 映射到正態(tài)分布參數(shù),進(jìn)而通過采樣獲得獎(jiǎng)勵(lì)預(yù)測(cè)。

class RewardModel(nn.Module):  
     def __init__(self, hidden_dim: int, state_dim: int):  
         super(RewardModel, self).__init__()  
   
         self.fc1 = nn.Linear(hidden_dim + state_dim, hidden_dim)  
         self.fc2 = nn.Linear(hidden_dim, hidden_dim)  
         self.fc3 = nn.Linear(hidden_dim, 2)  
   
     def forward(self, h: torch.Tensor, s: torch.Tensor):  
         x = torch.cat([h, s], dim=-1)  
         x = torch.relu(self.fc1(x))  
         x = torch.relu(self.fc2(x))  
         x = self.fc3(x)  
   
         return x

動(dòng)態(tài)模型的實(shí)現(xiàn)

動(dòng)態(tài)模型是 RSSM 架構(gòu)中最復(fù)雜的組件,需要同時(shí)處理先驗(yàn)和后驗(yàn)狀態(tài)轉(zhuǎn)移模型:

  1. 后驗(yàn)轉(zhuǎn)移模型:在能夠訪問真實(shí)觀察的情況下使用(主要在訓(xùn)練階段),用于在給定觀察和歷史狀態(tài)的條件下近似隨機(jī)狀態(tài)的后驗(yàn)分布。
  2. 先驗(yàn)轉(zhuǎn)移模型:用于近似先驗(yàn)狀態(tài)分布,僅依賴于前一時(shí)刻狀態(tài),不依賴于觀察。這在無法獲取后驗(yàn)觀察的推理階段使用。

這兩個(gè)模型均通過單層前饋網(wǎng)絡(luò)進(jìn)行參數(shù)化,輸出各自正態(tài)分布的均值和對(duì)數(shù)方差,用于狀態(tài) s 的采樣。該實(shí)現(xiàn)采用了簡單的網(wǎng)絡(luò)結(jié)構(gòu),但可以根據(jù)需要擴(kuò)展為更復(fù)雜的架構(gòu)。

確定性狀態(tài)采用門控循環(huán)單元(GRU)實(shí)現(xiàn)。其輸入包括:

  • 前一時(shí)刻的隱藏狀態(tài)
  • 獨(dú)熱編碼動(dòng)作
  • 前一時(shí)刻隨機(jī)狀態(tài) s(根據(jù)是否可以獲取觀察來選擇使用后驗(yàn)或先驗(yàn)狀態(tài))

這些輸入信息足以讓模型了解動(dòng)作歷史和系統(tǒng)狀態(tài)。以下是具體實(shí)現(xiàn)代碼:

class DynamicsModel(nn.Module):  
     def __init__(self, hidden_dim: int, action_dim: int, state_dim: int, embedding_dim: int, rnn_layer: int = 1):  
         super(DynamicsModel, self).__init__()  
   
         self.hidden_dim = hidden_dim  
           
         # 遞歸層實(shí)現(xiàn),支持多層 GRU
         self.rnn = nn.ModuleList([nn.GRUCell(hidden_dim, hidden_dim) for _ in range(rnn_layer)])  
           
         # 狀態(tài)動(dòng)作投影層
         self.project_state_action = nn.Linear(action_dim + state_dim, hidden_dim)  
           
         # 先驗(yàn)網(wǎng)絡(luò):輸出正態(tài)分布參數(shù)
         self.prior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_action = nn.Linear(hidden_dim + action_dim, hidden_dim)  
           
         # 后驗(yàn)網(wǎng)絡(luò):輸出正態(tài)分布參數(shù)
         self.posterior = nn.Linear(hidden_dim, state_dim * 2)  
         self.project_hidden_obs = nn.Linear(hidden_dim + embedding_dim, hidden_dim)  
   
         self.state_dim = state_dim  
         self.act_fn = nn.ReLU()  
   
     def forward(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor, actions: torch.Tensor,  
                 obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """  
        動(dòng)態(tài)模型的前向傳播
        參數(shù):  
            prev_hidden: RNN的前一隱藏狀態(tài),形狀 (batch_size, hidden_dim)  
            prev_state: 前一隨機(jī)狀態(tài),形狀 (batch_size, state_dim)  
            actions: 獨(dú)熱編碼動(dòng)作序列,形狀 (sequence_length, batch_size, action_dim)  
            obs: 編碼器輸出的觀察嵌入,形狀 (sequence_length, batch_size, embedding_dim)  
            dones: 終止?fàn)顟B(tài)標(biāo)志
        """  
         B, T, _ = actions.size()  # 用于無觀察訪問時(shí)的推理
   
         # 初始化存儲(chǔ)列表
         hiddens_list = []  
         posterior_means_list = []  
         posterior_logvars_list = []  
         prior_means_list = []  
         prior_logvars_list = []  
         prior_states_list = []  
         posterior_states_list = []  
           
         # 存儲(chǔ)初始狀態(tài)
         hiddens_list.append(prev_hidden.unsqueeze(1))    
         prior_states_list.append(prev_state.unsqueeze(1))  
         posterior_states_list.append(prev_state.unsqueeze(1))  
   
         # 時(shí)序展開
         for t in range(T - 1):  
             # 提取當(dāng)前時(shí)刻狀態(tài)和動(dòng)作
             action_t = actions[:, t, :]  
             obs_t = obs[:, t, :] if obs is not None else torch.zeros(B, self.embedding_dim, device=actions.device)  
             state_t = posterior_states_list[-1][:, 0, :] if obs is not None else prior_states_list[-1][:, 0, :]  
             state_t = state_t if dones is None else state_t * (1 - dones[:, t, :])  
             hidden_t = hiddens_list[-1][:, 0, :]  
               
             # 狀態(tài)動(dòng)作組合
             state_action = torch.cat([state_t, action_t], dim=-1)  
             state_action = self.act_fn(self.project_state_action(state_action))  
   
             # RNN 狀態(tài)更新
             for i in range(len(self.rnn)):  
                 hidden_t = self.rnn[i](state_action, hidden_t)  
   
             # 先驗(yàn)分布計(jì)算
             hidden_action = torch.cat([hidden_t, action_t], dim=-1)  
             hidden_action = self.act_fn(self.project_hidden_action(hidden_action))  
             prior_params = self.prior(hidden_action)  
             prior_mean, prior_logvar = torch.chunk(prior_params, 2, dim=-1)  
   
             # 從先驗(yàn)分布采樣
             prior_dist = torch.distributions.Normal(prior_mean, torch.exp(F.softplus(prior_logvar)))  
             prior_state_t = prior_dist.rsample()  
   
             # 后驗(yàn)分布計(jì)算
             if obs is None:  
                 posterior_mean = prior_mean  
                 posterior_logvar = prior_logvar  
             else:  
                 hidden_obs = torch.cat([hidden_t, obs_t], dim=-1)  
                 hidden_obs = self.act_fn(self.project_hidden_obs(hidden_obs))  
                 posterior_params = self.posterior(hidden_obs)  
                 posterior_mean, posterior_logvar = torch.chunk(posterior_params, 2, dim=-1)  
   
             # 從后驗(yàn)分布采樣
             posterior_dist = torch.distributions.Normal(posterior_mean, torch.exp(F.softplus(posterior_logvar)))  
             posterior_state_t = posterior_dist.rsample()  
   
             # 保存狀態(tài)
             posterior_means_list.append(posterior_mean.unsqueeze(1))  
             posterior_logvars_list.append(posterior_logvar.unsqueeze(1))  
             prior_means_list.append(prior_mean.unsqueeze(1))  
             prior_logvars_list.append(prior_logvar.unsqueeze(1))  
             prior_states_list.append(prior_state_t.unsqueeze(1))  
             posterior_states_list.append(posterior_state_t.unsqueeze(1))  
             hiddens_list.append(hidden_t.unsqueeze(1))  
   
         # 合并時(shí)序數(shù)據(jù)
         hiddens = torch.cat(hiddens_list, dim=1)  
         prior_states = torch.cat(prior_states_list, dim=1)  
         posterior_states = torch.cat(posterior_states_list, dim=1)  
         prior_means = torch.cat(prior_means_list, dim=1)  
         prior_logvars = torch.cat(prior_logvars_list, dim=1)  
         posterior_means = torch.cat(posterior_means_list, dim=1)  
         posterior_logvars = torch.cat(posterior_logvars_list, dim=1)  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars

需要特別注意的是,這里的觀察輸入并非原始觀察數(shù)據(jù),而是經(jīng)過編碼器處理后的嵌入表示。這種設(shè)計(jì)能夠有效降低計(jì)算復(fù)雜度并提升模型的泛化能力。

RSSM 整體架構(gòu)

將前述組件整合為完整的 RSSM 模型。其核心是 generate_rollout 方法,負(fù)責(zé)調(diào)用動(dòng)態(tài)模型并生成環(huán)境動(dòng)態(tài)的潛在表示序列。對(duì)于沒有歷史潛在狀態(tài)的情況(通常發(fā)生在軌跡開始時(shí)),該方法會(huì)進(jìn)行必要的初始化。下面是完整的實(shí)現(xiàn)代碼:

class RSSM:  
     def __init__(self,  
                  encoder: EncoderCNN,  
                  decoder: DecoderCNN,  
                  reward_model: RewardModel,  
                  dynamics_model: nn.Module,  
                  hidden_dim: int,  
                  state_dim: int,  
                  action_dim: int,  
                  embedding_dim: int,  
                  device: str = "mps"):  
         """  
        循環(huán)狀態(tài)空間模型(RSSM)實(shí)現(xiàn)
         
        參數(shù):
            encoder: 確定性狀態(tài)編碼器
            decoder: 觀察重構(gòu)解碼器
            reward_model: 獎(jiǎng)勵(lì)預(yù)測(cè)模型
            dynamics_model: 狀態(tài)動(dòng)態(tài)模型
            hidden_dim: RNN 隱藏層維度
            state_dim: 隨機(jī)狀態(tài)維度
            action_dim: 動(dòng)作空間維度
            embedding_dim: 觀察嵌入維度
            device: 計(jì)算設(shè)備
        """  
         super(RSSM, self).__init__()  
   
         # 模型組件初始化
         self.dynamics = dynamics_model  
         self.encoder = encoder  
         self.decoder = decoder  
         self.reward_model = reward_model  
   
         # 維度參數(shù)存儲(chǔ)
         self.hidden_dim = hidden_dim  
         self.state_dim = state_dim  
         self.action_dim = action_dim  
         self.embedding_dim = embedding_dim  
   
         # 模型遷移至指定設(shè)備
         self.dynamics.to(device)  
         self.encoder.to(device)  
         self.decoder.to(device)  
         self.reward_model.to(device)  
   
     def generate_rollout(self, actions: torch.Tensor, hiddens: torch.Tensor = None, states: torch.Tensor = None,  
                          obs: torch.Tensor = None, dones: torch.Tensor = None):  
         """
        生成狀態(tài)序列展開
         
        參數(shù):
            actions: 動(dòng)作序列
            hiddens: 初始隱藏狀態(tài)(可選)
            states: 初始隨機(jī)狀態(tài)(可選)
            obs: 觀察序列(可選)
            dones: 終止標(biāo)志序列
             
        返回:
            完整的狀態(tài)展開序列
        """
         # 狀態(tài)初始化
         if hiddens is None:  
             hiddens = torch.zeros(actions.size(0), self.hidden_dim).to(actions.device)  
   
         if states is None:  
             states = torch.zeros(actions.size(0), self.state_dim).to(actions.device)  
   
         # 執(zhí)行動(dòng)態(tài)模型展開
         dynamics_result = self.dynamics(hiddens, states, actions, obs, dones)  
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars = dynamics_result  
   
         return hiddens, prior_states, posterior_states, prior_means, prior_logvars, posterior_means, posterior_logvars  
   
     def train(self):  
         """啟用訓(xùn)練模式"""
         self.dynamics.train()  
         self.encoder.train()  
         self.decoder.train()  
         self.reward_model.train()  
   
     def eval(self):  
         """啟用評(píng)估模式"""
         self.dynamics.eval()  
         self.encoder.eval()  
         self.decoder.eval()  
         self.reward_model.eval()  
   
     def encode(self, obs: torch.Tensor):  
         """觀察編碼"""
         return self.encoder(obs)  
   
     def decode(self, state: torch.Tensor):  
         """狀態(tài)解碼為觀察"""
         return self.decoder(state)  
   
     def predict_reward(self, h: torch.Tensor, s: torch.Tensor):  
         """獎(jiǎng)勵(lì)預(yù)測(cè)"""
         return self.reward_model(h, s)  
   
     def parameters(self):  
         """返回所有可訓(xùn)練參數(shù)"""
         return list(self.dynamics.parameters()) + list(self.encoder.parameters()) + \
                list(self.decoder.parameters()) + list(self.reward_model.parameters())  
   
     def save(self, path: str):  
         """模型狀態(tài)保存"""
         torch.save({  
             "dynamics": self.dynamics.state_dict(),  
             "encoder": self.encoder.state_dict(),  
             "decoder": self.decoder.state_dict(),  
             "reward_model": self.reward_model.state_dict()  
        }, path)  
   
     def load(self, path: str):  
         """模型狀態(tài)加載"""
         checkpoint = torch.load(path)  
         self.dynamics.load_state_dict(checkpoint["dynamics"])  
         self.encoder.load_state_dict(checkpoint["encoder"])  
         self.decoder.load_state_dict(checkpoint["decoder"])  
         self.reward_model.load_state_dict(checkpoint["reward_model"])

這個(gè)實(shí)現(xiàn)提供了一個(gè)完整的 RSSM 框架,包含了模型的訓(xùn)練、評(píng)估、狀態(tài)保存和加載等基本功能。該框架可以作為基礎(chǔ)結(jié)構(gòu),根據(jù)具體應(yīng)用場(chǎng)景進(jìn)行擴(kuò)展和優(yōu)化。

訓(xùn)練系統(tǒng)設(shè)計(jì)

RSSM 的訓(xùn)練系統(tǒng)主要包含兩個(gè)核心組件:經(jīng)驗(yàn)回放緩沖區(qū)(Experience Replay Buffer)和智能體(Agent)。其中,緩沖區(qū)負(fù)責(zé)存儲(chǔ)歷史經(jīng)驗(yàn)數(shù)據(jù)用于訓(xùn)練,而智能體則作為環(huán)境與 RSSM 之間的接口,實(shí)現(xiàn)數(shù)據(jù)收集策略。

經(jīng)驗(yàn)回放緩沖區(qū)實(shí)現(xiàn)

緩沖區(qū)采用循環(huán)隊(duì)列結(jié)構(gòu),用于存儲(chǔ)和管理觀察、動(dòng)作、獎(jiǎng)勵(lì)和終止?fàn)顟B(tài)等數(shù)據(jù)。通過 sample 方法可以隨機(jī)采樣訓(xùn)練序列。

class Buffer:  
     def __init__(self, buffer_size: int, obs_shape: tuple, action_shape: tuple, device: torch.device):  
         """
        經(jīng)驗(yàn)回放緩沖區(qū)初始化
         
        參數(shù):
            buffer_size: 緩沖區(qū)容量
            obs_shape: 觀察數(shù)據(jù)維度
            action_shape: 動(dòng)作數(shù)據(jù)維度
            device: 計(jì)算設(shè)備
        """
         self.buffer_size = buffer_size  
         self.obs_buffer = np.zeros((buffer_size, *obs_shape), dtype=np.float32)  
         self.action_buffer = np.zeros((buffer_size, *action_shape), dtype=np.int32)  
         self.reward_buffer = np.zeros((buffer_size, 1), dtype=np.float32)  
         self.done_buffer = np.zeros((buffer_size, 1), dtype=np.bool_)  
   
         self.device = device  
         self.idx = 0  
   
     def add(self, obs: torch.Tensor, action: int, reward: float, done: bool):  
         """
        添加單步經(jīng)驗(yàn)數(shù)據(jù)
        """
         self.obs_buffer[self.idx] = obs  
         self.action_buffer[self.idx] = action  
         self.reward_buffer[self.idx] = reward  
         self.done_buffer[self.idx] = done  
         self.idx = (self.idx + 1) % self.buffer_size  
 
     def sample(self, batch_size: int, sequence_length: int):  
         """
        隨機(jī)采樣經(jīng)驗(yàn)序列
         
        參數(shù):
            batch_size: 批量大小
            sequence_length: 序列長度
             
        返回:
            經(jīng)驗(yàn)數(shù)據(jù)元組 (observations, actions, rewards, dones)
        """
         # 隨機(jī)選擇序列起始位置
         starting_idxs = np.random.randint(0, (self.idx % self.buffer_size) - sequence_length, (batch_size,))  
         
         # 構(gòu)建完整序列索引
         index_tensor = np.stack([np.arange(start, start + sequence_length) for start in starting_idxs])  
         
         # 提取數(shù)據(jù)序列
         obs_sequence = self.obs_buffer[index_tensor]  
         action_sequence = self.action_buffer[index_tensor]  
         reward_sequence = self.reward_buffer[index_tensor]  
         done_sequence = self.done_buffer[index_tensor]  
   
         return obs_sequence, action_sequence, reward_sequence, done_sequence  
 
     def save(self, path: str):  
         """保存緩沖區(qū)數(shù)據(jù)"""
         np.savez(path, obs_buffer=self.obs_buffer, action_buffer=self.action_buffer,  
                  reward_buffer=self.reward_buffer, done_buffer=self.done_buffer, idx=self.idx)  
   
     def load(self, path: str):  
         """加載緩沖區(qū)數(shù)據(jù)"""
         data = np.load(path)  
         self.obs_buffer = data["obs_buffer"]  
         self.action_buffer = data["action_buffer"]  
         self.reward_buffer = data["reward_buffer"]  
         self.done_buffer = data["done_buffer"]  
         self.idx = data["idx"]

智能體設(shè)計(jì)

智能體實(shí)現(xiàn)了數(shù)據(jù)收集和規(guī)劃功能。當(dāng)前實(shí)現(xiàn)采用了簡單的隨機(jī)策略進(jìn)行數(shù)據(jù)收集,但該框架支持?jǐn)U展更復(fù)雜的策略。

class Policy(ABC):  
     """策略基類"""
     @abstractmethod  
     def __call__(self, obs):  
         pass  
   
 class RandomPolicy(Policy):  
     """隨機(jī)采樣策略"""
     def __init__(self, env: Env):  
         self.env = env  
   
     def __call__(self, obs):  
         return self.env.action_space.sample()  
 
 class Agent:  
     def __init__(self, env: Env, rssm: RSSM, buffer_size: int = 100000,
                  collection_policy: str = "random", device="mps"):  
         """
        智能體初始化
         
        參數(shù):
            env: 環(huán)境實(shí)例
            rssm: RSSM模型實(shí)例
            buffer_size: 經(jīng)驗(yàn)緩沖區(qū)大小
            collection_policy: 數(shù)據(jù)收集策略類型
            device: 計(jì)算設(shè)備
        """
         self.env = env  
         # 策略選擇
         match collection_policy:  
             case "random":  
                 self.rollout_policy = RandomPolicy(env)  
             case _:  
                 raise ValueError("Invalid rollout policy")  
   
         self.buffer = Buffer(buffer_size, env.observation_space.shape,
                            env.action_space.shape, device=device)  
         self.rssm = rssm  
   
     def data_collection_action(self, obs):  
         """執(zhí)行數(shù)據(jù)收集動(dòng)作"""
         return self.rollout_policy(obs)  
   
     def collect_data(self, num_steps: int):  
         """
        收集訓(xùn)練數(shù)據(jù)
         
        參數(shù):
            num_steps: 收集步數(shù)
        """
         obs = self.env.reset()  
         done = False  
   
         iterator = tqdm(range(num_steps), desc="Data Collection")  
         for _ in iterator:  
             action = self.data_collection_action(obs)  
             next_obs, reward, done, _, _ = self.env.step(action)  
             self.buffer.add(next_obs, action, reward, done)  
             obs = next_obs  
             if done:  
                 obs = self.env.reset()  
   
     def imagine_rollout(self, prev_hidden: torch.Tensor, prev_state: torch.Tensor,
                        actions: torch.Tensor):  
         """
        執(zhí)行想象展開
         
        參數(shù):
            prev_hidden: 前一隱藏狀態(tài)
            prev_state: 前一隨機(jī)狀態(tài)
            actions: 動(dòng)作序列
             
        返回:
            完整的模型輸出,包括隱藏狀態(tài)、先驗(yàn)狀態(tài)、后驗(yàn)狀態(tài)等
        """
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
         posterior_means, posterior_logvars = self.rssm.generate_rollout(
             actions, prev_hidden, prev_state)  
   
         # 在想象階段使用先驗(yàn)狀態(tài)預(yù)測(cè)獎(jiǎng)勵(lì)
         rewards = self.rssm.predict_reward(hiddens, prior_states)  
   
         return hiddens, prior_states, posterior_states, prior_means, \
                prior_logvars, posterior_means, posterior_logvars, rewards  
   
     def plan(self, num_steps: int, prev_hidden: torch.Tensor,
              prev_state: torch.Tensor, actions: torch.Tensor):  
         """
        執(zhí)行規(guī)劃
         
        參數(shù):
            num_steps: 規(guī)劃步數(shù)
            prev_hidden: 初始隱藏狀態(tài)
            prev_state: 初始隨機(jī)狀態(tài)
            actions: 動(dòng)作序列
             
        返回:
            規(guī)劃得到的隱藏狀態(tài)和先驗(yàn)狀態(tài)序列
        """
         hidden_states = []  
         prior_states = []  
   
         hiddens = prev_hidden  
         states = prev_state  
   
         for _ in range(num_steps):  
             hiddens, states, _, _, _, _, _, _ = self.imagine_rollout(
                 hiddens, states, actions)  
             hidden_states.append(hiddens)  
             prior_states.append(states)  
   
         hidden_states = torch.stack(hidden_states)  
         prior_states = torch.stack(prior_states)  
   
         return hidden_states, prior_states

這部分實(shí)現(xiàn)提供了完整的數(shù)據(jù)管理和智能體交互框架。通過經(jīng)驗(yàn)回放緩沖區(qū),可以高效地存儲(chǔ)和重用歷史數(shù)據(jù);通過智能體的抽象策略接口,可以方便地?cái)U(kuò)展不同的數(shù)據(jù)收集策略。同時(shí)智能體還實(shí)現(xiàn)了基于模型的想象展開和規(guī)劃功能,為后續(xù)的決策制定提供了基礎(chǔ)。

訓(xùn)練器實(shí)現(xiàn)與實(shí)驗(yàn)

訓(xùn)練器設(shè)計(jì)

訓(xùn)練器是 RSSM 實(shí)現(xiàn)中的最后一個(gè)關(guān)鍵組件,負(fù)責(zé)協(xié)調(diào)模型訓(xùn)練過程。訓(xùn)練器接收 RSSM 模型、智能體、優(yōu)化器等組件,并實(shí)現(xiàn)具體的訓(xùn)練邏輯。

logging.basicConfig(  
     level=logging.INFO,  
     format="%(asctime)s - %(levelname)s - %(message)s",  
     handlers=[  
         logging.StreamHandler(),  # 控制臺(tái)輸出
         logging.FileHandler("training.log", mode="w")  # 文件輸出
    ]  
 )  
   
 logger = logging.getLogger(__name__)  
 
 class Trainer:  
     def __init__(self, rssm: RSSM, agent: Agent, optimizer: torch.optim.Optimizer,
                  device: torch.device):  
         """
        訓(xùn)練器初始化
         
        參數(shù):
            rssm: RSSM 模型實(shí)例
            agent: 智能體實(shí)例
            optimizer: 優(yōu)化器實(shí)例
            device: 計(jì)算設(shè)備
        """
         self.rssm = rssm  
         self.optimizer = optimizer  
         self.device = device  
         self.agent = agent  
         self.writer = SummaryWriter()  # tensorboard 日志記錄器
   
     def train_batch(self, batch_size: int, seq_len: int, iteration: int,
                    save_images: bool = False):  
         """
        單批次訓(xùn)練
         
        參數(shù):
            batch_size: 批量大小
            seq_len: 序列長度
            iteration: 當(dāng)前迭代次數(shù)
            save_images: 是否保存重建圖像
        """
         # 采樣訓(xùn)練數(shù)據(jù)
         obs, actions, rewards, dones = self.agent.buffer.sample(batch_size, seq_len)  
   
         # 數(shù)據(jù)預(yù)處理
         actions = torch.tensor(actions).long().to(self.device)  
         actions = F.one_hot(actions, self.rssm.action_dim).float()  
         obs = torch.tensor(obs, requires_grad=True).float().to(self.device)  
         rewards = torch.tensor(rewards, requires_grad=True).float().to(self.device)  
         dones = torch.tensor(dones).float().to(self.device)  
   
         # 觀察編碼
         encoded_obs = self.rssm.encoder(obs.reshape(-1, *obs.shape[2:]).permute(0, 3, 1, 2))  
         encoded_obs = encoded_obs.reshape(batch_size, seq_len, -1)  
   
         # 執(zhí)行 RSSM 展開
         rollout = self.rssm.generate_rollout(actions, obs=encoded_obs, dones=dones)  
         hiddens, prior_states, posterior_states, prior_means, prior_logvars, \
         posterior_means, posterior_logvars = rollout  
   
         # 重構(gòu)觀察
         hiddens_reshaped = hiddens.reshape(batch_size * seq_len, -1)  
         posterior_states_reshaped = posterior_states.reshape(batch_size * seq_len, -1)  
         decoded_obs = self.rssm.decoder(hiddens_reshaped, posterior_states_reshaped)  
         decoded_obs = decoded_obs.reshape(batch_size, seq_len, *obs.shape[-3:])  
   
         # 獎(jiǎng)勵(lì)預(yù)測(cè)
         reward_params = self.rssm.reward_model(hiddens, posterior_states)  
         mean, logvar = torch.chunk(reward_params, 2, dim=-1)  
         logvar = F.softplus(logvar)  
         reward_dist = Normal(mean, torch.exp(logvar))  
         predicted_rewards = reward_dist.rsample()  
   
         # 可視化
         if save_images:  
             batch_idx = np.random.randint(0, batch_size)  
             seq_idx = np.random.randint(0, seq_len - 3)  
             fig = self._visualize(obs, decoded_obs, rewards, predicted_rewards,
                                 batch_idx, seq_idx, iteration, grayscale=True)  
             if not os.path.exists("reconstructions"):  
                 os.makedirs("reconstructions")  
             fig.savefig(f"reconstructions/iteration_{iteration}.png")  
             self.writer.add_figure("Reconstructions", fig, iteration)  
             plt.close(fig)  
   
         # 計(jì)算損失
         reconstruction_loss = self._reconstruction_loss(decoded_obs, obs)  
         kl_loss = self._kl_loss(prior_means, F.softplus(prior_logvars),
                                posterior_means, F.softplus(posterior_logvars))  
         reward_loss = self._reward_loss(rewards, predicted_rewards)  
   
         loss = reconstruction_loss + kl_loss + reward_loss  
   
         # 反向傳播和優(yōu)化
         self.optimizer.zero_grad()  
         loss.backward()  
         nn.utils.clip_grad_norm_(self.rssm.parameters(), 1, norm_type=2)  
         self.optimizer.step()  
   
         return loss.item(), reconstruction_loss.item(), kl_loss.item(), reward_loss.item()  
   
     def train(self, iterations: int, batch_size: int, seq_len: int):  
         """
        執(zhí)行完整訓(xùn)練過程
         
        參數(shù):
            iterations: 迭代總次數(shù)
            batch_size: 批量大小
            seq_len: 序列長度
        """
         self.rssm.train()  
         iterator = tqdm(range(iterations), desc="Training", total=iterations)  
         losses = []  
         infos = []  
         last_loss = float("inf")  
         
         for i in iterator:  
             # 執(zhí)行單批次訓(xùn)練
             loss, reconstruction_loss, kl_loss, reward_loss = self.train_batch(
                 batch_size, seq_len, i, save_images=i % 100 == 0)  
   
             # 記錄訓(xùn)練指標(biāo)
             self.writer.add_scalar("Loss", loss, i)  
             self.writer.add_scalar("Reconstruction Loss", reconstruction_loss, i)  
             self.writer.add_scalar("KL Loss", kl_loss, i)  
             self.writer.add_scalar("Reward Loss", reward_loss, i)  
   
             # 保存最佳模型
             if loss < last_loss:  
                 self.rssm.save("rssm.pth")  
                 last_loss = loss  
   
             # 記錄詳細(xì)信息
             info = {  
                 "Loss": loss,  
                 "Reconstruction Loss": reconstruction_loss,  
                 "KL Loss": kl_loss,  
                 "Reward Loss": reward_loss  
            }  
             losses.append(loss)  
             infos.append(info)  
   
             # 定期輸出訓(xùn)練狀態(tài)
             if i % 10 == 0:  
                 logger.info("\n----------------------------")  
                 logger.info(f"Iteration: {i}")  
                 logger.info(f"Loss: {loss:.4f}")  
                 logger.info(f"Running average last 20 losses: {sum(losses[-20:]) / 20: .4f}")  
                 logger.info(f"Reconstruction Loss: {reconstruction_loss:.4f}")  
                 logger.info(f"KL Loss: {kl_loss:.4f}")  
                 logger.info(f"Reward Loss: {reward_loss:.4f}")
 
 ### 實(shí)驗(yàn)示例
 
 以下是一個(gè)在 CarRacing 環(huán)境中訓(xùn)練 RSSM 的完整示例:
 
 ```python
 # 環(huán)境初始化
 env = make_env("CarRacing-v2", render_mode="rgb_array", continuous=False, grayscale=True)  
 
 # 模型參數(shù)設(shè)置
 hidden_size = 1024  
 embedding_dim = 1024  
 state_dim = 512  
 
 # 模型組件實(shí)例化
 encoder = EncoderCNN(in_channels=1, embedding_dim=embedding_dim)  
 decoder = DecoderCNN(hidden_size=hidden_size, state_size=state_dim,
                      embedding_size=embedding_dim, output_shape=(1,128,128))  
 reward_model = RewardModel(hidden_dim=hidden_size, state_dim=state_dim)  
 dynamics_model = DynamicsModel(hidden_dim=hidden_size, state_dim=state_dim,
                               action_dim=5, embedding_dim=embedding_dim)  
 
 # RSSM 模型構(gòu)建
 rssm = RSSM(dynamics_model=dynamics_model,  
             encoder=encoder,  
             decoder=decoder,  
             reward_model=reward_model,  
             hidden_dim=hidden_size,  
             state_dim=state_dim,  
             action_dim=5,  
             embedding_dim=embedding_dim)  
 
 # 訓(xùn)練設(shè)置
 optimizer = torch.optim.Adam(rssm.parameters(), lr=1e-3)  
 agent = Agent(env, rssm)  
 trainer = Trainer(rssm, agent, optimizer=optimizer, device="cuda")  
 
 # 數(shù)據(jù)收集和訓(xùn)練
 trainer.collect_data(20000)  # 收集 20000 步經(jīng)驗(yàn)數(shù)據(jù)
 trainer.save_buffer("buffer.npz")  # 保存經(jīng)驗(yàn)緩沖區(qū)
 trainer.train(10000, 32, 20)  # 執(zhí)行 10000 次迭代訓(xùn)練

總結(jié)

本文詳細(xì)介紹了基于 PyTorch 實(shí)現(xiàn) RSSM 的完整過程。RSSM 的架構(gòu)相比傳統(tǒng)的 VAE 或 RNN 更為復(fù)雜,這主要源于其混合了隨機(jī)和確定性狀態(tài)的特性。通過手動(dòng)實(shí)現(xiàn)這一架構(gòu),我們可以深入理解其背后的理論基礎(chǔ)及其強(qiáng)大之處。RSSM 能夠遞歸地生成未來潛在狀態(tài)軌跡,這為智能體的行為規(guī)劃提供了基礎(chǔ)。

實(shí)現(xiàn)的優(yōu)點(diǎn)在于其計(jì)算負(fù)載適中,可以在單個(gè)消費(fèi)級(jí) GPU 上進(jìn)行訓(xùn)練,在有充足時(shí)間的情況下甚至可以在 CPU 上運(yùn)行。這一工作基于論文《Learning Latent Dynamics for Planning from Pixels》,該論文為 RSSM 類動(dòng)態(tài)模型奠定了基礎(chǔ)。后續(xù)的研究工作如《Dream to Control: Learning Behaviors by Latent Imagination》進(jìn)一步發(fā)展了這一架構(gòu)。這些改進(jìn)的架構(gòu)將在未來的研究中深入探討,因?yàn)樗鼈儗?duì)理解 MBRL 方法提供了重要的見解。

責(zé)任編輯:華軒 來源: DeepHub IMBA
相關(guān)推薦

2023-03-23 16:30:53

PyTorchDDPG算法

2020-08-10 06:36:21

強(qiáng)化學(xué)習(xí)代碼深度學(xué)習(xí)

2019-09-29 10:42:02

人工智能機(jī)器學(xué)習(xí)技術(shù)

2022-05-31 10:45:01

深度學(xué)習(xí)防御

2024-01-26 08:31:49

2022-03-25 10:35:20

機(jī)器學(xué)習(xí)深度學(xué)習(xí)強(qiáng)化學(xué)習(xí)

2020-11-12 19:31:41

強(qiáng)化學(xué)習(xí)人工智能機(jī)器學(xué)習(xí)

2021-09-17 15:54:41

深度學(xué)習(xí)機(jī)器學(xué)習(xí)人工智能

2023-06-25 11:30:47

可視化

2023-01-24 17:03:13

強(qiáng)化學(xué)習(xí)算法機(jī)器人人工智能

2025-01-03 11:46:31

2022-11-02 14:02:02

強(qiáng)化學(xué)習(xí)訓(xùn)練

2020-06-05 08:09:01

Python強(qiáng)化學(xué)習(xí)框架

2023-03-09 08:00:00

強(qiáng)化學(xué)習(xí)機(jī)器學(xué)習(xí)圍棋

2023-07-20 15:18:42

2024-12-09 08:45:00

模型AI

2023-08-14 16:49:13

強(qiáng)化學(xué)習(xí)時(shí)態(tài)差分法

2020-05-12 07:00:00

深度學(xué)習(xí)強(qiáng)化學(xué)習(xí)人工智能

2023-12-03 22:08:41

深度學(xué)習(xí)人工智能

2023-11-07 07:13:31

推薦系統(tǒng)多任務(wù)學(xué)習(xí)
點(diǎn)贊
收藏

51CTO技術(shù)棧公眾號(hào)