pytorch_keypoint/DeepPose/README.md
论文名称:DeepPose: Human Pose Estimation via Deep Neural Networks
论文arxiv链接:https://arxiv.org/abs/1312.4659
开发环境主要信息如下,其他Python依赖详情可见requirements.txt文件
该项目采用的训练数据是WFLW数据集(人脸98点检测),官方链接:https://wywu.github.io/projects/LAB/WFLW.html
在官方网页下载数据集后解压并组织成如下目录形式:
WFLW
├── WFLW_annotations
│ ├── list_98pt_rect_attr_train_test
│ └── list_98pt_test
└── WFLW_images
├── 0--Parade
├── 1--Handshaking
├── 10--People_Marching
├── 11--Meeting
├── 12--Group
└── ......
由于该项目默认使用的backbone是torchvision中的resnet50,在实例化模型时会自动下载在imagenet上的预训练权重。
~/.cache/torch/hub/checkpoints目录下即可将训练脚本中的--dataset_dir设置成自己构建的WFLW数据集绝对路径,例如/home/wz/datasets/WFLW
使用train.py脚本:
python train.py
使用train_multi_GPU.py脚本:
torchrun --nproc_per_node=8 train_multi_GPU.py
若要单独指定使用某些卡可在启动指令前加入CUDA_VISIBLE_DEVICES参数,例如:
CUDA_VISIBLE_DEVICES=4,5,6,7 torchrun --nproc_per_node=4 train_multi_GPU.py
若没有训练条件或者只想简单体验下,可使用本人训练好的模型权重(包含optimizer等信息故文件会略大),该权重在WFLW验证集上的NME指标为0.048,百度网盘下载地址:https://pan.baidu.com/s/1L_zg-fmocEyzhSTxj8IDJw
提取码:8fux
下载完成后在当前项目下创建一个weights文件夹,并将权重放置该文件夹内。
可参考predict.py文件,将img_path设置成自己要预测的人脸图片(注意这里只支持单人脸的关键点检测,故需要提供单独的人脸图片,具体使用时可配合一个人脸检测器联合使用),例如输入图片:
网络预测可视化结果为:
若需要导出ONNX模型可使用export_onnx.py脚本。