From 354da39108deddfde03107c95a26a8a7e51baa36 Mon Sep 17 00:00:00 2001 From: mrk0669 Date: Sun, 10 May 2026 12:28:36 +0530 Subject: [PATCH] Replace HIM dual-encoder + terrain table with DreamWaQ-style VAE The original HIMEstimator used two encoders (primary + target) with a prototype embedding table and Sinkhorn-based contrastive (swap) loss to learn terrain-aware latent representations. Following DreamWaQ (Nahrendra et al., 2023), this replaces that architecture with a simple Variational Autoencoder: - Encoder now outputs vel(3) + mu(16) + logvar(16) instead of vel(3) + z(16) - Reparameterization trick (z = mu + eps*sigma) is used during training; inference uses mu directly (deterministic, no noise) - Loss = MSE velocity estimation + beta * KL divergence to N(0,1) - Removes: target encoder, prototype embedding table, Sinkhorn algorithm - Old config keys (tar_hidden_dims, num_prototype, temperature) are kept as no-op params so existing config files remain compatible - him_ppo.py and him_on_policy_runner.py updated: swap_loss -> kl_loss Co-Authored-By: Claude Sonnet 4.6 --- rsl_rl/rsl_rl/algorithms/him_ppo.py | 10 +- rsl_rl/rsl_rl/modules/him_estimator.py | 120 ++++++------------ rsl_rl/rsl_rl/runners/him_on_policy_runner.py | 8 +- 3 files changed, 48 insertions(+), 90 deletions(-) diff --git a/rsl_rl/rsl_rl/algorithms/him_ppo.py b/rsl_rl/rsl_rl/algorithms/him_ppo.py index 42f1bdd..2b2c82a 100644 --- a/rsl_rl/rsl_rl/algorithms/him_ppo.py +++ b/rsl_rl/rsl_rl/algorithms/him_ppo.py @@ -120,7 +120,7 @@ def update(self): mean_value_loss = 0 mean_surrogate_loss = 0 mean_estimation_loss = 0 - mean_swap_loss = 0 + mean_kl_loss = 0 generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs) @@ -150,7 +150,7 @@ def update(self): param_group['lr'] = self.learning_rate #Estimator Update - estimation_loss, swap_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) + estimation_loss, kl_loss = self.actor_critic.estimator.update(obs_batch, next_critic_obs_batch, lr=self.learning_rate) # Surrogate loss ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch)) @@ -180,13 +180,13 @@ def update(self): mean_value_loss += value_loss.item() mean_surrogate_loss += surrogate_loss.item() mean_estimation_loss += estimation_loss - mean_swap_loss += swap_loss + mean_kl_loss += kl_loss num_updates = self.num_learning_epochs * self.num_mini_batches mean_value_loss /= num_updates mean_surrogate_loss /= num_updates mean_estimation_loss /= num_updates - mean_swap_loss /= num_updates + mean_kl_loss /= num_updates self.storage.clear() - return mean_value_loss, mean_surrogate_loss, estimation_loss, swap_loss + return mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss diff --git a/rsl_rl/rsl_rl/modules/him_estimator.py b/rsl_rl/rsl_rl/modules/him_estimator.py index 4bd920a..6a13f7e 100644 --- a/rsl_rl/rsl_rl/modules/him_estimator.py +++ b/rsl_rl/rsl_rl/modules/him_estimator.py @@ -1,11 +1,7 @@ -import copy -import math import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F -import torch.distributions as torchd -from torch.distributions import Normal, Categorical class HIMEstimator(nn.Module): @@ -13,124 +9,86 @@ def __init__(self, temporal_steps, num_one_step_obs, enc_hidden_dims=[128, 64, 16], - tar_hidden_dims=[128, 64], + tar_hidden_dims=[128, 64], # kept for config compatibility, unused activation='elu', learning_rate=1e-3, max_grad_norm=10.0, - num_prototype=32, - temperature=3.0, + num_prototype=32, # kept for config compatibility, unused + temperature=3.0, # kept for config compatibility, unused + kl_weight=1.0, **kwargs): if kwargs: - print("Estimator_CL.__init__ got unexpected arguments, which will be ignored: " + str( - [key for key in kwargs.keys()])) + print("HIMEstimator.__init__ got unexpected arguments, which will be ignored: " + + str([key for key in kwargs.keys()])) super(HIMEstimator, self).__init__() - activation = get_activation(activation) + activation_fn = get_activation(activation) self.temporal_steps = temporal_steps self.num_one_step_obs = num_one_step_obs self.num_latent = enc_hidden_dims[-1] self.max_grad_norm = max_grad_norm - self.temperature = temperature + self.kl_weight = kl_weight - # Encoder + # Encoder: outputs vel(3) + mu(num_latent) + logvar(num_latent) enc_input_dim = self.temporal_steps * self.num_one_step_obs enc_layers = [] for l in range(len(enc_hidden_dims) - 1): - enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation] + enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[l]), activation_fn] enc_input_dim = enc_hidden_dims[l] - enc_layers += [nn.Linear(enc_input_dim, enc_hidden_dims[-1] + 3)] + enc_layers += [nn.Linear(enc_input_dim, 3 + self.num_latent * 2)] self.encoder = nn.Sequential(*enc_layers) - # Target - tar_input_dim = self.num_one_step_obs - tar_layers = [] - for l in range(len(tar_hidden_dims)): - tar_layers += [nn.Linear(tar_input_dim, tar_hidden_dims[l]), activation] - tar_input_dim = tar_hidden_dims[l] - tar_layers += [nn.Linear(tar_input_dim, enc_hidden_dims[-1])] - self.target = nn.Sequential(*tar_layers) - - # Prototype - self.proto = nn.Embedding(num_prototype, enc_hidden_dims[-1]) - - # Optimizer self.learning_rate = learning_rate self.optimizer = optim.Adam(self.parameters(), lr=self.learning_rate) + def reparameterize(self, mu, logvar): + std = torch.exp(0.5 * logvar) + eps = torch.randn_like(std) + return mu + eps * std + def get_latent(self, obs_history): - vel, z = self.encode(obs_history) - return vel.detach(), z.detach() + """Inference: use mu directly (no sampling noise).""" + out = self.encoder(obs_history.detach()) + vel = out[..., :3] + mu = out[..., 3:3 + self.num_latent] + return vel.detach(), mu.detach() def forward(self, obs_history): - parts = self.encoder(obs_history.detach()) - vel, z = parts[..., :3], parts[..., 3:] - z = F.normalize(z, dim=-1, p=2) - return vel.detach(), z.detach() + return self.get_latent(obs_history) def encode(self, obs_history): - parts = self.encoder(obs_history.detach()) - vel, z = parts[..., :3], parts[..., 3:] - z = F.normalize(z, dim=-1, p=2) - return vel, z + """Training: sample z via reparameterization.""" + out = self.encoder(obs_history.detach()) + vel = out[..., :3] + mu = out[..., 3:3 + self.num_latent] + logvar = out[..., 3 + self.num_latent:] + z = self.reparameterize(mu, logvar) + return vel, mu, logvar, z def update(self, obs_history, next_critic_obs, lr=None): if lr is not None: self.learning_rate = lr for param_group in self.optimizer.param_groups: param_group['lr'] = self.learning_rate - - vel = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs+3].detach() - next_obs = next_critic_obs.detach()[:, 3:self.num_one_step_obs+3] - - z_s = self.encoder(obs_history) - z_t = self.target(next_obs) - pred_vel, z_s = z_s[..., :3], z_s[..., 3:] - - z_s = F.normalize(z_s, dim=-1, p=2) - z_t = F.normalize(z_t, dim=-1, p=2) - with torch.no_grad(): - w = self.proto.weight.data.clone() - w = F.normalize(w, dim=-1, p=2) - self.proto.weight.copy_(w) + # Ground-truth velocity from privileged obs + vel_gt = next_critic_obs[:, self.num_one_step_obs:self.num_one_step_obs + 3].detach() - score_s = z_s @ self.proto.weight.T - score_t = z_t @ self.proto.weight.T + pred_vel, mu, logvar, _ = self.encode(obs_history) - with torch.no_grad(): - q_s = sinkhorn(score_s) - q_t = sinkhorn(score_t) + estimation_loss = F.mse_loss(pred_vel, vel_gt) - log_p_s = F.log_softmax(score_s / self.temperature, dim=-1) - log_p_t = F.log_softmax(score_t / self.temperature, dim=-1) + # KL divergence: D_KL( N(mu, sigma) || N(0,1) ) + kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) - swap_loss = -0.5 * (q_s * log_p_t + q_t * log_p_s).mean() - estimation_loss = F.mse_loss(pred_vel, vel) - losses = estimation_loss + swap_loss + loss = estimation_loss + self.kl_weight * kl_loss self.optimizer.zero_grad() - losses.backward() + loss.backward() nn.utils.clip_grad_norm_(self.parameters(), self.max_grad_norm) self.optimizer.step() - return estimation_loss.item(), swap_loss.item() - - -@torch.no_grad() -def sinkhorn(out, eps=0.05, iters=3): - Q = torch.exp(out / eps).T - K, B = Q.shape[0], Q.shape[1] - Q /= Q.sum() - - for it in range(iters): - # normalize each row: total weight per prototype must be 1/K - Q /= torch.sum(Q, dim=1, keepdim=True) - Q /= K - - # normalize each column: total weight per sample must be 1/B - Q /= torch.sum(Q, dim=0, keepdim=True) - Q /= B - return (Q * B).T + return estimation_loss.item(), kl_loss.item() def get_activation(act_name): @@ -152,4 +110,4 @@ def get_activation(act_name): return nn.Sigmoid() else: print("invalid activation function!") - return None \ No newline at end of file + return None diff --git a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py index 5fa28b2..35a553f 100644 --- a/rsl_rl/rsl_rl/runners/him_on_policy_runner.py +++ b/rsl_rl/rsl_rl/runners/him_on_policy_runner.py @@ -139,7 +139,7 @@ def learn(self, num_learning_iterations, init_at_random_ep_len=False): start = stop self.alg.compute_returns(critic_obs) - mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_swap_loss = self.alg.update() + mean_value_loss, mean_surrogate_loss, mean_estimation_loss, mean_kl_loss = self.alg.update() stop = time.time() learn_time = stop - start if self.log_dir is not None: @@ -176,7 +176,7 @@ def log(self, locs, width=80, pad=35): self.writer.add_scalar('Loss/value_function', locs['mean_value_loss'], locs['it']) self.writer.add_scalar('Loss/surrogate', locs['mean_surrogate_loss'], locs['it']) self.writer.add_scalar('Loss/Estimation Loss', locs['mean_estimation_loss'], locs['it']) - self.writer.add_scalar('Loss/Swap Loss', locs['mean_swap_loss'], locs['it']) + self.writer.add_scalar('Loss/KL Loss', locs['mean_kl_loss'], locs['it']) self.writer.add_scalar('Loss/learning_rate', self.alg.learning_rate, locs['it']) self.writer.add_scalar('Policy/mean_noise_std', mean_std.item(), locs['it']) self.writer.add_scalar('Perf/total_fps', fps, locs['it']) @@ -198,7 +198,7 @@ def log(self, locs, width=80, pad=35): f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" - f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n""" + f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""" f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") @@ -212,7 +212,7 @@ def log(self, locs, width=80, pad=35): f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" f"""{'Estimation loss:':>{pad}} {locs['mean_estimation_loss']:.4f}\n""" - f"""{'Swap loss:':>{pad}} {locs['mean_swap_loss']:.4f}\n""" + f"""{'KL loss:':>{pad}} {locs['mean_kl_loss']:.4f}\n""" f"""{'Mean action noise std:':>{pad}} {mean_std.item():.2f}\n""") # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""")