Source code for graspnetAPI.moving_graspnet

__author__ = "mhgou"
# email: gouminghao@gmail.com

# GraspNet Toolbox.      version 2.0
# Data, paper, and tutorials available at:  https://graspnet.net/
# Code written by Minghao Gou 2022.
# Licensed under the none commercial CC4.0 license [see https://graspnet.net/about]

import os
import numpy as np
from tqdm import tqdm
import open3d as o3d
import cv2
import trimesh
import json
import open3d_plus as o3dp
import yaml


from .grasp import Grasp, GraspGroup, RectGrasp, RectGraspGroup, RECT_GRASP_ARRAY_LEN
from .utils.utils import transform_points, parse_posevector
from .utils.xmlhandler import xmlReader
from .utils.logger import get_logger


def _isArrayLike(obj):
    return hasattr(obj, "__iter__") and hasattr(obj, "__len__")


SCENE_DIR_NAME = "scenes"
CAMERA_PREFIX = "cam_"
SCENE_METADATA_FILE = "metadata.json"
DATASET_METADATA_FILE = "default.yaml"
MODELS_DIR = "models"
GRASP_DIR = "grasp_label"
MULTIPLE_OBJECT_COLLISION_DIR_NAME = "collision_label_multiobj_25frame"
INITIAL_FRAME_COLLISION_DIR_NAME = "collision_label_initial_frame"
POSE_DIR = "pose"
COLOR_DIR = "color"
DEPTH_DIR = "depth"
REALSENSE_D_SERIALS_DEPTH_SCALE = 1000.0  # scale is 1000 for other cameras
REALSENSE_L_SERIALS_DEPTH_SCALE = 4000.0  # scale is 4000 for L515
Z_DISTANCE_THRESH = 2.0  # max distance from the point to the camera in z direction
GRASP_HEIGHT = 0.02


