使用回調函數訓練YOLO模型
大多數人可能熟悉如何訓練計算機視覺模型,比如流行的YOLO模型,甚至知道如何使用這些模型進行預測。但你知道我們可以通過回調函數為這些模型增加一些靈活性,以便在模型訓練和模型推斷中使用嗎?大多數最先進的(SOTA)YOLO模型,如YOLOv8和YOLO-NAS,都實現了回調函數,我們可以調整這些函數以有效地利用我們的計算機視覺模型的訓練和推斷。
考慮以下情景。假設你是一名計算機視覺工程師,與團隊中的許多工程師一起工作。你正在使用自定義數據集訓練自定義的計算機視覺模型(也許是YOLO),以實現一些業(yè)務邏輯。你負責實現訓練和推斷邏輯。除此之外,你還需要報告模型的訓練進度、訓練模型的準確性等。作為一名工程師,你決定在很多個epoch上訓練你的模型,這可能需要幾天的時間,具體取決于一些因素,比如數據集的數量、服務器資源等。你需要密切關注模型的訓練進度,因為由于諸如服務器資源問題等原因,模型可能在一段時間后停止訓練,導致訓練崩潰。你可能也希望在模型訓練完成后收到自動警報,比如在訓練結束后收到帶有驗證指標的電子郵件,或者在模型訓練完成后自動向團隊負責人發(fā)送報告。這些以及許多其他事情都是你作為計算機視覺工程師可能想要做的事情。
要實現以上任何一種情況,我們需要一種回調函數。這就是在訓練計算機視覺模型時回調函數的作用。好消息是,大多數SOTA YOLO模型默認實現了這些回調函數。例如,默認情況下,YOLOv8和YOLO-NAS實現了這些回調函數,你可以在訓練或進行模型預測時有效地利用它們。在本文章中,我將向你展示一些示例,演示在訓練YOLO模型時如何使用回調函數。在本例中,我將使用YOLOv8,但請注意,這可以擴展到其他一些YOLO模型,比如YOLO-NAS。
讓我們繼續(xù)演示如何在YOLOv8上實現回調函數。我們將編寫代碼并在自定義數據集上訓練我們的模型。我們將實現回調函數。其中一個功能是在模型訓練結束后向我們的團隊工程師發(fā)送電子郵件。我們發(fā)送的電子郵件將包含受過訓練模型的報告,如指標、訓練模型所花費的時間等。
項目實施步驟
第1步:創(chuàng)建一個文件夾并給它命名(在我的案例中,我將我的文件夾命名為“yolo_with_callbacks”)。
在你創(chuàng)建的文件夾中,創(chuàng)建一個新的文本文件(requirements.txt)并添加以下內容:
opencv-python==4.8.1.78
Pillow==10.0.1
tqdm==4.66.1
ultralytics==8.1.2
python-dotenv==1.0.1
然后,在你的項目文件夾中創(chuàng)建一個Python虛擬環(huán)境,并安裝requirements.txt文件中列出的依賴項。
python3 -m venv env
接下來,通過運行以下命令激活新創(chuàng)建的虛擬環(huán)境:
source env/bin/activate # if you are using Ubuntu
source env/Scripts/activate # if you are using Windows
然后,通過運行以下命令安裝依賴項:
pip install -r requirements.txt
第2步:下載一個用于自定義模型訓練的示例數據集。
你可以使用任何你選擇的數據集,只要注釋是以YOLO格式提供的即可。在我的案例中,為了本教程的目的,我將使用來自Roboflow的POTHOLE數據集,你可以從這個鏈接下載:POTHOLE數據集。下載數據集后,你將得到三個文件夾(train、val和test)?,F在,在你的項目目錄中創(chuàng)建一個數據集文件夾,并將你下載的數據集(train、val和test)復制到這個文件夾中。你的數據集文件夾應該如下所示:
Datasets
└── train
├── images
└── labels
└── val
├── images
└── labels
└── test
├── images
└── labels
接下來,在項目根目錄中創(chuàng)建一個數據集配置文件(我們稱之為data.yaml)并在YAML文件中添加以下內容:
train: ./dataset/train/images
val: ./dataset/val/images
test: ./dataset/test/images
nc: 1
names: ['pothole']
第3步:創(chuàng)建模型訓練腳本。
接下來,我們需要編寫代碼來使用我們的自定義數據集訓練模型。之后,我們將繼續(xù)實現模型的回調函數,這是本教程的唯一目的?,F在,在你的項目根目錄中創(chuàng)建一個新文件(命名為training.py)。在這個training.py文件中,我們將實現模型訓練和回調函數。首先,讓我們編寫一個用于訓練YOLOV8模型的函數:
def train_yolov8_model(config_path, num_epochs, training_result_dir):
model = YOLO("yolov8x.pt")
model.add_callback("on_train_start", on_train_start)
model.add_callback("on_train_epoch_end", on_train_epoch_end)
model.add_callback("on_train_end", on_train_end)
model.start_time = datetime.now()
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Train the model
model.train(
data=config_path,
name="Yolo_Model_Training",
project=training_result_dir,
task="detect",
epochs=num_epochs,
patience=20,
batch=16,
cache=True,
imgsz=640,
iou=0.5,
augment=True,
degrees=25.0,
fliplr=0.0,
lr0=0.0001,
optimizer="Adam",
device=device,
)
注意:函數參數中的config_path是我們之前創(chuàng)建的數據集yaml配置文件。我們稍后將定義的回調函數,就像model.add_callback這樣的調用,稍等一下。
接下來,讓我們實現回調函數。在這種情況下,我們將要實現的回調函數包括:on_train_start、on_train_epoch_end和on_train_end。on_train_start回調是在模型開始訓練時立即觸發(fā)的回調函數。on_train_epoch_end是在每個epoch結束后立即觸發(fā)的回調函數。on_train_end是在模型完成訓練后觸發(fā)的回調函數。
實現回調函數
def on_train_start(trainer):
start_time = datetime.now()
def on_train_epoch_end(trainer):
curr_epoch = trainer.epoch + 1
text = f"Epoch Number: {curr_epoch}/{trainer.epochs} finished"
print(text)
print("-" * 50)
對于on_train_start回調,我們需要追蹤模型開始訓練的確切時間。你實際上可以在這里實現更復雜的邏輯。對于on_train_epoch_end,我們只是獲取了當前epoch并打印出來。這只是一個簡單的演示。我們可以在這里實現更復雜的邏輯。例如,如果我們有一個用戶正在從中訓練模型的前端應用程序,我們可以在每個epoch結束后更新GUI的訓練進度條。我們可以在這個函數中實現這個功能。
現在,讓我們繼續(xù)實現本教程的主要邏輯。我們將繼續(xù)實現on_train_end回調函數。如前所述,此函數僅在模型訓練成功完成后觸發(fā)。在我們的情況下,我們想要發(fā)送一個包含模型訓練報告的電子郵件給我們的團隊工程師。為了實現這一點,首先,讓我們編寫一個發(fā)送電子郵件的函數。我們將使用Gmail發(fā)送電子郵件。
以下是發(fā)送電子郵件的函數:
def send_email(
body,
from_email=FROM_EMAIL,
to_emails=RECIPENT_EMAIL,
subject=subject,
api=EMAIL_API_KEY,
):
msg = MIMEMultipart()
msg["From"] = from_email
msg["To"] = to_emails
msg["Subject"] = subject
msg.attach(MIMEText(body, "html"))
try:
smtp_server = smtplib.SMTP("smtp.gmail.com", 587)
smtp_server.starttls()
smtp_server.login(from_email, api)
smtp_server.sendmail(from_email, to_emails, msg.as_string())
smtp_server.quit()
print("Email sent.")
except Exception as e:
print("Email not sent", e)
但請注意,我們需要將諸如EMAIL API KEY、SENDER EMAIL等秘密憑證存儲到一個環(huán)境文件中?;诖?,請在你的項目根目錄中創(chuàng)建一個新文件(命名為.env)。在.env文件中,添加以下示例內容。
EMAIL_API_KEY=your Gmail app password goes here
EMAIL_ACCOUNT=your Gmail account which you created app password goes here
RECIPENT_EMAIL=the email address you will be sending the report email goes here.
現在,讓我們繼續(xù)實現回調函數(on_train_end),該函數將在模型訓練成功完成后觸發(fā)發(fā)送電子郵件功能。
def on_train_end(trainer):
trainer_epoch = trainer.epoch
trainer_metrics = trainer.metrics
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
end_time = datetime.now()
time_taken = end_time - start_time
hours, remainder = divmod(time_taken.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
time_taken_str = ""
if int(hours) > 0:
time_taken_str += f"{int(hours)} hr "
if int(minutes) > 0:
time_taken_str += f"{int(minutes)} mins "
if int(seconds) > 0:
time_taken_str += f"{int(seconds)} secs"
time_taken_str = time_taken_str.strip()
body = f"""
<html>
<head>
<style>
table, th, td {{
border: 1px solid black;
border-collapse: collapse;
padding: 5px;
}}
</style>
</head>
<body>
<h1>Training Report</h1>
<p>Date and Time: {current_time}</p>
<p>Total Epoch Trained: {trainer_epoch + 1} </p>
<p>Time Taken to Train Model: {time_taken_str} </p>
<table>
<tr>
<th>Metric</th>
<th>Value</th>
</tr>
{''.join([f'<tr><td>{k}</td><td>{v:.2f}</td></tr>' for k, v in trainer_metrics.items()])}
</table>
</body>
</html>
"""
send_email(body)
以上回調函數將在模型訓練完成后向指定收件人發(fā)送報告郵件?,F在,我們已經編寫了所有必要的函數,將它們全部封裝在一個名為ModelTraining的類中是一個好主意。所以,我們training.py文件中的完整代碼現在應該如下所示:
import os
from datetime import datetime
from dotenv import find_dotenv, load_dotenv
import torch
from ultralytics import YOLO
import smtplib
from email.mime.text import MIMEText
from email.mime.multipart import MIMEMultipart
load_dotenv(find_dotenv())
EMAIL_API_KEY = os.getenv("EMAIL_API_KEY")
FROM_EMAIL = os.getenv("EMAIL_ACCOUNT")
RECIPIENT_EMAIL = os.getenv("RECIPIENT_EMAIL")
subject = "Model Training Completed"
class ModelTraining:
def __init__(self):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.start_time = None
self.end_time = None
def send_email(
self,
body,
from_email=FROM_EMAIL,
to_emails=RECIPIENT_EMAIL,
subject=subject,
api=EMAIL_API_KEY,
):
msg = MIMEMultipart()
msg["From"] = from_email
msg["To"] = to_emails
msg["Subject"] = subject
msg.attach(MIMEText(body, "html"))
try:
smtp_server = smtplib.SMTP("smtp.gmail.com", 587)
smtp_server.starttls()
smtp_server.login(from_email, api)
smtp_server.sendmail(from_email, to_emails, msg.as_string())
smtp_server.quit()
print("Email sent.")
except Exception as e:
print("Email not sent", e)
def on_train_start(self, trainer):
self.start_time = datetime.now()
def on_train_epoch_end(self, trainer):
curr_epoch = trainer.epoch + 1
text = f"Epoch Number: {curr_epoch}/{trainer.epochs} finished"
print(text)
print("-" * 50)
def on_train_end(self, trainer):
trainer_epoch = trainer.epoch
trainer_metrics = trainer.metrics
current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.end_time = datetime.now()
time_taken = self.end_time - self.start_time
hours, remainder = divmod(time_taken.total_seconds(), 3600)
minutes, seconds = divmod(remainder, 60)
time_taken_str = ""
if int(hours) > 0:
time_taken_str += f"{int(hours)} hr "
if int(minutes) > 0:
time_taken_str += f"{int(minutes)} mins "
if int(seconds) > 0:
time_taken_str += f"{int(seconds)} secs"
time_taken_str = time_taken_str.strip()
body = f"""
<html>
<head>
<style>
table, th, td {{
border: 1px solid black;
border-collapse: collapse;
padding: 5px;
}}
</style>
</head>
<body>
<h1>Training Report</h1>
<p>Date and Time: {current_time}</p>
<p>Total Epochs Trained: {trainer_epoch + 1} </p>
<p>Time Taken to Train Model: {time_taken_str} </p>
<table>
<tr>
<th>Metric</th>
<th>Value</th>
</tr>
{''.join([f'<tr><td>{k}</td><td>{v:.2f}</td></tr>' for k, v in trainer_metrics.items()])}
</table>
</body>
</html>
"""
self.send_email(body)
def train_yolov8_model(self, config_path, num_epochs, training_result_dir):
model = YOLO("yolov8x.pt")
model.add_callback("on_train_start", self.on_train_start)
model.add_callback("on_train_epoch_end", self.on_train_epoch_end)
model.add_callback("on_train_end", self.on_train_end)
model.start_time = datetime.now()
start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# Train the model
model.train(
data=config_path,
name="Yolo_Model_Training",
project=training_result_dir,
task="detect",
epochs=num_epochs,
patience=20,
batch=16,
cache=True,
imgsz=640,
iou=0.5,
augment=True,
degrees=25.0,
fliplr=0.0,
lr0=0.0001,
optimizer="Adam",
device=self.device,
)
model.end_time = datetime.now()
if __name__ == "__main__":
model_training = ModelTraining()
# Load the dataset configuration file
current_dir = os.path.dirname(os.path.abspath(__file__))
config_path = os.path.join(current_dir, "data.yaml")
num_epochs = 40 # Change it to any number of epochs you want.
training_result_path = "./results"
os.makedirs(training_result_path, exist_ok=True)
model_training.train_yolov8_model(config_path, num_epochs, training_result_path)
完整的項目結構應該如下所示:
yolo_with_callback/
│
├── dataset/ # Directory containing dataset files
│
├── env/ # python virtual environment directory
│
│── .env # Environment variables file containing secret keys
├── results/ # Directory for storing training results
│
├── data.yaml # Dataset configuration file
│
├── requirements.txt # File listing required Python packages
│
└── training.py # Main script for model training
現在,你已經完成了實現,可以繼續(xù)運行training.py代碼。訓練完成后,訓練結果報告將發(fā)送到指定的收件人郵箱。