使用 ML.NET 實(shí)現(xiàn)圖像分類:從入門到實(shí)踐
ML.NET是微軟開發(fā)的開源機(jī)器學(xué)習(xí)框架,讓.NET開發(fā)者能夠直接在.NET應(yīng)用程序中集成機(jī)器學(xué)習(xí)功能。本文將詳細(xì)介紹如何使用ML.NET實(shí)現(xiàn)圖像分類,包括環(huán)境搭建、數(shù)據(jù)準(zhǔn)備、模型訓(xùn)練等完整流程。
環(huán)境準(zhǔn)備
- Visual Studio 2022
- .NET 6.0或更高版本
- 需要安裝的NuGet包:
Microsoft.ML
Microsoft.ML.Vision
Microsoft.ML.ImageAnalytics
SciSharp.TensorFlow.Redist (版本2.3.1)
圖片
項(xiàng)目結(jié)構(gòu)
ImageClassification/
├── Program.cs
├── assets/ # 存放訓(xùn)練圖片
│ ├── CD/ # 有裂縫的圖片
│ └── UD/ # 無裂縫的圖片
└── workspace/ # 存放模型文件
代碼實(shí)現(xiàn)
一、創(chuàng)建數(shù)據(jù)模型類
// 原始圖像數(shù)據(jù)類
public class ImageData
{
public string ImagePath { get; set; }
public string Label { get; set; }
}
ImageData 類是用來表示和存儲圖像的基本信息的數(shù)據(jù)結(jié)構(gòu),主要用于數(shù)據(jù)加載和預(yù)處理階段。它包含兩個(gè)關(guān)鍵屬性:
1.ImagePath
- 存儲圖像文件的完整路徑
- 用于后續(xù)加載和訪問圖像文件
- 是一個(gè)字符串類型的屬性
2.Label
- 存儲圖像的分類標(biāo)簽/類別
- 表示圖像所屬的類別(如示例中的"有裂縫"或"無裂縫")
- 是一個(gè)字符串類型的屬性
主要用途:
- 數(shù)據(jù)加載:在從目錄加載圖像數(shù)據(jù)時(shí),用于初步組織和存儲圖像信息
- 數(shù)據(jù)組織:將文件系統(tǒng)中的圖像與其對應(yīng)的分類標(biāo)簽關(guān)聯(lián)起來
// 模型輸入類
public class ModelInput
{
public byte[] Image { get; set; }
public UInt32 LabelAsKey { get; set; }
public string ImagePath { get; set; }
public string Label { get; set; }
}
3.Image (byte[] 類型)
- 存儲圖像的字節(jié)數(shù)組表示
- 這是模型訓(xùn)練和預(yù)測所必需的輸入格式
- 模型需要這種類型的圖像數(shù)據(jù)來進(jìn)行訓(xùn)練
4.LabelAsKey (UInt32 類型)
- 是 Label 的數(shù)值表示形式
- 將分類標(biāo)簽轉(zhuǎn)換為數(shù)值形式,因?yàn)闄C(jī)器學(xué)習(xí)模型要求輸入采用數(shù)值格式
5.ImagePath (string 類型)
- 存儲圖像的完整路徑
- 用于方便訪問原始圖像文件
6.Label (string 類型)
- 圖像所屬的類別
- 這是需要預(yù)測的目標(biāo)值
- 用于訓(xùn)練時(shí)的標(biāo)簽信息
重要說明:
- 在實(shí)際訓(xùn)練和預(yù)測中,只有 Image 和 LabelAsKey 這兩個(gè)屬性被用于模型訓(xùn)練和預(yù)測
- ImagePath 和 Label 屬性主要是為了方便訪問和追蹤原始數(shù)據(jù),不直接參與模型計(jì)算
- 這個(gè)類是連接原始圖像數(shù)據(jù)和模型訓(xùn)練需求的橋梁,將各種必要的信息整合在一起
// 模型輸出類
public class ModelOutput
{
public string ImagePath { get; set; }
public string Label { get; set; }
public string PredictedLabel { get; set; }
}
7.ImagePath (string 類型)
- 存儲圖像的完整文件路徑
- 用于追蹤和引用原始圖像文件
8.Label (string 類型)
- 存儲圖像的原始/真實(shí)類別標(biāo)簽
- 這是圖像實(shí)際應(yīng)該屬于的類別
9.PredictedLabel (string 類型)
- 存儲模型預(yù)測的類別標(biāo)簽
- 這是模型通過分析圖像后預(yù)測出的類別
重要說明:
- 在實(shí)際預(yù)測過程中,只有 PredictedLabel 是必需的,因?yàn)樗P偷念A(yù)測結(jié)果
- ImagePath 和 Label 屬性主要用于評估和驗(yàn)證目的,方便比較預(yù)測結(jié)果與實(shí)際標(biāo)簽的差異
二、 圖像加載工具方法
private static IEnumerable<ImageData> LoadImagesFromDirectory(string folder, bool useFolderNameAsLabel = true)
{
var files = Directory.GetFiles(folder, "*", searchOption: SearchOption.AllDirectories);
foreach (var file in files)
{
if ((Path.GetExtension(file) != ".jpg") && (Path.GetExtension(file) != ".png"))
continue;
var label = Path.GetFileName(file);
if (useFolderNameAsLabel)
label = Directory.GetParent(file).Name;
else
{
for (int index = 0; index < label.Length; index++)
{
if (!char.IsLetter(label[index]))
{
label = label.Substring(0, index);
break;
}
}
}
yield return new ImageData()
{
ImagePath = file,
Label = label
};
}
}
三、主程序?qū)崿F(xiàn)
class Program
{
static void Main(string[] args)
{
// 初始化ML.NET環(huán)境
MLContext mlContext = new MLContext();
// 設(shè)置路徑
var projectDirectory = Path.GetFullPath(Path.Combine(AppContext.BaseDirectory, "../../../"));
var workspaceRelativePath = Path.Combine(projectDirectory, "workspace");
var assetsRelativePath = Path.Combine(projectDirectory, "assets");
// 加載數(shù)據(jù)
IEnumerable<ImageData> images = LoadImagesFromDirectory(folder: assetsRelativePath, useFolderNameAsLabel: true);
IDataView imageData = mlContext.Data.LoadFromEnumerable(images);
IDataView shuffledData = mlContext.Data.ShuffleRows(imageData);
// 數(shù)據(jù)預(yù)處理
var preprocessingPipeline = mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
.Append(mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath"));
IDataView preProcessedData = preprocessingPipeline
.Fit(shuffledData)
.Transform(shuffledData);
// 數(shù)據(jù)集分割
var trainSplit = mlContext.Data.TrainTestSplit(data: preProcessedData, testFraction: 0.3);
var validationTestSplit = mlContext.Data.TrainTestSplit(trainSplit.TestSet);
// 配置訓(xùn)練選項(xiàng)
var classifierOptions = new ImageClassificationTrainer.Options()
{
FeatureColumnName = "Image",
LabelColumnName = "LabelAsKey",
ValidationSet = validationTestSplit.TrainSet,
Arch = ImageClassificationTrainer.Architecture.ResnetV2101,
MetricsCallback = (metrics) => Console.WriteLine(metrics),
TestOnTrainSet = false,
ReuseTrainSetBottleneckCachedValues = true,
ReuseValidationSetBottleneckCachedValues = true,
WorkspacePath = workspaceRelativePath
};
// 定義訓(xùn)練管道
var trainingPipeline = mlContext.MulticlassClassification.Trainers
.ImageClassification(classifierOptions)
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"));
// 訓(xùn)練模型
Console.WriteLine("*** 開始訓(xùn)練模型 ***");
ITransformer trainedModel = trainingPipeline.Fit(trainSplit.TrainSet);
// 進(jìn)行預(yù)測
ClassifySingleImage(mlContext, validationTestSplit.TestSet, trainedModel);
ClassifyImages(mlContext, validationTestSplit.TestSet, trainedModel);
}
}
數(shù)據(jù)預(yù)處理說明
第一步:標(biāo)簽轉(zhuǎn)換
mlContext.Transforms.Conversion.MapValueToKey(
inputColumnName: "Label",
outputColumnName: "LabelAsKey")
- 將字符串類型的標(biāo)簽("Label")轉(zhuǎn)換為數(shù)值類型("LabelAsKey")
- 例如:"CD" -> 0, "UD" -> 1
- 這是必需的,因?yàn)闄C(jī)器學(xué)習(xí)模型需要數(shù)值形式的標(biāo)簽
- 輸入是 ImageData 類中的 Label 屬性
- 輸出存儲在 ModelInput 類的 LabelAsKey 屬性中
第二步:圖像加載
mlContext.Transforms.LoadRawImageBytes(
outputColumnName: "Image",
imageFolder: assetsRelativePath,
inputColumnName: "ImagePath")
- 將圖像文件轉(zhuǎn)換為字節(jié)數(shù)組格式
- `outputColumnName: "Image"`: 輸出到 ModelInput 類的 Image 屬性
- `imageFolder: assetsRelativePath`: 指定圖像文件所在的根目錄
- `inputColumnName: "ImagePath"`: 使用 ImageData 類中的 ImagePath 屬性
配置訓(xùn)練選項(xiàng)
1.FeatureColumnName = "Image"
- 指定用作模型輸入的列名
- 這里使用"Image"列,它包含圖像的字節(jié)數(shù)組數(shù)據(jù)
2.LabelColumnName = "LabelAsKey"
- 指定要預(yù)測的目標(biāo)值列名
- 使用"LabelAsKey"列,它是標(biāo)簽的數(shù)值表示形式
3.ValidationSet = validationTestSplit.TrainSet
- 指定用于驗(yàn)證的數(shù)據(jù)集
- 用于在訓(xùn)練過程中評估模型性能
4.Arch = ImageClassificationTrainer.Architecture.ResnetV2101
- 指定使用的預(yù)訓(xùn)練模型架構(gòu)
- 這里使用 ResNet v2 的101層變體
- ResNet是一個(gè)預(yù)訓(xùn)練模型,可以將圖像分為1000個(gè)類別
5.MetricsCallback = (metrics) => Console.WriteLine(metrics)
- 用于在訓(xùn)練過程中跟蹤和顯示訓(xùn)練指標(biāo)
- 通過控制臺輸出訓(xùn)練進(jìn)度和性能指標(biāo)
6.TestOnTrainSet = false
- 設(shè)置是否在訓(xùn)練集上測試模型
- false表示不在訓(xùn)練集上測試,避免過擬合
7.ReuseTrainSetBottleneckCachedValues = true
- 是否重用訓(xùn)練集的瓶頸層計(jì)算結(jié)果
- true表示緩存并重用這些值,可以顯著減少訓(xùn)練時(shí)間
- 適用于訓(xùn)練數(shù)據(jù)不變但需要調(diào)整其他參數(shù)的情況
8.ReuseValidationSetBottleneckCachedValues = true
- 是否重用驗(yàn)證集的瓶頸層計(jì)算結(jié)果
- 與上面類似,但作用于驗(yàn)證數(shù)據(jù)集
9.WorkspacePath = workspaceRelativePath
- 指定存儲工作文件的目錄路徑
- 用于保存計(jì)算的瓶頸值和模型的.pb版本
- 便于后續(xù)重用和模型部署
這些參數(shù)的配置對模型的訓(xùn)練效果和效率有重要影響,可以根據(jù)具體需求調(diào)整這些參數(shù)來優(yōu)化模型性能。
定義訓(xùn)練管道
- 圖像分類訓(xùn)練器
mlContext.MulticlassClassification.Trainers.ImageClassification(classifierOptions)
- 使用多分類分類器進(jìn)行圖像分類
- 基于之前定義的 classifierOptions 配置
- 使用遷移學(xué)習(xí)方法,基于預(yù)訓(xùn)練的 ResNet 模型
- 主要功能:
- 提取圖像特征
- 訓(xùn)練分類器
- 生成預(yù)測模型
- 預(yù)測標(biāo)簽轉(zhuǎn)換
.Append(mlContext.Transforms.Conversion.MapKeyToValue("PredictedLabel"))
- 將模型輸出的數(shù)值預(yù)測結(jié)果轉(zhuǎn)換回原始標(biāo)簽
- 與前面的 MapValueToKey 操作相反
- 例如:將 0 轉(zhuǎn)回 "CD",1 轉(zhuǎn)回 "UD"
- 確保最終輸出是人類可讀的標(biāo)簽
整個(gè)訓(xùn)練管道的工作流程:
- 接收預(yù)處理后的數(shù)據(jù)(圖像字節(jié)數(shù)組和數(shù)值標(biāo)簽)
- 通過深度學(xué)習(xí)模型進(jìn)行特征提取和分類
- 將數(shù)值預(yù)測結(jié)果轉(zhuǎn)換為原始標(biāo)簽類別
- 輸出最終的分類結(jié)果
四、預(yù)測方法實(shí)現(xiàn)
private static void ClassifySingleImage(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
var predictionEngine = mlContext.Model.CreatePredictionEngine<ModelInput, ModelOutput>(trainedModel);
var image = mlContext.Data.CreateEnumerable<ModelInput>(data, reuseRowObject: true).First();
var prediction = predictionEngine.Predict(image);
Console.WriteLine($"單張圖片分類結(jié)果:");
Console.WriteLine($"圖片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"實(shí)際類別: {prediction.Label}");
Console.WriteLine($"預(yù)測類別: {prediction.PredictedLabel}");
}
private static void ClassifyImages(MLContext mlContext, IDataView data, ITransformer trainedModel)
{
IDataView predictionData = trainedModel.Transform(data);
var predictions = mlContext.Data.CreateEnumerable<ModelOutput>(predictionData, reuseRowObject: true)
.Take(10);
Console.WriteLine("\n批量圖片分類結(jié)果:");
foreach (var prediction in predictions)
{
Console.WriteLine($"圖片: {Path.GetFileName(prediction.ImagePath)}");
Console.WriteLine($"實(shí)際類別: {prediction.Label}");
Console.WriteLine($"預(yù)測類別: {prediction.PredictedLabel}\n");
}
}
五、執(zhí)行
訓(xùn)練速度比較慢
圖片
圖片
圖片
實(shí)際與預(yù)測都是CD。
圖片
這里會發(fā)現(xiàn)預(yù)測與實(shí)際是有出入的。
模型優(yōu)化建議
1.增加訓(xùn)練數(shù)據(jù)量: 收集更多的樣本數(shù)據(jù)可以提高模型的泛化能力。
2.數(shù)據(jù)增強(qiáng):
- 對現(xiàn)有圖片進(jìn)行旋轉(zhuǎn)、翻轉(zhuǎn)、縮放等操作
- 調(diào)整亮度、對比度
- 添加噪聲
3.調(diào)整超參數(shù):
- 增加訓(xùn)練輪數(shù)(Epoch)
- 調(diào)整學(xué)習(xí)率
- 嘗試不同的批次大小
4.使用不同的預(yù)訓(xùn)練模型:
- ResNet不同版本
- Inception
- MobileNet
總結(jié)
本文詳細(xì)介紹了如何使用ML.NET實(shí)現(xiàn)圖像分類功能。通過使用遷移學(xué)習(xí)和預(yù)訓(xùn)練模型,我們可以快速構(gòu)建高質(zhì)量的圖像分類應(yīng)用。ML.NET提供了簡單易用的API,讓.NET開發(fā)者能夠方便地將機(jī)器學(xué)習(xí)集成到應(yīng)用程序中。