Yan's World.

基于Transformer的人体姿态重建(二):实现模板

Word count: 1.4kReading time: 7 min
2024/08/09
loading

原创性声明

本文为作者原创,在个人Blog首次发布,如需转载请注明引用出处。(yanzhang.cg@gmail.comhttps://graphicyan.github.io/)


一、项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
transformer-pose-reconstruction/
├── configs/ # 配置文件
│ └── base_config.yaml
├── data/ # 数据处理模块
│ ├── datasets/
│ │ ├── h36m_dataset.py
│ │ └── agora_dataset.py
│ └── transforms.py
├── models/ # 模型定义模块
│ ├── backbones/
│ │ └── vit.py
│ ├── heads/
│ │ └── smplx_head.py
│ └── pose_regressor.py
├── loss/ # 损失函数模块
│ └── loss_functions.py
├── utils/ # 工具函数
│ ├── visualization.py
│ └── smplx_utils.py
├── training/ # 训练脚本
│ └── train.py
├── evaluation/ # 评估脚本
│ └── evaluate.py
├── inference/ # 推理模块
│ └── infer.py
├── checkpoints/ # 模型权重保存路径
├── logs/ # 日志文件(TensorBoard等)
├── README.md
└── requirements.txt

二、核心模块详解

2.1 配置文件(configs/base_config.yaml)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# Model configuration
model:
backbone:
name: "vit-base"
pretrained: true
image_size: 224
patch_size: 16
head:
type: "transformer"
num_queries: 10
smplx_dim: 126
smplx:
model_path: "data/models/smplx/SMPLX_NEUTRAL.npz"

# Dataset configuration
data:
dataset:
name: "h36m"
root: "data/datasets/h36m"
image_size: 224
split_ratio: 0.9
batch_size: 64
num_workers: 4

# Training configuration
training:
epochs: 100
lr: 1e-4
weight_decay: 0.01
device: "cuda"
log_freq: 10
save_freq: 10
log_dir: "logs/"
checkpoint_dir: "checkpoints/"

# Evaluation configuration
evaluation:
metrics:
- "mpjpe"
- "pa-mpjpe"
- "mpvpe"

2.2 数据模块(data/datasets/h36m_dataset.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import numpy as np
import cv2

class H36MDataset(Dataset):
def __init__(self, root, image_size=224, transform=None):
self.root = root
self.image_size = image_size
self.transform = transform or transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 加载标注数据(此处为示意,实际应从文件加载)
self.samples = [...] # 包含图像路径、3D关节、SMPLX参数等

def __len__(self):
return len(self.samples)

def __getitem__(self, idx):
sample = self.samples[idx]
image = cv2.imread(sample['image_path'])
image = cv2.resize(image, (self.image_size, self.image_size))
joints_3d = sample['joints_3d']
smplx_params = sample['smplx']

if self.transform:
image = self.transform(image)

return {
'image': image,
'joints_3d': torch.tensor(joints_3d, dtype=torch.float32),
'smplx': torch.tensor(smplx_params, dtype=torch.float32)
}

2.3 模型模块(models)

Vision Transformer Backbone(models/backbones/vit.py)

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn
from transformers import ViTModel

class ViTBackbone(nn.Module):
def __init__(self, model_name="google/vit-base-patch16-224", pretrained=True):
super().__init__()
self.vit = ViTModel.from_pretrained(model_name) if pretrained else ViTModel(config=...)

def forward(self, x):
outputs = self.vit(x)
return outputs.last_hidden_state # shape: (B, N, D)

SMPLX Head(models/heads/smplx_head.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import torch.nn as nn

class SMPLXHead(nn.Module):
def __init__(self, embed_dim, smplx_dim=126):
super().__init__()
self.head = nn.Sequential(
nn.Linear(embed_dim, 512),
nn.ReLU(),
nn.Linear(512, smplx_dim)
)

def forward(self, x):
# x: (B, N, D)
x = x.mean(dim=1) # 全局平均池化
return self.head(x)

总体模型(models/pose_regressor.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from models.backbones.vit import ViTBackbone
from models.heads.smplx_head import SMPLXHead

class PoseRegressor(nn.Module):
def __init__(self, config):
super().__init__()
self.backbone = ViTBackbone(config['model']['backbone']['name'])
self.head = SMPLXHead(
embed_dim=768,
smplx_dim=config['model']['head']['smplx_dim']
)

def forward(self, x):
features = self.backbone(x)
smplx_params = self.head(features)
return smplx_params

2.4 损失函数(loss/loss_functions.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch
import torch.nn as nn

class PoseLoss(nn.Module):
def __init__(self, lambda_joint=1.0, lambda_smplx=1.0):
super().__init__()
self.lambda_joint = lambda_joint
self.lambda_smplx = lambda_smplx
self.l2_loss = nn.MSELoss()

def forward(self, pred_smplx, target_smplx, pred_joints, target_joints):
loss_joint = self.l2_loss(pred_joints, target_joints)
loss_smplx = self.l2_loss(pred_smplx, target_smplx)
total_loss = self.lambda_joint * loss_joint + self.lambda_smplx * loss_smplx
return total_loss

2.5 训练脚本(training/train.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
from torch.utils.data import DataLoader
from models.pose_regressor import PoseRegressor
from data.datasets.h36m_dataset import H36MDataset
from loss.loss_functions import PoseLoss
from configs.base_config import load_config
import os

def train():
config = load_config("configs/base_config.yaml")
device = torch.device(config['training']['device'])

# 数据集
train_dataset = H36MDataset(config['data']['root'])
train_loader = DataLoader(train_dataset, batch_size=config['data']['batch_size'],
shuffle=True, num_workers=config['data']['num_workers'])

# 模型
model = PoseRegressor(config).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=config['training']['lr'],
weight_decay=config['training']['weight_decay'])

criterion = PoseLoss()

# 训练循环
for epoch in range(config['training']['epochs']):
model.train()
for batch in train_loader:
images = batch['image'].to(device)
joints_3d = batch['joints_3d'].to(device)
smplx_params = batch['smplx'].to(device)

pred_smplx = model(images)
# 这里应调用SMPLX模型生成joints用于监督
pred_joints = smplx_forward(pred_smplx) # 需要实现SMPLX前向函数

loss = criterion(pred_smplx, smplx_params, pred_joints, joints_3d)

optimizer.zero_grad()
loss.backward()
optimizer.step()

# 保存模型
if (epoch + 1) % config['training']['save_freq'] == 0:
torch.save(model.state_dict(), os.path.join(config['training']['checkpoint_dir'], f"model_{epoch}.pth"))

print(f"Epoch {epoch}, Loss: {loss.item()}")

if __name__ == "__main__":
train()

2.6 推理模块(inference/infer.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import cv2
import numpy as np
from models.pose_regressor import PoseRegressor
from configs.base_config import load_config
from utils.smplx_utils import load_smplx_model, get_joints_from_smplx

def infer(image_path, model_path):
config = load_config("configs/base_config.yaml")
device = torch.device(config['training']['device'])

# 加载模型
model = PoseRegressor(config)
model.load_state_dict(torch.load(model_path))
model.to(device)
model.eval()

# 图像预处理
image = cv2.imread(image_path)
image = cv2.resize(image, (config['data']['image_size'], config['data']['image_size']))
image = image / 255.0
image = torch.tensor(image, dtype=torch.float32).permute(2, 0, 1).unsqueeze(0).to(device)

with torch.no_grad():
smplx_params = model(image)
vertices, joints = get_joints_from_smplx(smplx_params)

# 可视化
print("Predicted 3D Joints:", joints)
# 可视化代码见 utils/visualization.py

if __name__ == "__main__":
infer("data/images/test.jpg", "checkpoints/model_final.pth")

2.7 SMPLX 工具函数(utils/smplx_utils.py)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
import smplx

def load_smplx_model(model_path):
model = smplx.create(model_path, model_type='smplx',
gender='neutral', ext='npz')
return model

def get_joints_from_smplx(smplx_params):
# 假设smplx_params = [beta, theta, gamma]
# 需要拆分参数
model = load_smplx_model("data/models/smplx/SMPLX_NEUTRAL.npz")
output = model(betas=smplx_params[:, :10],
global_orient=smplx_params[:, 10:13],
body_pose=smplx_params[:, 13:])
return output.vertices, output.joints

三、依赖安装(requirements.txt)

1
2
3
4
5
6
7
8
9
torch>=1.10
torchvision
transformers
opencv-python
numpy
yacs # 用于配置文件
smplx
pytorch3d
tensorboard

四、使用说明

1. 安装依赖

1
pip install -r requirements.txt

2. 数据准备

  • 下载并解压 H36M、AGORA 等数据集到 data/datasets/
  • 下载 SMPLX 模型并解压到 data/models/smplx/

3. 开始训练

1
python training/train.py

4. 推理测试

1
python inference/infer.py --image data/images/test.jpg --model checkpoints/model_final.pth

五、扩展建议

  • 增加 Transformer Decoder 结构用于更精细的参数预测
  • 添加 Temporal Transformer 模块用于视频序列建模
  • 支持 ONNX 导出TensorRT 加速部署
  • 集成 MAE 等自监督预训练方法
  • 支持多视角融合与深度估计

CATALOG
  1. 1. 原创性声明
  2. 2. 一、项目结构
  3. 3. 二、核心模块详解
    1. 3.1. 2.1 配置文件(configs/base_config.yaml)
    2. 3.2. 2.2 数据模块(data/datasets/h36m_dataset.py)
    3. 3.3. 2.3 模型模块(models)
      1. 3.3.1. Vision Transformer Backbone(models/backbones/vit.py)
      2. 3.3.2. SMPLX Head(models/heads/smplx_head.py)
      3. 3.3.3. 总体模型(models/pose_regressor.py)
    4. 3.4. 2.4 损失函数(loss/loss_functions.py)
    5. 3.5. 2.5 训练脚本(training/train.py)
    6. 3.6. 2.6 推理模块(inference/infer.py)
    7. 3.7. 2.7 SMPLX 工具函数(utils/smplx_utils.py)
  4. 4. 三、依赖安装(requirements.txt)
  5. 5. 四、使用说明
    1. 5.1. 1. 安装依赖
    2. 5.2. 2. 数据准备
    3. 5.3. 3. 开始训练
    4. 5.4. 4. 推理测试
  6. 6. 五、扩展建议