[docs]class MovingGraspNet: """API for MovingGraspNet. Args: root(str): root directory for MovingGraspNet. """ def __init__(self, root): self.root = root self.scenes_root = os.path.join(self.root, SCENE_DIR_NAME) self.scene_name_list = self._get_scene_name_list() self.collision_path_dict_multiobj_25frame = self._get_gt_collision_frame_dict( "multiobj_25frame" ) self.collision_path_dict_initial_frame = self._get_gt_collision_frame_dict( "initial_frame" ) self._load_camera_intrinsics() self._load_object_name_list() self.logger = get_logger("Moving-GraspNet") def _get_scene_name_list(self): """Get all scene name. Returns: list: a list of scene names. """ return sorted(os.listdir(self.scenes_root)) def _get_camera_sn_list(self, scene_name): """Get all camera serial number from a scene. Args: scene_name(str): the scene name. Returns: list: the camera serial numbers list. """ scene_dir = os.path.join(self.scenes_root, scene_name) cam_sn_list = [ name.split("_")[1] for name in os.listdir(scene_dir) if name.startswith(CAMERA_PREFIX) ] return sorted(cam_sn_list) def _check_metadata(self, scene_name, metadata): """Check if the metadata is correct. Args: scene_name(str): the scene name. metadata(dict): the dict of metadata. Returns: bool: True for correct. """ cam_sn_list = self._get_camera_sn_list(scene_name=scene_name) cam_set = set(cam_sn_list + ["num_" + cam_sn for cam_sn in cam_sn_list]) key_set = set(metadata.keys()) return cam_set == key_set
[docs] def load_scene_metadata(self, scene_name): """Get metadata of a scene. Args: scene_name:(str): the scene name. Returns: dict: the metadata of the scene. """ scene_dir = os.path.join(self.scenes_root, scene_name) with open( os.path.join(scene_dir, SCENE_METADATA_FILE), "r" ) as metadata_json_file: scene_meta = json.load(metadata_json_file) return scene_meta
def _get_valid_frame_list(self, scene_name, camera_sn): """Get all valid frame in a scene. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. Returns: list: the valid frame ids. """ scene_meta = self.load_scene_metadata(scene_name=scene_name) assert len(scene_meta[camera_sn]) == scene_meta["num_" + camera_sn] return [str(int_camera_sn) for int_camera_sn in scene_meta[camera_sn]]
[docs] def get_frame_list(self, scene_name, camera_sn): """Get all frame in a scene. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. Returns: list(str): the frame ids. """ camera_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, CAMERA_PREFIX + camera_sn ) return sorted( [ file_name.replace(".png", "") for file_name in os.listdir(os.path.join(camera_dir, "color")) ] )
def _load_camera_intrinsics(self): """Get all camera intrinsic matraces. Returns: dict: the intrinsic matraces. """ camera_sn_list = [ camera_sn.replace(".npy", "") for camera_sn in os.listdir(os.path.join(self.root, "cam_intrinsics")) ] camera_intrinsics = dict() for camera_sn in camera_sn_list: camera_intrinsics[camera_sn] = np.load( os.path.join(self.root, "cam_intrinsics", "{}.npy".format(camera_sn)) ) self.camera_intrinsics = camera_intrinsics
[docs] def load_cam_intrinsic(self, camera_sn): """Get camera intrinsic matrix for a camera. Args: camera_sn(str): camera serial number. Returns: dict: the intrinsic matraces. """ return self.camera_intrinsics[camera_sn]
[docs] def get_rgb_path(self, scene_name, camera_sn, frame): """Get RGB image path Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. Returns: str: path of the RGB image. """ return os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), "color", "{}.png".format(frame), )
[docs] def get_depth_path(self, scene_name, camera_sn, frame): """Get depth image path Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. Returns: str: path of the depth image. """ return os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), "depth", "{}.png".format(frame), )
[docs] def load_point_cloud(self, scene_name, camera_sn, frame): """Get open3d point cloud. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. Returns: open3d.geometry.PointCloud: the point cloud. """ depth_path = self.get_depth_path(scene_name, camera_sn, frame) rgb_path = self.get_rgb_path(scene_name, camera_sn, frame) intrinsic = self.load_cam_intrinsic(camera_sn) depth_scale = ( REALSENSE_L_SERIALS_DEPTH_SCALE if camera_sn.startswith("f") else REALSENSE_D_SERIALS_DEPTH_SCALE ) pcd = o3dp.generate_scene_pointcloud( depth=depth_path, rgb=rgb_path, intrinsics=intrinsic, depth_scale=depth_scale, ) points, colors = o3dp.pcd2array(pcd) mask = points[:, 2] < Z_DISTANCE_THRESH return o3dp.array2pcd(points[mask], colors[mask])
def _load_object_name_list(self): """Load object id name correspondese""" with open( os.path.join(self.root, DATASET_METADATA_FILE), "r" ) as scene_metadata_file: scene_metadata_yaml = scene_metadata_file.read() scene_metadata = yaml.load(scene_metadata_yaml, Loader=yaml.FullLoader) self.object_name_list = [ object_file_name.replace(".ply", "") for object_file_name in scene_metadata["object_model_list"] ]
[docs] def load_object_point_cloud(self, obj_id): """Load object point cloud Args: obj_id(int): object index. Returns: open3d.geometry.PointCloud: the object point cloud. """ models_dir = os.path.join(self.root, MODELS_DIR) return o3d.io.read_point_cloud( os.path.join(models_dir, "{}.ply".format(self.object_name_list[obj_id])) )
[docs] def load_scene_object_list(self, scene_name, camera_sn, frame): """Get object indices list with in a scene. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. Returns: list(int): the object indices. """ frame_pose_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), POSE_DIR, frame, ) return list( set( [ int(obj_id.replace(".npy", "").replace("registered_", "")) for obj_id in os.listdir(frame_pose_dir) if obj_id.endswith(".npy") ] ) )
[docs] def load_scene_registered_object_list(self, scene_name, camera_sn, frame): """Get registered object indices list with in a scene. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. Returns: list(int): the object indices. """ frame_pose_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), POSE_DIR, frame, ) return list( set( [ int(obj_id.replace(".npy", "").replace("registered_", "")) for obj_id in os.listdir(frame_pose_dir) if (obj_id.startswith("registered_") and obj_id.endswith(".npy")) ] ) )
[docs] def load_object_pose(self, scene_name, camera_sn, frame, obj_id, registered=True): """Get object 6d pose. None for no pose available. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. obj_id(int): the index of the object. registered(bool): whether to use the registered pose. Returns: np.array(4x4): the object pose. """ frame_pose_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), POSE_DIR, frame, ) prefix = "" if not registered else "registered_" if not os.path.exists( os.path.join(frame_pose_dir, "{}{}.npy".format(prefix, obj_id)) ): return None pose = np.load(os.path.join(frame_pose_dir, "{}{}.npy".format(prefix, obj_id))) return pose
[docs] def get_registered_object_pose_filename(self, scene_name, camera_sn, frame, obj_id): """Get the new object pose filename Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. obj_id(int): the index of the object. Returns: str: the new object pose filename. """ frame_pose_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), POSE_DIR, frame, ) return os.path.join(frame_pose_dir, "registered_{}.npy".format(obj_id))
[docs] def load_scene_with_object(self, scene_name, camera_sn, frame, registered=True): """Get open3d point cloud for both scene and object(s). Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the frame id. registered(bool): whether to use registered object pose. Returns: list(open3d.geometry.PointCloud): the point clouds. """ pcds = [ self.load_point_cloud( scene_name=scene_name, camera_sn=camera_sn, frame=frame ) ] frame_pose_dir = os.path.join( self.root, SCENE_DIR_NAME, scene_name, "{}{}".format(CAMERA_PREFIX, camera_sn), POSE_DIR, frame, ) with_pose_obj_ids = self.load_scene_object_list(scene_name, camera_sn, frame) for obj_id in with_pose_obj_ids: pcd = self.load_object_point_cloud(obj_id) pose = self.load_object_pose( scene_name, camera_sn, frame, obj_id, registered=registered ) pcd.transform(pose) pcds.append(pcd) return pcds
[docs] def get_near_frames(self, scene_name, camera_sn, frame, max_distance): """Find the near frame indices. Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the given frame id. max_distance(int): maxmimum distance(the unit is ms(1/1000 second)) allowed. Returns: list(str): the list of frame indices within the max distance. """ near_frames = [] frame_list = self.get_frame_list(scene_name=scene_name, camera_sn=camera_sn) if not frame in frame_list: raise ValueError("frame {} is not a valid frame in scene {} camera {}") for all_frame in frame_list: if ( abs(int(frame) - int(all_frame)) < max_distance and not frame == all_frame ): near_frames.append(all_frame) return near_frames
[docs] def load_obj_grasp_labels(self, obj_ids=None): """Load grasp labels of objects. Author: Chenxi Wang Args: obj_ids(int/list): object indices. Returns: grasp_labels(dict): grasp labels of target objects. """ obj_ids = ( [x for x in range(len(self.object_name_list))] if obj_ids is None else obj_ids ) assert _isArrayLike(obj_ids) or isinstance( obj_ids, int ), "obj_ids must be an integer or a list/numpy array of integers" obj_ids = obj_ids if _isArrayLike(obj_ids) else [obj_ids] assert min(obj_ids) >= 0 and max(obj_ids) < len( self.object_name_list ), "obj_ids out of range" grasp_labels = {} for obj_id in tqdm(obj_ids, desc="Loading grasp labels..."): obj_name = self.object_name_list[obj_id] labels = np.load( os.path.join(self.root, GRASP_DIR, "%s_labels.npz" % obj_name) ) points = labels["points"].astype(np.float32) offsets = labels["offsets"].astype(np.float32) fric_coefs = labels["scores"].astype(np.float32) grasp_labels[obj_id] = (points, offsets, fric_coefs) return grasp_labels
[docs] def load_obj_grasp_poses( self, obj_id, grasp_labels=None, collision_labels=None, fric_coef_thresh=0.4 ): """Load grasp poses on an object. Author: Chenxi Wang Args: obj_id(int): the index of the object. grasp_labels(dict): grasp labels of target objects. fric_coef_thresh(float): maximum friction coefficient for grasps. Returns: grasp_group: grasp pose of the object in object coordinates. """ from .utils.utils import generate_views from .utils.rotation import batch_viewpoint_params_to_matrix if grasp_labels is None: self.logger.warning( "grasp_labels are not given, calling self.load_obj_grasp_labels to retrieve them" ) grasp_labels = self.load_obj_grasp_labels(obj_id) points, offsets, fric_coefs = grasp_labels[obj_id] num_points, num_views, num_angles, num_depths, _ = offsets.shape translations = points[:, np.newaxis, np.newaxis, np.newaxis, :] translations = np.tile(translations, [1, num_views, num_angles, num_depths, 1]) template_views = generate_views(num_views) views = template_views[np.newaxis, :, np.newaxis, np.newaxis, :] views = np.tile(views, [num_points, 1, num_angles, num_depths, 1]) mask = (fric_coefs <= fric_coef_thresh) & (fric_coefs > 0) if collision_labels is not None: mask &= collision_labels[obj_id] == 0 translations = translations[mask] views = views[mask] offsets = offsets[mask] fric_coefs = fric_coefs[mask][:, np.newaxis] angles = offsets[:, 0:1] depths = offsets[:, 1:2] widths = offsets[:, 2:3] rotations = batch_viewpoint_params_to_matrix(-views, angles[:, 0]).reshape( [-1, 9] ) heights = GRASP_HEIGHT * np.ones([angles.shape[0], 1]) object_ids = obj_id * np.ones([angles.shape[0], 1]) scores = 1.1 - fric_coefs ggarray = np.concatenate( [scores, widths, heights, depths, rotations, translations, object_ids], axis=-1, ).astype(np.float32) grasp_group = GraspGroup(ggarray) return grasp_group
def _get_gt_collision_frame_dict(self, split="multiobj_25frame"): """Load dictionary of frame paths which contain ground truth collision labels. Author: Chenxi Wang Args: split(str): which collision labels to load, support 'multiobj_25frame' and 'initial_frame'. Returns: frame_dict(dict): dictionary of frame paths which contain ground truth collision labels. """ assert split in [ "multiobj_25frame", "initial_frame", ], "argument 'split' only support 'multiobj_25frame' and 'initial_frame'" collision_dir = ( MULTIPLE_OBJECT_COLLISION_DIR_NAME if split == "multiobj_25frame" else INITIAL_FRAME_COLLISION_DIR_NAME ) with open(os.path.join(self.root, collision_dir, "frame_dicts.json"), "r") as f: frame_dict = json.load(f) # convert data type of obj_id (str -> int) if split == "initial_frame": for scene_name in frame_dict: for camera_sn in frame_dict[scene_name]: frame_dict_camera = frame_dict[scene_name][camera_sn] new_frame_dict_camera = {} for obj_id in frame_dict_camera: new_frame_dict_camera[int(obj_id)] = frame_dict_camera[obj_id] frame_dict[scene_name][camera_sn] = new_frame_dict_camera return frame_dict def _interpolate_collision_labels(self, scene_name, camera_sn, frame, obj_id): """Interpolate collision labels according to the nearest valid frame if there is not. Author: Chenxi Wang Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the given frame id. obj_id(int): object index. Returns: collision(np.array): interolated collision labels of the given frame. """ collision_dir = os.path.join( self.root, MULTIPLE_OBJECT_COLLISION_DIR_NAME, scene_name, camera_sn ) def _get_collision_path(frame): return os.path.join(collision_dir, frame, "%d.npy" % obj_id) # find nearest frame which has collision labels full_frames = sorted(self.get_frame_list(scene_name, camera_sn)) frame_index = full_frames.index(frame) left, right = frame_index, frame_index while left >= 0: if os.path.exists(_get_collision_path(full_frames[left])): break else: left -= 1 while right < len(full_frames): if os.path.exists(_get_collision_path(full_frames[right])): break else: right += 1 assert left >= 0 or right < len( full_frames ), "no collision label of object %d exists in %s/cam_%s" % ( obj_id, scene_name, camera_sn, ) if left < 0: collision_path = _get_collision_path(full_frames[right]) elif right >= len(full_frames): collision_path = _get_collision_path(full_frames[left]) else: if abs(int(frame) - int(full_frames[left])) < abs( int(full_frames[right]) - int(frame) ): collision_path = _get_collision_path(full_frames[left]) else: collision_path = _get_collision_path(full_frames[right]) collision = np.load(collision_path) return collision def _load_initial_frame_collision_labels( self, scene_name, camera_sn, frame, obj_id ): """Load collisions label of the given object at the initial frame. Author: Chenxi Wang Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. obj_id(int): object index. Returns: collision(np.array): collision labels of the given object at the initial frame. """ frame_dict = self.collision_path_dict_initial_frame log_info = "%s contains no valid collision label, check self.collision_path_dict_initial_frame" assert scene_name in frame_dict, log_info % (scene_name) assert camera_sn in frame_dict[scene_name], log_info % ( "%s/%s" % (scene_name, camera_sn) ) assert obj_id in frame_dict[scene_name][camera_sn], log_info % ( "object %d in %s/%s" % (obj_id, scene_name, camera_sn) ) assert frame == frame_dict[scene_name][camera_sn][obj_id], ( "incorrect initial frame of object %d, check self.collision_path_dict_initial_frame" % obj_id ) collision_path = os.path.join( self.root, INITIAL_FRAME_COLLISION_DIR_NAME, scene_name, camera_sn, frame, "%d.npy" % obj_id, ) collision = np.load(collision_path) return collision
[docs] def load_frame_collision_labels( self, scene_name, camera_sn, frame, obj_ids=None, split="multiobj_25frame" ): """Load collision labels of the given frame. Author: Chenxi Wang Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the given frame id. obj_ids(int/list): object indices. split(str): which collision labels to load, support 'multiobj_25frame' and 'initial_frame'. Returns: collision_labels(dict): collision labels of the given frame. """ full_obj_ids = self.load_scene_registered_object_list( scene_name, camera_sn, frame ) assert split in [ "multiobj_25frame", "initial_frame", ], "argument 'split' only support 'multiobj_25frame' and 'initial_frame'" assert ( split == "initial_frame" or len(full_obj_ids) >= 2 ), "split 'multiobj_25frame' requires more than one objects" if obj_ids is None: obj_ids = full_obj_ids assert _isArrayLike(obj_ids) or isinstance( obj_ids, int ), "obj_ids must be an integer or a list/numpy array of integers" obj_ids = obj_ids if _isArrayLike(obj_ids) else [obj_ids] for obj_id in obj_ids: assert obj_id in full_obj_ids, "obj %d is not in the frame" collision_labels = {} if split == "multiobj_25frame": for obj_id in obj_ids: collision_labels[obj_id] = self._interpolate_collision_labels( scene_name, camera_sn, frame, obj_id ) else: for obj_id in obj_ids: collision_labels[obj_id] = self._load_initial_frame_collision_labels( scene_name, camera_sn, frame, obj_id ) return collision_labels
[docs] def load_frame_grasp_poses( self, scene_name, camera_sn, frame, obj_ids=None, grasp_labels=None, fric_coef_thresh=0.4, num_sample=None, collision_split="multiobj_25frame", ): """Load grasp poses on an object. Author: Chenxi Wang Args: scene_name(str): the scene name. camera_sn(str): the camera serial number. frame(str): the given frame id. obj_ids(int/list): object indices. grasp_labels(dict): grasp labels of target objects. fric_coef_thresh(float): maximum friction coefficient for grasps. num_sample(int): number of samples. collision_split(str): which collision labels to load, support 'multiobj_25frame' and 'initial_frame'. Returns: grasp_group: grasp poses of the given frame in camera coordinates. """ full_obj_ids = self.load_scene_registered_object_list( scene_name, camera_sn, frame ) assert collision_split in [ "multiobj_25frame", "initial_frame", ], "argument 'collision_split' only support 'multiobj_25frame' and 'initial_frame'" if obj_ids is None: obj_ids = full_obj_ids else: assert _isArrayLike(obj_ids) or isinstance( obj_ids, int ), "obj_ids must be an integer or a list/numpy array of integers" obj_ids = obj_ids if _isArrayLike(obj_ids) else [obj_ids] for obj_id in obj_ids: assert obj_id in full_obj_ids, "object %d is not in the scene." % obj_id if grasp_labels is None: self.logger.warning( "grasp_labels are not given, calling self.load_obj_grasp_labels to retrieve them" ) grasp_labels = self.load_obj_grasp_labels(obj_ids) collision_labels = self.load_frame_collision_labels( scene_name, camera_sn, frame, obj_ids, split=collision_split ) gg_array = [] for obj_id in obj_ids: grasp_group_i = self.load_obj_grasp_poses( obj_id, grasp_labels, collision_labels, fric_coef_thresh ) gg_array_i = grasp_group_i.grasp_group_array obj_pose = self.load_object_pose( scene_name, camera_sn, frame, obj_id, registered=True ) rotations = gg_array_i[:, 4:13].reshape([-1, 3, 3]) rotations = np.matmul(obj_pose[np.newaxis, :3, :3], rotations) translations = transform_points(gg_array_i[:, 13:16], obj_pose) gg_array_i[:, 4:13] = rotations.reshape([-1, 9]) gg_array_i[:, 13:16] = translations gg_array.append(gg_array_i) gg_array = np.concatenate(gg_array, axis=0) if num_sample is not None: if num_sample > 0 and num_sample < gg_array.shape[0]: indices = np.random.choice(gg_array.shape[0], num_sample, replace=False) gg_array = gg_array[indices] grasp_group = GraspGroup(gg_array) return grasp_group
[docs] def get_camera_obj_ids(self, scene_name, camera_sn): """Get object indices in a camera Args: scene_name(str): scene name. camera_sn(str): camera serial number. Return: list: the object indices. """ obj_ids = set() frame_list = self.get_frame_list(scene_name=scene_name, camera_sn=camera_sn) for frame in frame_list: obj_ids = obj_ids.union( set( self.load_scene_registered_object_list( scene_name=scene_name, camera_sn=camera_sn, frame=frame ) ) ) return sorted(list(obj_ids))