diff --git a/internnav/dataset/navdp_lerobot_dataset.py b/internnav/dataset/navdp_lerobot_dataset.py index 95753d9c..16183b4b 100644 --- a/internnav/dataset/navdp_lerobot_dataset.py +++ b/internnav/dataset/navdp_lerobot_dataset.py @@ -9,6 +9,7 @@ import open3d as o3d import pandas as pd import torch +import jsonlines from PIL import Image from scipy.interpolate import CubicSpline from torch.utils.data import Dataset @@ -30,6 +31,8 @@ def print(*args, **kwargs): builtins.print = print + + class NavDP_Base_Datset(Dataset): def __init__( self, @@ -42,6 +45,7 @@ def __init__( scene_data_scale=1.0, trajectory_data_scale=1.0, pixel_channel=7, + action_dim=3, debug=False, preload=False, random_digit=False, @@ -54,8 +58,9 @@ def __init__( self.scene_scale_size = scene_data_scale self.trajectory_data_scale = trajectory_data_scale self.predict_size = predict_size + self.action_dim = action_dim self.debug = debug - self.trajectory_dirs = [] + self.trajectory_data_dir = [] self.trajectory_rgb_path = [] self.trajectory_depth_path = [] @@ -69,70 +74,80 @@ def __init__( self._last_time = None if preload is False: - for group_dir in self.dataset_dirs: # gibson_zed, 3dfront ... + for group_dir in self.dataset_dirs: # gibson_zed, 3dfront ... all_scene_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir))]) select_scene_dirs = all_scene_dirs[ np.arange(0, all_scene_dirs.shape[0], 1 / self.scene_scale_size).astype(np.int32) ] - for scene_dir in select_scene_dirs: - all_traj_dirs = np.array([p for p in os.listdir(os.path.join(root_dirs, group_dir, scene_dir))]) - select_traj_dirs = all_traj_dirs[ - np.arange(0, all_traj_dirs.shape[0], 1 / self.trajectory_data_scale).astype(np.int32) - ] - for traj_dir in tqdm(select_traj_dirs): - entire_task_dir = os.path.join(root_dirs, group_dir, scene_dir, traj_dir) - rgb_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.rgb/") - depth_dir = os.path.join(entire_task_dir, "videos/chunk-000/observation.images.depth/") - data_path = os.path.join( - entire_task_dir, 'data/chunk-000/episode_000000.parquet' - ) # intrinsic, extrinsic, cam_traj, path - afford_path = os.path.join(entire_task_dir, 'data/chunk-000/path.ply') - rgbs_length = len([p for p in os.listdir(rgb_dir)]) - depths_length = len([p for p in os.listdir(depth_dir)]) - - rgbs_path = [] - depths_path = [] - if depths_length != rgbs_length: - continue - for i in range(rgbs_length): - rgbs_path.append(os.path.join(rgb_dir, "%d.jpg" % i)) - depths_path.append(os.path.join(depth_dir, "%d.png" % i)) - if os.path.exists(data_path) is False: - continue - self.trajectory_dirs.append(entire_task_dir) - self.trajectory_data_dir.append(data_path) - self.trajectory_rgb_path.append(rgbs_path) - self.trajectory_depth_path.append(depths_path) - self.trajectory_afford_path.append(afford_path) - + + for scene_dir in tqdm(select_scene_dirs): + chunk_name = os.listdir(os.path.join(root_dirs, group_dir, scene_dir, 'data'))[0] + data_dir = os.path.join(root_dirs, group_dir, scene_dir, f'data/{chunk_name}') + afford_dir = os.path.join(root_dirs, group_dir, scene_dir, 'meta/pointcloud.ply') + with jsonlines.open(os.path.join(root_dirs, group_dir, scene_dir, 'meta/episodes_stats.jsonl'), 'r') as reader: + episode_info = list(reader) + rgb_dir = os.path.join(root_dirs, group_dir, scene_dir, f"videos/{chunk_name}/observation.images.rgb/") + rgb_paths = [os.path.join(rgb_dir, p) for p in sorted(os.listdir(rgb_dir))] + + depth_dir = os.path.join(root_dirs, group_dir, scene_dir, f"videos/{chunk_name}/observation.images.depth/") + depth_paths = [os.path.join(depth_dir, p) for p in sorted(os.listdir(depth_dir))] + + data_paths = [os.path.join(data_dir, p) for p in sorted(os.listdir(data_dir))] + + for episode_idx, episode in enumerate(episode_info): + image_start_index = episode['image_index']['min'] + image_end_index = episode['image_index']['max'] + episode_rgb_path = np.array(rgb_paths)[image_start_index:image_end_index+1].tolist() + episode_depth_path = np.array(depth_paths)[image_start_index:image_end_index+1].tolist() + + try: + self.trajectory_data_dir.append(data_paths[episode_idx]) + self.trajectory_rgb_path.append(episode_rgb_path) + self.trajectory_depth_path.append(episode_depth_path) + self.trajectory_afford_path.append(afford_dir) + except: + import pdb + pdb.set_trace() + + save_dict = { - 'trajectory_dirs': self.trajectory_dirs, - 'trajectory_data_dir': self.trajectory_data_dir, + 'trajectory_data_dir': self.trajectory_data_dir, 'trajectory_rgb_path': self.trajectory_rgb_path, 'trajectory_depth_path': self.trajectory_depth_path, 'trajectory_afford_path': self.trajectory_afford_path, } with open(preload_path, 'w') as f: json.dump(save_dict, f, indent=4) + + # replicate the data 50 times + self.trajectory_data_dir = self.trajectory_data_dir * 50 + self.trajectory_rgb_path = self.trajectory_rgb_path * 50 + self.trajectory_depth_path = self.trajectory_depth_path * 50 + self.trajectory_afford_path = self.trajectory_afford_path * 50 else: load_dict = json.load(open(preload_path, 'r')) - self.trajectory_dirs = load_dict['trajectory_dirs'] * 50 self.trajectory_data_dir = load_dict['trajectory_data_dir'] * 50 self.trajectory_rgb_path = load_dict['trajectory_rgb_path'] * 50 self.trajectory_depth_path = load_dict['trajectory_depth_path'] * 50 self.trajectory_afford_path = load_dict['trajectory_afford_path'] * 50 def __len__(self): - return len(self.trajectory_dirs) + return len(self.trajectory_data_dir) def load_image(self, image_url): - image = Image.open(image_url) - image = np.array(image, np.uint8) + try: + image = Image.open(image_url) + image = np.array(image, np.uint8) + except: + image = np.zeros((self.image_size, self.image_size, 3), dtype=np.uint8) return image def load_depth(self, depth_url): - depth = Image.open(depth_url) - depth = np.array(depth, np.uint16) + try: + depth = Image.open(depth_url) + depth = np.array(depth, np.uint16) + except: + depth = np.zeros((self.image_size, self.image_size), dtype=np.uint16) return depth def load_pointcloud(self, pcd_url): @@ -176,39 +191,19 @@ def process_data_parquet(self, index): camera_intrinsic = np.vstack(np.array(df['observation.camera_intrinsic'].tolist()[0])).reshape(3, 3) camera_extrinsic = np.vstack(np.array(df['observation.camera_extrinsic'].tolist()[0])).reshape(4, 4) trajectory_length = len(df['action'].tolist()) - camera_trajectory = np.array([np.stack(frame) for frame in df['action']], dtype=np.float64) + camera_trajectory = np.array([np.stack(frame) for frame in df['action']], dtype=np.float64).reshape(-1,4,4) return camera_intrinsic, camera_extrinsic, camera_trajectory, trajectory_length - def process_path_points(self, index): - trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) - trajectory_color = np.array(trajectory_pcd.colors) - color_distance = np.abs(trajectory_color - np.array([0, 0, 0])).sum( - axis=-1 - ) # sometimes, the path are saved as black points + def process_obstacle_points(self, index): + scene_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) + scene_color = np.array(scene_pcd.colors) + scene_points = np.array(scene_pcd.points) + color_distance = np.abs(scene_color - np.array([0, 0, 0.5])).sum(axis=-1) select_index = np.where(color_distance < 0.05)[0] - trajectory_path = o3d.geometry.PointCloud() - trajectory_path.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) - trajectory_path.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) - return np.array(trajectory_path.points), trajectory_path - - def process_obstacle_points(self, index, path_points): - trajectory_pcd = self.load_pointcloud(self.trajectory_afford_path[index]) - trajectory_color = np.array(trajectory_pcd.colors) - trajectory_points = np.array(trajectory_pcd.points) - color_distance = np.abs(trajectory_color - np.array([0, 0, 0.5])).sum(axis=-1) # the obstacles are save in blue - path_lower_bound = path_points.min(axis=0) - path_upper_bound = path_points.max(axis=0) - condition_x = (trajectory_points[:, 0] >= path_lower_bound[0] - 2.0) & ( - trajectory_points[:, 0] <= path_upper_bound[0] + 2.0 - ) - condition_y = (trajectory_points[:, 1] >= path_lower_bound[1] - 2.0) & ( - trajectory_points[:, 1] <= path_upper_bound[1] + 2.0 - ) - select_index = np.where((color_distance < 0.05) & condition_x & condition_y)[0] - trajectory_obstacle = o3d.geometry.PointCloud() - trajectory_obstacle.points = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.points)[select_index]) - trajectory_obstacle.colors = o3d.utility.Vector3dVector(np.asarray(trajectory_pcd.colors)[select_index]) - return np.array(trajectory_obstacle.points), trajectory_obstacle + scene_obstacle = o3d.geometry.PointCloud() + scene_obstacle.points = o3d.utility.Vector3dVector(scene_points[select_index]) + scene_obstacle.colors = o3d.utility.Vector3dVector(scene_color[select_index]) + return np.array(scene_obstacle.points), scene_obstacle def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1): memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step + 1, memory_digit) @@ -220,8 +215,11 @@ def process_memory(self, rgb_paths, depth_paths, start_step, memory_digit=1): return context_image, context_depth, memory_index def process_pixel_goal(self, image_url, target_point, camera_intrinsic, camera_extrinsic): - image = Image.open(image_url) - image = np.array(image, np.uint8) + try: + image = Image.open(image_url) + image = np.array(image, np.uint8) + except: + image = np.zeros((self.image_size, self.image_size, 3), dtype=np.uint8) resize_image = self.process_image(image_url) coordinate = np.array([-target_point[1], target_point[0], camera_extrinsic[2, 3] * 0.8]) @@ -422,10 +420,11 @@ def __getitem__(self, index): trajectory_length, ) = self.process_data_parquet(index) - trajectory_path_points, trajectory_path_pcd = self.process_path_points(index) - trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points( - index, trajectory_path_points - ) + # trajectory_path_points, trajectory_path_pcd = self.process_path_points(index) + # trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points( + # index, trajectory_path_points + # ) + trajectory_obstacle_points, trajectory_obstacle_pcd = self.process_obstacle_points(index) if self.prior_sample: pixel_start_choice, target_choice = self.rank_steps() @@ -520,7 +519,10 @@ def __getitem__(self, index): pred_actions = (pred_actions[1:] - pred_actions[:-1]) * 4.0 augment_actions = (augment_actions[1:] - augment_actions[:-1]) * 4.0 - + + pred_actions = np.pad(pred_actions, ((0,0),(0,self.action_dim - pred_actions.shape[-1])), mode='constant', constant_values=(0,0)) + augment_actions = np.pad(augment_actions, ((0,0),(0,self.action_dim - augment_actions.shape[-1])), mode='constant', constant_values=(0,0)) + # Summarize avg time of batch end_time = time.time() self.item_cnt += 1 @@ -553,6 +555,264 @@ def __getitem__(self, index): float(pixel_flag), ) +class SekaiDataset(Dataset): + def __init__( + self, + root_dir, + preload_path=False, + memory_size=8, + predict_size=24, + batch_size=64, + image_size=224, + scene_data_scale=1.0, + trajectory_data_scale=1.0, + debug=False, + preload=False, + random_digit=False, + ): + self.video_dataset_dir = os.path.join(root_dir,'sekai-real-walking') + self.action_dataset_dir = os.path.join(root_dir,'sekai-real-walking-hq') + self.trajectory_video_dir = [] + self.trajectory_action_dir = [] + + self.preload_path = preload_path + self.preload = preload + self.random_digit = random_digit + self.predict_size = predict_size + self.memory_size = memory_size + self.image_size = image_size + + for video_group in os.listdir(self.video_dataset_dir)[0:100]: + video_dir = os.path.join(self.video_dataset_dir, video_group) + for video_name in os.listdir(os.path.join(video_dir)): + video_id = video_name.split(".mp4")[0] + action_path = os.path.join(self.action_dataset_dir,video_id + ".npz") + video_path = os.path.join(video_dir, video_name) + if os.path.exists(action_path) and os.path.exists(video_path): + self.trajectory_video_dir.append(video_path) + self.trajectory_action_dir.append(action_path) + + def __len__(self): + return len(self.trajectory_action_dir) + + def load_video(self, idx): + """ + Load MP4 video file and return all frames as a list of numpy arrays. + + Args: + idx: Index of the video in trajectory_video_dir + + Returns: + List of numpy arrays, each representing a frame (H, W, 3) in BGR format + """ + video_path = self.trajectory_video_dir[idx] + frames = [] + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + raise ValueError(f"Failed to open video file: {video_path}") + + while True: + ret, frame = cap.read() + if not ret: + break + frames.append(frame) + cap.release() + if len(frames) == 0: + print(f"Warning: No frames extracted from video: {video_path}") + return [] + + except Exception as e: + print(f"Error loading video {video_path}: {e}") + return [] + return frames + + def load_npdata(self, idx): + """ + Load .npz file and return the data as a dictionary. + + Args: + idx: Index of the .npz file in trajectory_action_dir + + Returns: + Dictionary containing arrays from the .npz file, or empty dict on error + """ + data_path = self.trajectory_action_dir[idx] + + try: + if not os.path.exists(data_path): + raise FileNotFoundError(f"NPZ file not found: {data_path}") + + data = np.load(data_path) + # Convert to dictionary for easier access + # Note: np.load returns a NpzFile object, which can be accessed like a dict + # We can return it directly or convert to a regular dict + return dict(data) + + except Exception as e: + print(f"Error loading NPZ file {data_path}: {e}") + return {} + + def process_image(self, image): + H, W, C = image.shape + prop = self.image_size / max(H, W) + image = cv2.resize(image, (-1, -1), fx=prop, fy=prop) + pad_width = max((self.image_size - image.shape[1]) // 2, 0) + pad_height = max((self.image_size - image.shape[0]) // 2, 0) + pad_image = np.pad( + image, ((pad_height, pad_height), (pad_width, pad_width), (0, 0)), mode='constant', constant_values=0 + ) + image = cv2.resize(pad_image, (self.image_size, self.image_size)) + image = np.array(image, np.float32) / 255.0 + return image + + def process_memory(self, rgb_images, start_step, memory_digit=1): + memory_index = np.arange(start_step - (self.memory_size - 1) * memory_digit, start_step + 1, memory_digit) + outrange_sum = (memory_index < 0).sum() + memory_index = memory_index[outrange_sum:] + context_image = np.zeros((self.memory_size, self.image_size, self.image_size, 3), np.float32) + context_image[outrange_sum:] = np.array([self.process_image(rgb_images[i].copy()) for i in memory_index]) + return context_image, memory_index + + def relative_pose(self, R_base, T_base, R_world, T_world, base_extrinsic): + R_base = np.matmul(R_base, np.linalg.inv(base_extrinsic[0:3, 0:3])) + if len(T_world.shape) == 1: + homo_RT = np.eye(4) + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_frame = np.dot(R_world, R_base.T) + T_frame = np.dot(np.linalg.inv(homo_RT), np.array([*T_world, 1]).T)[0:3] + T_frame = np.array([T_frame[1], -T_frame[0], T_frame[2]]) # [:T[1],-T[0],T[2] + return R_frame, T_frame + else: + homo_RT = np.eye(4) + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_frame = np.dot(R_world, R_base.T) + T_frame = np.dot( + np.linalg.inv(homo_RT), np.concatenate((T_world, np.ones((T_world.shape[0], 1))), axis=-1).T + ).T[:, 0:3] + T_frame = T_frame[:, [1, 0, 2]] + T_frame[:, 1] = -T_frame[:, 1] + return R_frame, T_frame + + def absolute_pose(self, R_base, T_base, R_frame, T_frame, base_extrinsic): + R_base = np.matmul(R_base, np.linalg.inv(base_extrinsic[0:3, 0:3])) + if len(T_frame.shape) == 1: + homo_RT = np.eye(4) + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_world = np.dot(R_frame, R_base) + T_world = np.dot(homo_RT, np.array([-T_frame[1], T_frame[0], T_frame[2], 1]).T)[0:3] + else: + homo_RT = np.eye(4) + homo_RT[0:3, 0:3] = R_base + homo_RT[0:3, 3] = T_base + R_world = np.dot(R_frame, R_base) + T_world = np.dot( + homo_RT, + np.concatenate( + (np.stack((-T_frame[:, 1], T_frame[:, 0], T_frame[:, 2]), axis=-1), np.ones((T_frame.shape[0], 1))), + axis=-1, + ).T, + ).T[:, 0:3] + return R_world, T_world + + def xyz_to_xyt(self, xyz_actions, init_vector): + xyt_actions = [] + for i in range(0, xyz_actions.shape[0] - 1): + current_vector = xyz_actions[i + 1] - xyz_actions[i] + dot_product = np.dot(init_vector[0:2], current_vector[0:2]) + cross_product = np.cross(init_vector[0:2], current_vector[0:2]) + theta = np.arctan2(cross_product, dot_product) + xyt_actions.append([xyz_actions[i][0], xyz_actions[i][1], theta]) + return np.array(xyt_actions) + + def process_actions(self, extrinsics, base_extrinsic, start_step, end_step, pred_digit=1): + label_linear_pos = [] + for f_ext in extrinsics[start_step : end_step + 1]: + R, T = self.relative_pose( + extrinsics[start_step][0:3, 0:3], + extrinsics[start_step][0:3, 3], + f_ext[0:3, 0:3], + f_ext[0:3, 3], + base_extrinsic, + ) + label_linear_pos.append(T) + label_actions = np.array(label_linear_pos) + + # this is usesd for action augmentations: + # (1) apply random rotation to the future steps + # (2) interpolate between the rotated actions and origin actions + + origin_world_points = extrinsics[start_step : end_step + 1, 0:3, 3] + local_label_points = [] + + for f_ext in origin_world_points: + Rf, Tf = self.relative_pose( + extrinsics[start_step][0:3, 0:3], extrinsics[start_step][0:3, 3], np.eye(3), f_ext, base_extrinsic + ) + local_label_points.append(Tf) + local_label_points = np.array(local_label_points) + action_indexes = np.clip(np.arange(self.predict_size + 1) * pred_digit, 0, label_actions.shape[0] - 2) + return local_label_points, origin_world_points, action_indexes + + + def __getitem__(self, index): + video_data = self.load_video(index) + action_data = self.load_npdata(index) + + camera_intrinsic = action_data['intrinsic'] + trajectory_extrinsics = action_data['extrinsic'] + trajectory_grid = np.abs(trajectory_extrinsics[:-1,0:3,3] - trajectory_extrinsics[1:,0:3,3]).sum(axis=-1).mean() + downsample_digit = (0.04 / trajectory_grid).astype(np.int32) + downsample_index = np.arange(0,len(video_data),downsample_digit) + + video_data = [video_data[idx] for idx in downsample_index] + trajectory_extrinsics = trajectory_extrinsics[downsample_index] + trajectory_length = trajectory_extrinsics.shape[0] + + import imageio + fps_writer = imageio.get_writer('fps.mp4', fps=10) + for image in video_data: + fps_writer.append_data(image) + fps_writer.close() + + import pdb + pdb.set_trace() + + target_choice = np.random.randint(max((trajectory_length - 1) // 2, 0), max(trajectory_length - 1, 0)) + memory_start_choice = np.random.randint(0, max(target_choice,1)) + + if self.random_digit: + memory_digit = np.random.randint(2, 8) + pred_digit = memory_digit + else: + memory_digit = 4 + pred_digit = 4 + + memory_images, memory_index = self.process_memory( + video_data, + memory_start_choice, + memory_digit=memory_digit, + ) + + target_local_points, target_world_points, action_indexes = self.process_actions(trajectory_extrinsics, + np.eye(4), + memory_start_choice, + target_choice, + pred_digit) + + import pdb + pdb.set_trace() + + + + + + + + def navdp_collate_fn(batch): @@ -569,19 +829,31 @@ def navdp_collate_fn(batch): } return collated - if __name__ == "__main__": + dataset = SekaiDataset(root_dir="/mnt/data/caiwenzhe/interndata-vln-n1/sekai/") + for i in range(100): + import time + start_time = time.time() + dataset.__getitem__(i) + print(time.time() - start_time) + import pdb + pdb.set_trace() + + #import pdb + #pdb.set_trace() os.makedirs("./navdp_dataset_test/", exist_ok=True) dataset = NavDP_Base_Datset( - "/shared/smartbot_new/liuyu/vln-n1-minival/", - "./navdp_dataset_test/dataset_lerobot.json", + "/mnt/data/liuyu/InternDate-N1-v05/vln-n1", + "./navdp_dataset_test/dataset_lerobot_v05_with_interiorgs.json", 8, 24, 224, - trajectory_data_scale=0.1, - scene_data_scale=0.1, + trajectory_data_scale=1.0, + scene_data_scale=1.0, preload=False, ) + # import pdb + # pdb.set_trace() for i in range(10): ( diff --git a/internnav/model/basemodel/navdp/navdp_policy.py b/internnav/model/basemodel/navdp/navdp_policy.py index e2353529..58e97bec 100644 --- a/internnav/model/basemodel/navdp/navdp_policy.py +++ b/internnav/model/basemodel/navdp/navdp_policy.py @@ -266,8 +266,10 @@ def forward(self, goal_point, goal_image, goal_pixel, input_images, input_depths noise_pred_mg, cr_label_pred, cr_augment_pred, - [ng_noise, mg_noise], - [imagegoal_aux_pred, pixelgoal_aux_pred], + ng_noise, + mg_noise, + imagegoal_aux_pred, + pixelgoal_aux_pred, ) def _get_device(self): diff --git a/internnav/trainer/navdp_trainer.py b/internnav/trainer/navdp_trainer.py index 5542562e..c3cb53bf 100644 --- a/internnav/trainer/navdp_trainer.py +++ b/internnav/trainer/navdp_trainer.py @@ -77,7 +77,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N batch_label_critic = inputs["batch_label_critic"] batch_augment_critic = inputs["batch_augment_critic"] - pred_ng, pred_mg, critic_pred, augment_pred, noise, aux_pred = model( + pred_ng, pred_mg, critic_pred, augment_pred, ng_noise, mg_noise, imagegoal_aux_pred, pixelgoal_aux_pred = model( inputs_on_device["batch_pg"], inputs_on_device["batch_ig"], inputs_on_device["batch_tg"], @@ -87,11 +87,11 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N inputs_on_device["batch_augments"], ) - ng_action_loss = (pred_ng - noise[0]).square().mean() - mg_action_loss = (pred_mg - noise[1]).square().mean() + ng_action_loss = (pred_ng - ng_noise).square().mean() + mg_action_loss = (pred_mg - mg_noise).square().mean() aux_loss = ( - 0.5 * (inputs_on_device["batch_pg"] - aux_pred[0]).square().mean() - + 0.5 * (inputs_on_device["batch_pg"] - aux_pred[1]).square().mean() + 0.5 * (inputs_on_device["batch_pg"] - imagegoal_aux_pred).square().mean() + + 0.5 * (inputs_on_device["batch_pg"] - pixelgoal_aux_pred).square().mean() ) action_loss = 0.5 * mg_action_loss + 0.5 * ng_action_loss critic_loss = (critic_pred - batch_label_critic).square().mean() + ( @@ -104,7 +104,7 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N 'pred_mg': pred_mg, 'critic_pred': critic_pred, 'augment_pred': augment_pred, - 'noise': noise, + 'noise': [ng_noise, mg_noise], 'loss': loss, 'ng_action_loss': ng_action_loss, 'mg_action_loss': mg_action_loss, diff --git a/scripts/train/base_train/train.py b/scripts/train/base_train/train.py index 7bb3e278..d9427174 100755 --- a/scripts/train/base_train/train.py +++ b/scripts/train/base_train/train.py @@ -127,6 +127,9 @@ def main(config, model_class, model_config_class): model = model_class.from_pretrained(pretrained_model_name_or_path=config.il.ckpt_to_load, config=model_cfg) if config.model_name == "navdp": model.to(device) + for name, param in model.named_parameters(): + if 'mask_token' in name: + param.requires_grad = False # Check that all parameters and buffers are on the correct device for name, param in model.named_parameters(): if param.device != device: