探索MediaPipe自定义机器学习模型
MediaPipe支持人脸识别、目标检测、图像分类、人像分割、手势识别、文本分类、语音分类。每个模块都有对应的模型,但是原有模型可能比较大、推理耗时比较长,我们可以自定义模型来进行进行优化。
一、训练准备
1、准备数据
在自定义模型前,准备两种数据:原始数据、标注数据。
1.1 原始数据
1.2 标注数据
我们可以用来添加标注。支持3种形式安装:pip、Anaconda、docker。这里以pip安装为例:
# Requires Python >=3.7 <=3.9 pip install label-studio # Start the server at http://localhost:8080 label-studio
2、简化模型
2.1 减少标签
选择2-5个类别给图像打标签,遵从简单原则。
2.2 剪裁边缘
样本图像尽可能保留完整轮廓。剩下一部分样本进行裁剪,这样利于提高模型的鲁棒性。
2.3 模型复用
由于使用,即复用原有模型,使用新数据来重新训练原来的模型。这样可以节省训练时间,节约模型数据。Model Maker可用于训练物体检测、手势检测、图像分类、音频分类的模型。通过删除数据分类的层级,然后使用新数据来重建,最终输出新模型,框架图如下:
大概需要100个样本,其中80%用于训练,10%用于测试,剩下10%用于验证。
3、训练迭代
第一次训练的模型比较难达到理想效果。那么,我们需要花时间去选择合适样本,添加恰当标注,从而提升成功率。添加样本,或者修改样本,反复迭代训练,不断完善。
二、目标检测训练
1、准备安装包
安装mediepipe model maker:
pip install --upgrade pip pip install mediapipe-model-maker
导入object detector包:
import os import tensorflow as tf assert tf.__version__.startswith(2) from google.colab import files from mediapipe_model_maker import object_detector
2、准备数据集
从官网下载数据集,以小狗动物为例:
并且声明模型的训练路径、验证路径:
train_dataset_path = "dogs/train" validation_dataset_path = "dogs/validate"
3、加载数据集
加载训练、验证的数据集:
train_data = object_detector.Dataset.from_pascal_voc_folder( dogs copy/train, cache_dir="/tmp/od_data/train") validate_data = object_detector.Dataset.from_pascal_voc_folder( dogs copy/validate, cache_dir="/tmp/od_data/validatation")
4、训练模型
使用样本数据来训练TensorFlow模型,设置相关参数:
-
batch_size=8 learning_rate=0.3 epochs=50
根据参数选项、数据路径来创建模型:
hparams = object_detector.HParams(batch_size=8, learning_rate=0.3, epochs=50, export_dir=exported_model) options = object_detector.ObjectDetectorOptions( supported_model=object_detector.SupportedModels.MOBILENET_V2, hparams=hparams) model = object_detector.ObjectDetector.create( train_data=train_data, validation_data=validate_data, options=options)
5、验证模型
使用未用过的图像来验证模型:
loss, coco_metrics = model.evaluate(validate_data, batch_size=4) print(f"Validation loss: {loss}") print(f"Validation coco metrics: {coco_metrics}")
6、导出模型
以TensorFlow Lite的格式导出模型,然后下载下来:
model.export_model(dogs.tflite) !ls exported_model files.download(exported_model/dogs.tflite)
下一篇:
【固态硬盘】入门讲解