記一次坎坷的算法需求實現(xiàn):輕量級人體姿態(tài)估計模型的修煉之路
本文記錄了作者實現(xiàn)輕量級人體姿態(tài)估計模型的全過程,從方案的選取到嘗試復現(xiàn)等,詳細的敘述了一個項目需求完成的整體思路,并 附有谷歌開源的MoveNet的復現(xiàn)經(jīng)驗。本文能給處在CV各個階段的朋友們帶來幫助!
一、需求背景
這天接到個新需求,需要實時檢測自然場景下目標人體的關(guān)鍵點位置。
從算法工程師的角度來拆解下需求:
1、檢測人體關(guān)鍵點位置,就是人體姿態(tài)估計任務嘛;
2、要實時,那么就是終端部署,服務端那傳輸延時就不考慮了。對了咱們硬件不大行,所以肯定是要輕量級模型的,分辨率也不能太大,剪枝量化蒸餾三件套也要做好打算;
3、“自然場景下的目標人體”,意思就是場景下可能有多人,但是我們只需要一個目標的關(guān)鍵點,要考慮如何區(qū)分(這點后面再展開說)。
二、方案探索
2.1 初見
之前的項目經(jīng)驗主要是人臉相關(guān)的分類、檢測、分割以及OCR等,雖然沒做過人體姿態(tài)估計,但是需求分析完,我的第一感覺是“應該很簡單”,這份底氣來自于之前做過的人臉關(guān)鍵點檢測項目,當時的經(jīng)驗是,處理好數(shù)據(jù)、使用合適的Loss,隨便用個剪枝后的輕量級模型都可以達到很高的精度,而且在嵌入式板子上能跑到10ms以下(都不需要PFLD之類的)。而人臉關(guān)鍵點有68點,人體也就17個點,不是簡單多了。
于是直接拿出之前人臉關(guān)鍵點的代碼,修改億點點細節(jié)(調(diào)整模型輸出、準備訓練數(shù)據(jù)、修改DataLoader等),用少部分數(shù)據(jù)先跑跑看看。數(shù)據(jù)就用COCO和MPII就差不多了,后續(xù)可以自己從網(wǎng)上爬取一些來擴充,當然如果有目標場景的數(shù)據(jù)就更好了。
工欲善其事,必先抓只小白鼠。也就是需要一個評測指標,分類常用Accuracy、F1,檢測常用mAP,而人體姿態(tài)估計又有所不同,請教了谷哥(Google), 一般是用PCK(PCKh)和OKS+mAP,簡單來說,前者就是計算predict和label點的距離再經(jīng)過head size normalize,后者就是用類似于目標檢測中IOU的OKS計算相似度,然后再算AP。 雖然前者在學術(shù)界目前已經(jīng)很少使用(主要是刷榜刷得太高了),但是我們場景只需要單人,加上工程化簡單,肯定是選擇PCK,且考慮數(shù)據(jù)不一定都有head的標注,以及我們場景目標大小比較一致,最終決定使用自定義PCK,160分辨率下,距離小于5則算正確,其實PCK也算作一種acc,所以后面用acc指代。
第一次跑完,驗證集acc達到了 0.81,感覺還行,于是開始往丹爐瘋狂加料(加數(shù)據(jù)、加Data balance、微調(diào)loss、不同關(guān)鍵點loss設置不同weight、懷疑特征提取不夠甚至把FastSCNN的backbone都拿來了、各種調(diào)參等等),一頓操作下來,終于有了一點提升:val-acc達到了0.9869。
這么高,看來是成了,部署上板測試之前,習慣性在自己電腦上先跑跑,準備好可視化demo一看,傻眼了,除了head之類的比較準,手稍微一動就很容易檢測不到。
一腔熱血終于被冰冷的現(xiàn)實擊潰,看來這個“應該很簡單”的任務并不簡單...自己挖的坑,跪著也要填完。稍微總結(jié)下失敗的教訓,然后準備老老實實從零開始。
人臉關(guān)鍵點檢測直接回歸坐標點的效果好,很大原因在于特征分布比較集中。首先是人臉大小比較一致,其次是人臉是剛體,各個關(guān)鍵點相對位置也比較固定,最后是除了側(cè)臉,基本不會遮擋,目標特征也比較清晰。相較之下,人體大小分布就不一致,且各個關(guān)節(jié)活動范圍很大,相對位置不固定,還存在各種位置遮擋和服裝遮擋的情況,相比來說任務難了很多。同時由于關(guān)鍵點的分布范圍太廣,單純用全連接層去回歸,很有可能出現(xiàn)某一部分神經(jīng)元訓練的比較好,另一部分比較差,也就導致泛化能力很差。
2.2 反思
直接回歸關(guān)鍵點的思路行不通,那就打開國門睜眼看世界——看看現(xiàn)在的主流方案。這篇2020年的綜述就挺不錯的:Deep Learning-Based Human Pose Estimation: A Survey。如同華山派著名的氣宗和劍宗之爭,人體姿態(tài)估計也有不同的派別,而且打得更熱鬧,還分了兩個方向:
-
目標學習形式: 直接回歸點坐標 VS 回歸Heatmap
-
整體檢測流程: Top-Bottom VS Bottom-Up
原來早期的人體姿態(tài)估計就是直接回歸關(guān)鍵點,看來我也只是在重復前輩的路而已。 后來發(fā)現(xiàn)性能不好,于是改為回歸heatmap的方式來訓練。 簡單來說就是把原圖縮小幾倍得到特征圖,這個特征圖有關(guān)鍵點的位置值就是1、背景就是0,因為只設1太稀疏了,所以把1改為一個高斯核,這樣也符合實際語義。
確定好學習形式后,再看看檢測流程。上面我的方案其實就屬于Top-Bottom,也就是先檢測人,再裁剪出來對每個人去檢測關(guān)鍵點;而Bottom-Up則相反,直接檢測出所有的關(guān)鍵點,再依次組合成一個個人。這里和目標檢測中的one-stage和two-stage其實有點神似,一般來說,Top-Bottom方式精度高,但是速度慢一點,而Bottom-Up速度快,但是精度差一點。
同時了解了一下目前一些優(yōu)秀的開源方案(主要側(cè)重一些相對輕量級模型,因此就不包含HRNet之流了): AlphaPose、OpenPose、Lightweight OpenPose、BlazePose、Simple Baselines、Fast Human Pose Estimation、Global Context for Convolutional Pose Machines、Simple and Lightweight Human Pose Estimation 等,那就來吧。
2.3 嘗試
2.3.1 淺嘗輒止
對上面提到的各種潛在可行方案,分別進行測試。
-
修改上文的我的人臉關(guān)鍵點回歸模型(Top-Bottom),把輸出改為heatmap,相應調(diào)整數(shù)據(jù)處理和loss,訓練完效果一般,val-acc85;
-
AlphaPose、OpenPose都屬于比較大的模型,且GPU速度也只是勉強實時,我們是端側(cè)+CPU+低性能板子,估計1秒以上了,所以暫時不考慮,同時發(fā)現(xiàn)有個Lightweight OpenPose(Bottom-Up)的項目,直接拿來上板跑了下,360ms,有點希望;
-
BlazePose,這里直接用了TNN的模型,同時發(fā)現(xiàn)他們還開源了騰訊光流實驗室的一個姿態(tài)估計模型(Top-Bottom),比BlazePose要好,于是直接測試了后者,上板200ms,且除了轉(zhuǎn)換后的模型沒有什么其他資料;
-
Simple Baselines(Top-Bottom),用預訓練模型訓練,val-acc在0.95左右還行,就是換成輕量級網(wǎng)絡后(mobilenet-v3),精度掉到0.9左右,且可視化很差;
-
Fast Human Pose Estimation(Top-Bottom),相比于沙漏網(wǎng)絡,stage砍半,channel砍半,就沒其它改變了;使用了蒸餾,loss是兩部分,一部分是老師的預測結(jié)果,一部分是GT;效果還行,720ms略慢;
-
Global Context for Convolutional Pose Machines(Top-Bottom),提供的預訓練模型,可視化效果還行,400ms左右;
-
Simple and Lightweight Human Pose Estimation(Top-Bottom),可視化效果還行,250ms;
2.3.2 集中火力
經(jīng)過綜合考慮,最后選擇了Lightweight OpenPose來做優(yōu)化。
一是它是Bottom-Up的模型,不需要額外訓練人體檢測器(對應額外的數(shù)據(jù)處理成本、模型訓練成本以及推理耗時);二是可視化來看,它的精度也是屬于表現(xiàn)最好的幾個。
要了解Lightweight OpenPose,那么就需要先提下OpenPose,它使用了VGG作為特征提取,加上兩個header,分別輸出關(guān)鍵點的heatmap和PAF的heatmap,通過多個stage迭代優(yōu)化(每個stage的輸入和輸出都是這兩個heatmap),最后通過MSE進行優(yōu)化。關(guān)鍵點的heatmap好理解,PAF是什么呢?這是論文提出的part affinity fields,一般翻譯為親和力向量場,它代表了兩個關(guān)鍵點之間的連接信息,個人覺得可以直觀理解為關(guān)鍵點之間的骨骼信息,實現(xiàn)的時候是求兩個關(guān)鍵點之間的單位向量(代表了方向),然后給對應的heatmap位置賦值。最后把多人檢測問題轉(zhuǎn)化為二分圖匹配問題,并用匈牙利算法求得相連關(guān)鍵點最佳匹配。
而Lightweight OpenPose作者測試發(fā)現(xiàn)原openpose的refinement stage的5個階段,從第一個refinement stage1之后,正確率沒有提升多少,但是運算量就加大了。并且還發(fā)現(xiàn)了refinement stage網(wǎng)絡的關(guān)鍵點定位(heatmap)和關(guān)鍵點組合(pafs)兩部分的兩條分支運算,有不少操作是一樣的,造成運算冗余。于是主要進行了以下優(yōu)化:
-
新的網(wǎng)絡設計。只采用initial stage+refinement stage兩個階段網(wǎng)絡。(此處,我覺得到refinement stage2階段的性價比比較高,refinement stage1的正確率還是有點低,盡管運算量不大。但是refinement stage2再加了約19GFLOPs的同時正確率就提高了3%左右,性價比更好。)
-
更換輕量級backbone。用MobileNet替換,并用空洞卷積優(yōu)化網(wǎng)絡。并且進行了一系列對比實驗,發(fā)現(xiàn)MoblieNet v1比MobileNet v2效果居然更好。
-
refinement stage中的關(guān)鍵點定位(heatmap)和關(guān)鍵點組合(pafs)部分的網(wǎng)絡權(quán)值共享,減少操作計算量。最后兩個卷積層才分出兩個分支分別用來預測關(guān)鍵點的定位(heatmap)和關(guān)鍵點組合(pafs)。
-
原openpose的refinement stage中都是使用7x7的卷積核。作者卻使用了三個連續(xù)的1x1, 3x3, 3x3的卷積核來代替7x7,并且最后的3x3是使用了空洞卷積, dilation=2。另外由于用3個卷積核代替原來的一個卷積核,網(wǎng)絡層數(shù)變得很深,所以作者又加上了一個residual連接。
研究了下Lightweight OpenPose源碼,繼續(xù)開始往丹爐瘋狂加料(改backbone、剪枝、加數(shù)據(jù)、調(diào)分辨率、蒸餾、各種調(diào)參等等),一頓操作下來,得到了模型:速度80ms,acc0.95(對應的蒸餾teacher acc0.98)
速度勉強能接受,精度還行,但是可視化測試稍微有點不準,可能跟數(shù)據(jù)也有一定關(guān)系。
三、峰回路轉(zhuǎn)
3.1 轉(zhuǎn)機
就在我猶豫要不要繼續(xù)深入優(yōu)化Lightweight OpenPose的時候,無意間刷到了這樣一篇文章:《實時檢測17個人體關(guān)鍵點,谷歌SOTA姿態(tài)檢測模型,手機端也能運行》。文章介紹了谷歌今年五月開源的一個輕量級人體姿態(tài)估計模型MoveNet,在tfjs有對應的api。
眾里尋他千百度,驀然回首,那人卻在燈火闌珊處。這不就是上天為我準備的嗎?趕緊用網(wǎng)頁端測試了下,感覺還行,加上之前有轉(zhuǎn)bodypix的tfjs模型到tf模型的經(jīng)驗,準備直接下模型本地跑跑,幸運的是,這次的MoveNet不僅有tfjs模型,還有tflite模型,這下方便多了。
通過官方博客的只言片語以及Netron可視化網(wǎng)絡,對模型有了初步了解,官方宣稱是基于CenterNet做的修改,但是修改了很多(The prediction scheme loosely follows CenterNet, with notable changes that improve both speed and accuracy)。具體來說,使用了mobilenetv2+fpn作為backbone,輸出四個header,最后經(jīng)過后處理輸出一個最靠近圖片中心的人的關(guān)鍵點。由此可知,它也是屬于Bottom-Up的流派,同時通過關(guān)鍵點回歸的范圍進行加權(quán),只輸出最靠近中心的一個人的關(guān)鍵點,和我們的場景絕佳匹配,因為就Lightweight OpenPose而言,其余人的PAF信息其實也是多余的。
先測測速度吧,首先嘗試上板安卓直接加載tflite模型,報錯Didn't find op for builtin opcode 'RESIZE_BILINEAR' version '3',之前用的org.tensorflow:tensorflow-lite:2.0.0,嘗試2.3.0后可以了。速度150ms左右一幀,連續(xù)跑在120ms左右。然后嘗試一些推理框架(比如Tengine、NCNN、MNN等),這里我習慣使用TNN。結(jié)果tflite轉(zhuǎn)tnn和pb轉(zhuǎn)tnn均failed,debug看是不支持ARG_MAX算子,這是后處理里面的,于是自己簡單寫了個去掉后處理的mobilenetv2+fpn+4個header的模型轉(zhuǎn)tnn,上板測試80ms左右。
考慮到Lightweight OpenPose已經(jīng)優(yōu)化了很多,繼續(xù)壓榨空間有限,且上文提到有冗余輸出,而這個MoveNet直接就能跑到80多ms,且只輸出一個人的關(guān)鍵點,還有谷歌背書,因此痛下決心,準備拋棄Lightweight OpenPose采用MoveNet的方案,嘗試復現(xiàn)它。同時我還準備了Plan B:如果復現(xiàn)不了,就直接導出tflite的backbone權(quán)重,然后新建一個pytorch模型加載權(quán)重,得到模型推理輸出后通過C++代碼實現(xiàn)后處理。這樣的缺點是沒法用自己數(shù)據(jù)訓練,無法迭代優(yōu)化,且直接使用官方提供的預訓練模型可視化來看精度也沒有達到特別高。
3.2 復現(xiàn)
好吧,那就開始準備擼起袖子復現(xiàn)MoveNet了。
只根據(jù)一個模型結(jié)構(gòu)要復現(xiàn)一個模型訓練,說難不難,說簡單也不簡單。模型結(jié)構(gòu)搭建照貓畫虎即可,關(guān)鍵是數(shù)據(jù)如何構(gòu)造、Loss如何選擇沒有任何信息,甚至還隱藏了大量細節(jié),比如高斯核的構(gòu)造、loss權(quán)重的設置、優(yōu)化器等各種超參數(shù)設置。道阻且長,行則將至,只能一步一個腳印了。
事先說明,以下所有分析均是主觀猜想,只是經(jīng)過實驗驗證效果也不錯,不代表谷歌官方原始實現(xiàn)就一定是這樣,僅供參考。
3.2.1 模型結(jié)構(gòu)
要復現(xiàn)一個模型并不難,很多模型有論文、有官方代碼或者有其他人復現(xiàn)的代碼、哪怕是各種專欄博客的拆解分析,都能幫助很多,遺憾的是,MoveNet都沒有,唯一能參考的只有兩個東西:谷歌官方博客的介紹,以及TFHub提供的訓練好的模型。這也是這篇文章分享的初衷。
觀察模型結(jié)構(gòu),主要分為三部分:backbone、header、后處理。
1. Backbone前面說過了很簡單,即mobilenetv2+fpn,為了保證可控性和靈活性 ,我沒有使用現(xiàn)成的一些mobilenetv2的實現(xiàn),而是自己一層層自己搭起來的(其實也沒有太大必要,后續(xù)換backbone我也是直接在現(xiàn)成的模型代碼上改的);
2. Header的構(gòu)建就更簡單了,輸入backbone的特征圖,經(jīng)過各自的幾個卷積層,最后輸出各自維度的特征圖即可,難點在于要理解每一個header的含義,整體結(jié)構(gòu)參考如下:
經(jīng)過不斷的分析、猜想、測試,現(xiàn)在直接分享得到的結(jié)論吧。四個header我們分別命名head_heatmap,head_center,head_reg,head_offset以便說明:
-
head_heatmap的維度是[N,K,H,W],n是batchsize,訓練時自己指定,預測時一般為1;K代表關(guān)鍵點數(shù)量,比如17;H、W就是對應的特征圖了,這里輸入是192x192,降采樣4倍就是48x48;它所代表的意義就是當前圖像上所有人的關(guān)鍵點的heatmap,注意是所有人的;
-
head_center的維度是[N,1,H,W],這里的1代表的是當前圖像上所有人的中心點的heatmap,你可以簡單理解為關(guān)鍵點,因為只有一個,所以通道為1;它所代表的意義官方博客說是arithmetic mean,即每一個人的所有關(guān)鍵點的算術(shù)平均數(shù),但是我實測這樣效果并不好,我自己最終是取得所有關(guān)鍵點得最大外接矩形的中心點,當存在一些較遠的關(guān)鍵點的時候,可能算術(shù)平均數(shù)可以很好的訓練大部分距離近的點,但是對較遠的點效果差點,而我比較關(guān)注手腕這種較遠的點,按我這么取對每一個點學習起來差不多,這個就仁者見仁智者見智了,以自己場景實驗結(jié)果為準;
-
head_reg的維度是[N,2K,H,W],K個關(guān)鍵點,坐標用x,y表示,那么就有2K個數(shù)據(jù),就是對應這里的2K通道;那么數(shù)據(jù)如何構(gòu)造呢?根據(jù)模型結(jié)構(gòu)的拆解,就是在每個人的center坐標位置,按2K通道順序依次賦值x1,y1,x2,y2,...,這里的x、y代表的是同一個人的關(guān)鍵點相對于中心點的偏移值,原始Movenet用的是特征圖48尺寸下的絕對偏移值,實測換成相對值(即除以size48轉(zhuǎn)換到0-1區(qū)間)也是可以的,可以稍微加快收斂,不過幾乎沒有區(qū)別;
-
head_offset的維度是[N,2K,H,W],通道意義一樣都是對應K個關(guān)鍵點的坐標,只不過上面是回歸偏移值,這里是offset,含義是我們模型降采樣特征圖可能存在量化誤差,比如192分辨率下x=0和x=3映射到48分辨率的特征圖時坐標都變?yōu)榱?;同時還有回歸的誤差,這里一并訓練了;
3.后處理相對來說比較麻煩,因為有各種奇奇怪怪的操作,這一步只有結(jié)合Netron可視化一步步大膽猜測,結(jié)合官方博客認真分析,結(jié)構(gòu)如下,可以看到不是常規(guī)的CNN結(jié)構(gòu)。
官方介紹如下,也解開了不少疑團。
也不賣關(guān)子了,直接說明后處理流程吧。
首先對于head_heatmap和head_center,需要求一個sigmoid,這里也可以寫到網(wǎng)絡結(jié)構(gòu)中,主要是保證取值范圍,因為后面要乘位置權(quán)重。
然后對于head_center,我們乘以一個位置權(quán)重矩陣,大小也是48x48,值可以通過Netron導出谷歌的設置,通過Numpy+OpenCV加載可視化可以看出是一個中心點高亮的高斯核。對于我們檢測出所有的可能的中心點,通過乘以這樣一個中心位置加權(quán)矩陣,可以提高靠近中心的人物權(quán)重,因為最終我們只取一個最靠近圖像中心的。
乘以權(quán)重后,就是常見的argmax操作求出最大值點的坐標,這就是最終檢測目標的中心點。拿到這個坐標就可以去head_reg的2K個通道對應坐標位置取出2K個值,然后加上中心點坐標,這就得到粗略的關(guān)鍵點位置了。我一開始以為這里是檢測的主力輸出,但是后續(xù)發(fā)現(xiàn),這里的關(guān)鍵點其實很不準的,因為只是粗暴的回歸坐標偏移。
它的實際作用是,對于每一個粗略的關(guān)鍵點坐標,咱們再生成一個以這個坐標點為中心的權(quán)重矩陣,這里也沒用高斯核了,直接0-47等差數(shù)列構(gòu)造即可。 構(gòu)造好一個粗略關(guān)鍵點坐標最小、周邊遞增的權(quán)重矩陣后,再加上一個常數(shù)(比如1.8)防止除0,然后就把一開始的head_heatmap除以這個權(quán)重矩陣,再分別通過K個通道求argmax,得到精細化的關(guān)鍵點坐標。為什么要這么做呢? 其實思想很簡單,就是通過粗定位的關(guān)鍵點范圍,然后加權(quán)取去找最可能屬于當前人物的關(guān)鍵點,因為前面說過了我們是Bottom-Up的方案,head_heatmap檢測的是所有人的所有關(guān)鍵點。通過這樣一步操作,就可以取出最靠近中心點的人的關(guān)鍵點了。
最后再根據(jù)坐標點去head_offset對應坐標位置取出offset的值加上即可。置信度就是head_heatmap的sigmoid后的值。
3.2.2 數(shù)據(jù)處理
終于分析完模型結(jié)構(gòu)和后處理流程,但是依舊不能開始訓練,因為我們還要喂數(shù)據(jù)、定目標函數(shù)。
數(shù)據(jù)構(gòu)造其實弄明白后處理流程后也不難了,一個值得注意的問題是heatmap的高斯核如何構(gòu)造。一開始我使用的是cv2.GaussianBlur生成的,可視化看過還行,后來debug發(fā)現(xiàn)一個問題,單個點沒影響,但是當存在多個接近的關(guān)鍵點時,互相的高斯核極大值會受影響,可能變成比1小的值。于是干脆拿了CenterNet和Lightweight OpenPose中的高斯核生成來嘗試,前者就是經(jīng)典的高斯核,后者是按照指數(shù)曲線生成的,最終選擇了后者。因為MoveNet后處理中,存在一個乘權(quán)重然后取最大值的步驟,為了增加不同點的區(qū)分度(比如隔得遠的點置信度為1,近的0.9,這時候應該取近的0.9這個點),那么應該增大不同距離點的權(quán)重差,但是為了減少同一個點的heatmap范圍內(nèi)變中心點改變(比如同一個高斯核內(nèi)最大值1,次大0.9,但是0.9更靠近中心,這時候還是應該取1),那么應該減少相鄰點間隔。在這兩點矛盾的情況下,權(quán)重矩陣又是谷歌提供的已知的值,比較合理的一個方案就是選擇指數(shù)曲線生成的熱力圖,即增大同一個高斯核中不同點的差值;最后還要注意高斯核的大小半徑,可以當作超參數(shù)來調(diào)整,建議覆蓋到原圖上分析是否合理,最好根據(jù)不同的目標尺寸設置不同的半徑大小。
另外一個細節(jié)是,如果只給中心點坐標對應的head_reg上那一個點賦值,每次去取回歸值的時候,由于只有一個點,學習起來比較困難,因此我是給周圍一圈也賦值了,只是要注意賦值坐標大小要對應加減。
3.2.3 損失函數(shù)
一般來說,對于這種heatmap形式的數(shù)據(jù),比較常用的是L2 loss,比如Lightweight OpenPose,而CenterNet使用的是Focal loss,其實Focal loss也是從L2 loss優(yōu)化而來,只是根據(jù)類別加權(quán)處理類別不平衡的問題,根據(jù)置信度加權(quán)處理難易樣本的問題。
最終我采用的是加權(quán)版的MSE,相當于平衡了正負樣本,因為heatmap+高斯核的方式背景占比很大,但是沒有加Focal loss中的gamma,因為實驗效果不好。這個加權(quán)也很好實現(xiàn),因為我們的label就是取值0-1的矩陣,直接乘以k然后加1,就變成取值1到k+1的范圍,即取值1的位置對應權(quán)重k+1,取值0也就是背景對應權(quán)重1,這個k可以根據(jù)數(shù)據(jù)情況自己調(diào)節(jié),參考取值5-10,參考代碼如下:
- loss = torch.pow((pre-target),2)
- weight_mask = target*k+1
- #gamma = torch.pow(torch.abs(target-pre), 2)
- loss = loss*weight_mask #*gamma
最終head_heatmap和head_center采用了加權(quán)MSE,head_reg和head_offset采用了L1 Loss,也是參考的CenterNet。
最后需要考慮的就是不同loss的權(quán)重問題,前期調(diào)試的時候由于量級存在差異(head_reg)是絕對坐標差值,我是設置的head_heatmap:head_center:head_reg:head_offset=1:1:0.1:1,最終優(yōu)化好之后,直接采用1:1:1:1。
至此,整個MoveNet就已經(jīng)復現(xiàn)差不多了,跑出來的acc能達到95,再稍微修復了一些bug,優(yōu)化了一些細節(jié),基本能跑到97了。但是直接跑視頻demo可視化還是不是特別滿意,因此需要以此為baseline,深度優(yōu)化。
四、煉丹之道
既然有了baseline,訓練代碼也搭建好了,就是喜聞樂見的煉丹環(huán)節(jié)了。
其實復現(xiàn)MoveNet的過程中就包含了一些優(yōu)化,比如loss選擇、高斯核生成,因為一開始效果不好,你不知道是復現(xiàn)的不對還是沒有優(yōu)化好。這里數(shù)據(jù)迭代清洗、一般的數(shù)據(jù)增強就不多說了,就說一些印象比較深的吧。
我的建議是先通過可視化分析+實驗驗證確定好數(shù)據(jù)增強方案,然后再進行模型結(jié)構(gòu)的調(diào)整,最后進行l(wèi)oss優(yōu)化以及各種超參數(shù)調(diào)整,這樣可以一定程度上避免改了東墻測西墻的重復實驗。
首先是數(shù)據(jù)擴充,谷歌本身除了使用清洗后的COCO數(shù)據(jù)集,也自己從youtube上爬了一些瑜伽、健身之類的視頻然后生成圖片和標注,因為原始COCO、MPII之類的圖片,還是缺少一些瑜伽之類的特殊的姿態(tài)的,這就會造成數(shù)據(jù)不平衡(類似人臉關(guān)鍵的中的張嘴閉眼),于是除了考慮代碼中加入數(shù)據(jù)平衡,最好的方式也是擴充一些數(shù)據(jù),因此我也從B站上下了一些瑜伽、健身、體操、舞蹈等視頻標注;
其次是數(shù)據(jù)增強有兩點值得一提, 一是常見的鏡像翻轉(zhuǎn) ,分類任務中翻了就翻了,目標檢測中要把目標框的橫坐標同樣翻轉(zhuǎn),而人體姿態(tài)估計更進一步,對關(guān)鍵點橫坐標翻轉(zhuǎn)之后,還要修改對應關(guān)鍵點的順序,比如左手腕的圖片水平鏡像翻轉(zhuǎn)后,它就不再是左手腕而變成右手腕了,這是一個小坑,不過閱讀一些開源代碼應該也能看到相應的處理; 第二點就是隨機裁剪crop增大目標范圍 ,主要是初步驗證的時候發(fā)現(xiàn)泛化效果不好,經(jīng)過簡單數(shù)據(jù)分析,發(fā)現(xiàn)訓練集的關(guān)鍵點區(qū)域面積集中在2000附近:
再看看目標場景采集的幾百張圖片的關(guān)鍵點區(qū)域面積分布,范圍是2000-12000,且峰值在6000-8000:
那么就可以針對性得調(diào)整隨機crop的概率、crop的范圍等,最后得到訓練集數(shù)據(jù)增強后的分布如下,可以看到相比原始分布已經(jīng)好很多了:
接著就是模型結(jié)構(gòu)的優(yōu)化,谷歌原始方案居然是使用的MobilenetV2,有點奇怪,據(jù)個人經(jīng)驗來說,這幾個常見的輕量級模型效果一般排序是MobileNetV3≈ShuffleNetV2 >= MobilenetV2 >= MobilenetV1 ≈ShuffleNetV1,如果谷歌礙于自家產(chǎn)品不使用Shufflnet能理解,那為什么不用MobileNetV3呢?分別加上FPN上手跑跑,原始MoveNet使用的FPN分別融合了四個特征層,而由于網(wǎng)絡結(jié)構(gòu)block的差異,MobileNetV3和ShuffleNetV2我均只融合了三個特征層。經(jīng)過大量實驗測試,以及一些經(jīng)驗化剪枝,最后選定的是ShuffleNetV2+FPN,寬度因子為0.75,同時對通道數(shù)做了一定的剪枝。
同時把Sigmoid換成了一個近似實現(xiàn):0.5 * (x/(1+torch.abs(x)))+0.5,訓練精度差不多,且速度略快幾毫秒,而且還順便解決了推理工具結(jié)果不對齊的問題。
后來等模型訓練過程中刷到了一篇文章《姿態(tài)估計技巧總結(jié)》,嘗試了其中一些方案,其中Bone Loss還是有提升的,AID也有輕微提升不過不明顯,Automatic Weighted Loss對此模型沒有明顯效果,嘗試了Mish速度下降很多,也就沒有訓練了。
最后就是后處理的代碼實現(xiàn),因為涉及一些矩陣運算,本來是打算用C++版的OpenCV的Mat運算來做的,后來發(fā)現(xiàn)其實涉及的矩陣運算也就是一些element-wise操作,且特征圖大小也就48*48,于是干脆用一維數(shù)組來做了,一次循環(huán)就可以搞定乘權(quán)重和記錄極大值坐標,時間復雜度O(1)。
至此,整個項目也就接近了尾聲,自己挖的坑終于填的七七八八,最終在驗證集、測試集上acc都達到了99%(都這么高了,也就沒打算蒸餾了),速度在嵌入式CPU上能跑到60+ms??梢暬Ч策€不錯,不過依舊有提升空間,鑒于驗證集已經(jīng)很高了,后續(xù)主要的優(yōu)化方向應該是擴充數(shù)據(jù)了。
五、總結(jié)
總結(jié)一下,一開始不熟悉人體姿態(tài)估計領(lǐng)域,走了一些彎路。同時在方案選擇中也做了不少貌似無用的嘗試,不過也正是對Simple Baselines和Lighweight OpenPose的熟悉,才讓自己有機會深入了解人體姿態(tài)估計的相關(guān)流程,從而根據(jù)一個MoveNet的模型文件就復現(xiàn)出整個訓練流程。最終經(jīng)過優(yōu)化,精度接近飽和,速度也從原始Movenet純backbone跑80ms提升到了包含后處理整個流程達到60多ms。
最后慣例來個展望吧,時間有限,還有很多可以嘗試的東西,列出一些供參考:
-
我的場景首先要保證速度可以接受,其次才是精度,因此這套方案的精度還有很大的提升空間,如果是用高端手機來跑、甚至是GPU終端,換用容量更大的backbone,應該可以達到很高的精度;
-
嘗試一下DSNT、DARK的方案;
-
嘗試下Adversarial Semantic Data Augmentation for Human Pose Estimation論文提出的A數(shù)據(jù)增強,mpll數(shù)據(jù)集評測榜達到94.1%,貌似目前最高的;
-
量化感知訓練(QAT),因為PyTorch的QAT貌似不支持導出onnx,也就無法轉(zhuǎn)換到推理框架使用,因此沒有做。而訓練后量化實際加速很少,掉點明顯,分類任務還好,MoveNet有精細化的heatmap乘以權(quán)重的操作,誤差影響更大所以也不考慮;后續(xù)可以考慮使用Tensorflow來做QAT,因為本身谷歌MoveNet就是TF訓練的模型;
-
試試新出的MicroNet,今年的輕量級網(wǎng)絡,關(guān)注了一段時間最近代碼也開源了,論文實驗表示甩了Mobilenetv3和ShuffleNetv2一大截,有空可以玩玩;
才疏學淺,有不對的地方歡迎指出,也歡迎友好交流討論。