通过PyTorch Hub加载YOLOv5

一、准备

PyTorch安装请点

二、简单示例

这里使用轻量级yolov5s模型。

import torch

# Model
model = torch.hub.load(ultralytics/yolov5, yolov5s)

# Image
img = https://ultralytics.com/images/zidane.jpg

# Inference
results = model(img)

三、详细示例

代码中同时使用PIL和OpenCV,识别结果保存在runs/hub目录下。

import cv2
import torch
from PIL import Image

# Model
model = torch.hub.load(ultralytics/yolov5, yolov5s)

# Images
for f in [zidane.jpg, bus.jpg]:
    torch.hub.download_url_to_file(https://ultralytics.com/images/ + f, f)  # download 2 images
img1 = Image.open(zidane.jpg)  # PIL image
img2 = cv2.imread(bus.jpg)[:, :, ::-1]  # OpenCV image (BGR to RGB)
imgs = [img1, img2]  # batch of images

# Inference
results = model(imgs, size=640)  # includes NMS

# Results
results.print()  
results.save()  # or .show()

results.xyxy[0]  # img1 predictions (tensor)
results.pandas().xyxy[0]  # img1 predictions (pandas)
#      xmin    ymin    xmax   ymax  confidence  class    name
# 0  749.50   43.50  1148.0  704.5    0.874023      0  person
# 1  433.50  433.50   517.5  714.5    0.687988     27     tie
# 2  114.75  195.75  1095.0  708.0    0.624512      0  person
# 3  986.00  304.00  1028.0  420.0    0.286865     27     tie

四、参数设置

这里参数主要是指置信度阈值,阈值,分类筛选器等模型属性参数。
model.conf = 0.25  # confidence threshold (0-1)
model.iou = 0.45  # NMS IoU threshold (0-1)
model.classes = None  # (optional list) filter by class, i.e. = [0, 15, 16] for persons, cats and dogs

results = model(imgs, size=320)  # custom inference size

五、输入通道设置

加载YOLOv5s模型输入通道数默认值为3,可以通过以下方式修改。

# 这里将通道数设置为4
model = torch.hub.load(ultralytics/yolov5, yolov5s, channels=4)

六、分类设置

YOLOv5模型默认分类数为80,可以通过以下方式修改。

model = torch.hub.load(ultralytics/yolov5, yolov5s, classes=10)

七、强制重新加载

使用force_reload=True可以帮助清理缓存并且强制更新下载最新YOLOv5版本。

model = torch.hub.load(ultralytics/yolov5, yolov5s, force_reload=True)  # force reload

八、训练

加载YOLOv5模型如果是为了训练,可以设置autoshape=False。 加载模型并随机初始化权值可以设置pretrained=False。

model = torch.hub.load(ultralytics/yolov5, yolov5s, autoshape=False)  # load pretrained
model = torch.hub.load(ultralytics/yolov5, yolov5s, autoshape=False, pretrained=False)  # load scratch

九、Base64结果

示例如下:

results = model(imgs)  # inference

results.imgs # array of original images (as np array) passed to model for inference
results.render()  # updates results.imgs with boxes and labels
for img in results.imgs:
    buffered = BytesIO()
    img_base64 = Image.fromarray(img)
    img_base64.save(buffered, format="JPEG")
    print(base64.b64encode(buffered.getvalue()).decode(utf-8))  # base64 encoded image with results

十、JSON结果

示例如下:

results = model(imgs)  # inference

results.pandas().xyxy[0].to_json(orient="records")  # JSON img1 predictions
json [ {“xmin”:749.5,“ymin”:43.5,“xmax”:1148.0,“ymax”:704.5,“confidence”:0.8740234375,“class”:0,“name”:“person”}, {“xmin”:433.5,“ymin”:433.5,“xmax”:517.5,“ymax”:714.5,“confidence”:0.6879882812,“class”:27,“name”:“tie”}, {“xmin”:115.25,“ymin”:195.75,“xmax”:1096.0,“ymax”:708.0,“confidence”:0.6254882812,“class”:0,“name”:“person”}, {“xmin”:986.0,“ymin”:304.0,“xmax”:1028.0,“ymax”:420.0,“confidence”:0.2873535156,“class”:27,“name”:“tie”} ]
经验分享 程序员 微信小程序 职场和发展