1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
| import torch from torch.utils.data import Dataset from torchvision import transforms import numpy as np import cv2
class H36MDataset(Dataset): def __init__(self, root, image_size=224, transform=None): self.root = root self.image_size = image_size self.transform = transform or transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) self.samples = [...]
def __len__(self): return len(self.samples)
def __getitem__(self, idx): sample = self.samples[idx] image = cv2.imread(sample['image_path']) image = cv2.resize(image, (self.image_size, self.image_size)) joints_3d = sample['joints_3d'] smplx_params = sample['smplx']
if self.transform: image = self.transform(image)
return { 'image': image, 'joints_3d': torch.tensor(joints_3d, dtype=torch.float32), 'smplx': torch.tensor(smplx_params, dtype=torch.float32) }
|