Skip to content

Latest commit

 

History

History
153 lines (131 loc) · 5.66 KB

1_finetune.md

File metadata and controls

153 lines (131 loc) · 5.66 KB

教程 1:如何微调模型

在 COCO 数据集上进行预训练,然后在其他数据集(如 COCO-WholeBody 数据集)上进行微调,往往可以提升模型的效果。 本教程介绍如何使用模型库中的预训练模型,并在其他数据集上进行微调。

概要

对新数据集上的模型微调需要两个步骤:

  1. 支持新数据集。详情参见 教程 2:如何增加新数据集
  2. 修改配置文件。这部分将在本教程中做具体讨论。

例如,如果想要在自定义数据集上,微调 COCO 预训练的模型,则需要修改 配置文件 中 网络头、数据集、训练策略、预训练模型四个部分。

修改网络头

如果自定义数据集的关键点个数,与 COCO 不同,则需要相应修改 keypoint_head 中的 out_channels 参数。 网络头(head)的最后一层的预训练参数不会被载入,而其他层的参数都会被正常载入。 例如,COCO-WholeBody 拥有 133 个关键点,因此需要把 17 (COCO 数据集的关键点数目) 改为 133。

channel_cfg = dict(
    num_output_channels=133,  # 从 17 改为 133
    dataset_joints=133,  # 从 17 改为 133
    dataset_channel=[
        list(range(133)),  # 从 17 改为 133
    ],
    inference_channel=list(range(133)))  # 从 17 改为 133

# model settings
model = dict(
    type='TopDown',
    pretrained='https://download.openmmlab.com/mmpose/'
    'pretrain_models/hrnet_w48-8ef0771d.pth',
    backbone=dict(
        type='HRNet',
        in_channels=3,
        extra=dict(
            stage1=dict(
                num_modules=1,
                num_branches=1,
                block='BOTTLENECK',
                num_blocks=(4, ),
                num_channels=(64, )),
            stage2=dict(
                num_modules=1,
                num_branches=2,
                block='BASIC',
                num_blocks=(4, 4),
                num_channels=(48, 96)),
            stage3=dict(
                num_modules=4,
                num_branches=3,
                block='BASIC',
                num_blocks=(4, 4, 4),
                num_channels=(48, 96, 192)),
            stage4=dict(
                num_modules=3,
                num_branches=4,
                block='BASIC',
                num_blocks=(4, 4, 4, 4),
                num_channels=(48, 96, 192, 384))),
    ),
    keypoint_head=dict(
        type='TopdownHeatmapSimpleHead',
        in_channels=48,
        out_channels=channel_cfg['num_output_channels'], # 已对应修改
        num_deconv_layers=0,
        extra=dict(final_conv_kernel=1, ),
        loss_keypoint=dict(type='JointsMSELoss', use_target_weight=True)),
    train_cfg=dict(),
    test_cfg=dict(
        flip_test=True,
        post_process='unbiased',
        shift_heatmap=True,
        modulate_kernel=17))

其中, pretrained='https://download.openmmlab.com/mmpose/pretrain_models/hrnet_w48-8ef0771d.pth' 表示采用 ImageNet 预训练的权重,初始化主干网络(backbone)。 不过,pretrained 只会初始化主干网络(backbone),而不会初始化网络头(head)。因此,我们模型微调时的预训练权重一般通过 load_from 指定,而不是使用 pretrained 指定。

支持自己的数据集

MMPose 支持十余种不同的数据集,包括 COCO, COCO-WholeBody, MPII, MPII-TRB 等数据集。 用户可将自定义数据集转换为已有数据集格式,并修改如下字段。

data_root = 'data/coco'
data = dict(
    samples_per_gpu=32,
    workers_per_gpu=2,
    val_dataloader=dict(samples_per_gpu=32),
    test_dataloader=dict(samples_per_gpu=32),
    train=dict(
        type='TopDownCocoWholeBodyDataset', # 对应修改数据集名称
        ann_file=f'{data_root}/annotations/coco_wholebody_train_v1.0.json', # 修改数据集标签路径
        img_prefix=f'{data_root}/train2017/',
        data_cfg=data_cfg,
        pipeline=train_pipeline),
    val=dict(
        type='TopDownCocoWholeBodyDataset', # 对应修改数据集名称
        ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json', # 修改数据集标签路径
        img_prefix=f'{data_root}/val2017/',
        data_cfg=data_cfg,
        pipeline=val_pipeline),
    test=dict(
        type='TopDownCocoWholeBodyDataset', # 对应修改数据集名称
        ann_file=f'{data_root}/annotations/coco_wholebody_val_v1.0.json', # 修改数据集标签路径
        img_prefix=f'{data_root}/val2017/',
        data_cfg=data_cfg,
        pipeline=val_pipeline)
)

修改训练策略

通常情况下,微调模型时设置较小的学习率和训练轮数,即可取得较好效果。

# 优化器
optimizer = dict(
    type='Adam',
    lr=5e-4, # 可以适当减小
)
optimizer_config = dict(grad_clip=None)
# 学习策略
lr_config = dict(
    policy='step',
    warmup='linear',
    warmup_iters=500,
    warmup_ratio=0.001,
    step=[170, 200]) # 可以适当减小
total_epochs = 210 # 可以适当减小

使用预训练模型

网络设置中的 pretrained,仅会在主干网络模型上加载预训练参数。若要载入整个网络的预训练参数,需要通过 load_from 指定模型文件路径或模型链接。

# 将预训练模型用于整个 HRNet 网络
load_from = 'https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w48_coco_384x288_dark-741844ba_20200812.pth'  # 模型路径可以在 model zoo 中找到