如何在Tensorflow.js中處理MNIST圖像數(shù)據(jù)
有人開玩笑說有 80% 的數(shù)據(jù)科學家在清理數(shù)據(jù),剩下的 20% 在抱怨清理數(shù)據(jù)……在數(shù)據(jù)科學工作中,清理數(shù)據(jù)所占比例比外人想象的要多得多。一般而言,訓練模型通常只占機器學習或數(shù)據(jù)科學家工作的一小部分(少于 10%)。
——Kaggle CEO Antony Goldbloom |
對任何一個機器學習問題而言,數(shù)據(jù)處理都是很重要的一步。本文將采用 Tensorflow.js(0.11.1)的 MNIST 樣例
(https://github.com/tensorflow/tfjs-examples/blob/master/mnist/data.js),逐行運行數(shù)據(jù)處理的代碼。
MNIST 樣例
- 18 import * as tf from '@tensorflow/tfjs';
- 19
- 20 const IMAGE_SIZE = 784;
- 21 const NUM_CLASSES = 10;
- 22 const NUM_DATASET_ELEMENTS = 65000;
- 23
- 24 const NUM_TRAIN_ELEMENTS = 55000;
- 25 const NUM_TEST_ELEMENTS = NUM_DATASET_ELEMENTS - NUM_TRAIN_ELEMENTS;
- 26
- 27 const MNIST_IMAGES_SPRITE_PATH =
- 28 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png';
- 29 const MNIST_LABELS_PATH =
- 30 'https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8';`
首先,導入 TensorFlow(確保你在轉(zhuǎn)譯代碼)并建立一些常量,包括:
- IMAGE_SIZE:圖像尺寸(28*28=784)
- NUM_CLASSES:標簽類別的數(shù)量(這個數(shù)字可以是 0~9,所以這里有 10 類)
- NUM_DATASET_ELEMENTS:圖像總數(shù)量(65000)
- NUM_TRAIN_ELEMENTS:訓練集中圖像的數(shù)量(55000)
- NUM_TEST_ELEMENTS:測試集中圖像的數(shù)量(10000,亦稱余數(shù))
- MNIST_IMAGES_SPRITE_PATH&MNIST_LABELS_PATH:圖像和標簽的路徑
將這些圖像級聯(lián)為一個巨大的圖像,如下圖所示:
MNISTData
接下來,從第 38 行開始是 MnistData,該類別使用以下函數(shù):
- load:負責異步加載圖像和標注數(shù)據(jù);
- nextTrainBatch:加載下一個訓練批;
- nextTestBatch:加載下一個測試批;
- nextBatch:返回下一個批的通用函數(shù),該函數(shù)的使用取決于是在訓練集還是測試集。
本文屬于入門文章,因此只采用 load 函數(shù)。
load
- async load() {
- // Make a request for the MNIST sprited image.
- const img = new Image();
- const canvas = document.createElement('canvas');
- const ctx = canvas.getContext('2d');
異步函數(shù)(async)是 Javascript 中相對較新的語言功能,因此你需要一個轉(zhuǎn)譯器。
Image 對象是表示內(nèi)存中圖像的本地 DOM 函數(shù),在圖像加載時提供可訪問圖像屬性的回調(diào)。canvas 是 DOM 的另一個元素,該元素可以提供訪問像素數(shù)組的簡單方式,還可以通過上下文對其進行處理。
因為這兩個都是 DOM 元素,所以如果用 Node.js(或 Web Worker)則無需訪問這些元素。有關(guān)其他可替代的方法,請參見下文。
imgRequest
- const imgRequest = new Promise((resolve, reject) => {
- img.crossOrigin = '';
- img.onload = () => {
- imgimg.width = img.naturalWidth;
- imgimg.height = img.naturalHeight;
該代碼初始化了一個 new promise,圖像加載成功后該 promise 結(jié)束。該示例沒有明確處理誤差狀態(tài)。
crossOrigin 是一個允許跨域加載圖像并可以在與 DOM 交互時解決 CORS(跨源資源共享,cross-origin resource sharing)問題的圖像屬性。naturalWidth 和 naturalHeight 指加載圖像的原始維度,在計算時可以強制校正圖像尺寸。
- const datasetBytesBuffer =
- new ArrayBuffer(NUMDATASETELEMENTS * IMAGESIZE * 4);
- 57
- 58 const chunkSize = 5000;
- 59 canvas.width = img.width;
- 60 canvas.height = chunkSize;
該代碼初始化了一個新的 buffer,包含每一張圖的每一個像素。它將圖像總數(shù)和每張圖像的尺寸和通道數(shù)量相乘。
我認為 chunkSize 的用處在于防止 UI 一次將太多數(shù)據(jù)加載到內(nèi)存中,但并不能 100% 確定。
- 62 for (let i = 0; i < NUMDATASETELEMENTS / chunkSize; i++) {
- 63 const datasetBytesView = new Float32Array(
- 64 datasetBytesBuffer, i * IMAGESIZE * chunkSize * 4,
- IMAGESIZE * chunkSize);
- 66 ctx.drawImage(
- 67 img, 0, i * chunkSize, img.width, chunkSize, 0, 0, img.width,
- 68 chunkSize);
- 69
- 70 const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
該代碼遍歷了每一張 sprite 圖像,并為該迭代初始化了一個新的 TypedArray。接下來,上下文圖像獲取了一個繪制出來的圖像塊。最終,使用上下文的 getImageData 函數(shù)將繪制出來的圖像轉(zhuǎn)換為圖像數(shù)據(jù),返回的是一個表示底層像素數(shù)據(jù)的對象。
- 72 for (let j = 0; j < imageData.data.length / 4; j++) {
- 73 // All channels hold an equal value since the image is grayscale, so
- 74 // just read the red channel.
- 75 datasetBytesView[j] = imageData.data[j * 4] / 255;
- 76 }
- 77 }
我們遍歷了這些像素并除以 255(像素的可能***值),以將值限制在 0 到 1 之間。只有紅色的通道是必要的,因為它是灰度圖像。
- 78 this.datasetImages = new Float32Array(datasetBytesBuffer);
- 79
- 80 resolve();
- 81 };
- 82 img.src = MNISTIMAGESSPRITEPATH;
- );
這一行創(chuàng)建了 buffer,將其映射到保存了我們像素數(shù)據(jù)的新 TypedArray 中,然后結(jié)束了該 promise。事實上***一行(設(shè)置 src 屬性)才真正啟動函數(shù)并加載圖像。
起初困擾我的一件事是 TypedArray 的行為與其底層數(shù)據(jù) buffer 相關(guān)。你可能注意到了,在循環(huán)中設(shè)置了 datasetBytesView,但它永遠都不會返回。
datasetBytesView 引用了緩沖區(qū)的 datasetBytesBuffer(初始化使用)。當代碼更新像素數(shù)據(jù)時,它會間接編輯緩沖區(qū)的值,然后將其轉(zhuǎn)換為 78 行的 new Float32Array。
獲取 DOM 外的圖像數(shù)據(jù)
如果你在 DOM 中,使用 DOM 即可,瀏覽器(通過 canvas)負責確定圖像的格式以及將緩沖區(qū)數(shù)據(jù)轉(zhuǎn)換為像素。但是如果你在 DOM 外工作的話(也就是說用的是 Node.js 或 Web Worker),那就需要一種替代方法。
fetch 提供了一種稱為 response.arrayBuffer 的機制,這種機制使你可以訪問文件的底層緩沖。我們可以用這種方法在完全避免 DOM 的情況下手動讀取字節(jié)。這里有一種編寫上述代碼的替代方法(這種方法需要 fetch,可以用 isomorphic-fetch 等方法在 Node 中進行多邊填充):
- const imgRequest = fetch(MNISTIMAGESSPRITE_PATH).then(resp => resp.arrayBuffer()).then(buffer => {
- return new Promise(resolve => {
- const reader = new PNGReader(buffer);
- return reader.parse((err, png) => {
- const pixels = Float32Array.from(png.pixels).map(pixel => {
- return pixel / 255;
- });
- this.datasetImages = pixels;
- resolve();
- });
- });
- });
這為特定圖像返回了一個緩沖數(shù)組。在寫這篇文章時,我***次試著解析傳入的緩沖,但我不建議這樣做。如果需要的話,我推薦使用 pngjs 進行 png 的解析。當處理其他格式的圖像時,則需要自己寫解析函數(shù)。
有待深入
理解數(shù)據(jù)操作是用 JavaScript 進行機器學習的重要部分。通過理解本文所述用例與需求,我們可以根據(jù)需求在僅使用幾個關(guān)鍵函數(shù)的情況下對數(shù)據(jù)進行格式化。
TensorFlow.js 團隊一直在改進 TensorFlow.js 的底層數(shù)據(jù) API,這有助于更多地滿足需求。這也意味著,隨著 TensorFlow.js 的不斷改進和發(fā)展,API 也會繼續(xù)前進,跟上發(fā)展的步伐。
原文鏈接:
https://medium.freecodecamp.org/how-to-deal-with-mnist-image-data-in-tensorflow-js-169a2d6941dd
【本文是51CTO專欄機構(gòu)“機器之心”的原創(chuàng)文章,微信公眾號“機器之心( id: almosthuman2014)”】