使用Rust的Linfa和Polars庫進(jìn)行機(jī)器學(xué)習(xí):線性回歸
在這篇文章中,我們將使用Rust的Linfa庫和Polars庫來實(shí)現(xiàn)機(jī)器學(xué)習(xí)中的線性回歸算法。
Linfa crate旨在提供一個全面的工具包來使用Rust構(gòu)建機(jī)器學(xué)習(xí)應(yīng)用程序。
Polars是Rust的一個DataFrame庫,它基于Apache Arrow的內(nèi)存模型。Apache arrow提供了非常高效的列數(shù)據(jù)結(jié)構(gòu),并且正在成為列數(shù)據(jù)結(jié)構(gòu)事實(shí)上的標(biāo)準(zhǔn)。
在下面的例子中,我們使用一個糖尿病數(shù)據(jù)集來訓(xùn)練線性回歸算法。
使用以下命令創(chuàng)建一個Rust新項目:
cargo new machine_learning_linfa
在Cargo.toml文件中加入以下依賴項:
[dependencies]
linfa = "0.7.0"
linfa-linear = "0.7.0"
ndarray = "0.15.6"
polars = { version = "0.35.4", features = ["ndarray"]}
在項目根目錄下創(chuàng)建一個diabetes_file.csv文件,將數(shù)據(jù)集寫入文件。
AGE SEX BMI BP S1 S2 S3 S4 S5 S6 Y
59 2 32.1 101 157 93.2 38 4 4.8598 87 151
48 1 21.6 87 183 103.2 70 3 3.8918 69 75
72 2 30.5 93 156 93.6 41 4 4.6728 85 141
24 1 25.3 84 198 131.4 40 5 4.8903 89 206
50 1 23 101 192 125.4 52 4 4.2905 80 135
23 1 22.6 89 139 64.8 61 2 4.1897 68 97
36 2 22 90 160 99.6 50 3 3.9512 82 138
66 2 26.2 114 255 185 56 4.55 4.2485 92 63
60 2 32.1 83 179 119.4 42 4 4.4773 94 110
.............
數(shù)據(jù)集從這里下載:https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt
在src/main.rs文件中寫入以下代碼:
use linfa::prelude::*;
use linfa::traits::Fit;
use linfa_linear::LinearRegression;
use ndarray::{ArrayBase, OwnedRepr};
use polars::prelude::*; // Import polars
fn main() -> Result<(), Box<dyn std::error::Error>> {
// 將制表符定義為分隔符
let separator = b'\t';
let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?
.infer_schema(None)
.with_separator(separator)
.has_header(true)
.finish()?;
println!("{:?}", df);
// 提取并轉(zhuǎn)換目標(biāo)列
let age_series = df.column("AGE")?.cast(&DataType::Float64)?;
let target = age_series.f64()?;
println!("Creating features dataset");
let mut features = df.drop("AGE")?;
// 遍歷列并將每個列強(qiáng)制轉(zhuǎn)換為Float64
for col_name in features.get_column_names_owned() {
let casted_col = df
.column(&col_name)?
.cast(&DataType::Float64)
.expect("Failed to cast column");
features.with_column(casted_col)?;
}
println!("{:?}", df);
let features_ndarray: ArrayBase<OwnedRepr<_>, _> =
features.to_ndarray::<Float64Type>(IndexOrder::C)?;
let target_ndarray = target.to_ndarray()?.to_owned();
let (dataset_training, dataset_validation) =
Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);
// 訓(xùn)練模型
let model = LinearRegression::default().fit(&dataset_training)?;
// 預(yù)測
let pred = model.predict(&dataset_validation);
// 評價模型
let r2 = pred.r2(&dataset_validation)?;
println!("r2 from prediction: {}", r2);
Ok(())
}
- 使用polar的CSV reader讀取CSV文件。
- 將數(shù)據(jù)幀打印到控制臺以供檢查。
- 從DataFrame中提取“AGE”列作為線性回歸的目標(biāo)變量。將目標(biāo)列強(qiáng)制轉(zhuǎn)換為Float64(雙精度浮點(diǎn)數(shù)),這是機(jī)器學(xué)習(xí)中數(shù)值數(shù)據(jù)的常用格式。
- 將features DataFrame轉(zhuǎn)換為narray::ArrayBase(一個多維數(shù)組)以與linfa兼容。將目標(biāo)序列轉(zhuǎn)換為數(shù)組,這些數(shù)組與用于機(jī)器學(xué)習(xí)的linfa庫兼容。
- 使用80-20的比例將數(shù)據(jù)集分割為訓(xùn)練集和驗證集,這是機(jī)器學(xué)習(xí)中評估模型在未知數(shù)據(jù)上的常見做法。
- 使用linfa的線性回歸算法在訓(xùn)練數(shù)據(jù)集上訓(xùn)練線性回歸模型。
- 使用訓(xùn)練好的模型對驗證數(shù)據(jù)集進(jìn)行預(yù)測。
- 計算驗證數(shù)據(jù)集上的R2(決定系數(shù))度量,以評估模型的性能。R2值表示回歸預(yù)測與實(shí)際數(shù)據(jù)點(diǎn)的近似程度。
執(zhí)行cargo run,運(yùn)行結(jié)果如下:
shape: (442, 11)
┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐
│ AGE ┆ SEX ┆ BMI ┆ BP ┆ … ┆ S4 ┆ S5 ┆ S6 ┆ Y │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ i64 ┆ i64 │
╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡
│ 59 ┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0 ┆ 4.8598 ┆ 87 ┆ 151 │
│ 48 ┆ 1 ┆ 21.6 ┆ 87.0 ┆ … ┆ 3.0 ┆ 3.8918 ┆ 69 ┆ 75 │
│ 72 ┆ 2 ┆ 30.5 ┆ 93.0 ┆ … ┆ 4.0 ┆ 4.6728 ┆ 85 ┆ 141 │
│ 24 ┆ 1 ┆ 25.3 ┆ 84.0 ┆ … ┆ 5.0 ┆ 4.8903 ┆ 89 ┆ 206 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 47 ┆ 2 ┆ 24.9 ┆ 75.0 ┆ … ┆ 5.0 ┆ 4.4427 ┆ 102 ┆ 104 │
│ 60 ┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95 ┆ 132 │
│ 36 ┆ 1 ┆ 30.0 ┆ 95.0 ┆ … ┆ 4.79 ┆ 5.1299 ┆ 85 ┆ 220 │
│ 36 ┆ 1 ┆ 19.6 ┆ 71.0 ┆ … ┆ 3.0 ┆ 4.5951 ┆ 92 ┆ 57 │
└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘
Creating features dataset
shape: (442, 11)
┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐
│ AGE ┆ SEX ┆ BMI ┆ BP ┆ … ┆ S4 ┆ S5 ┆ S6 ┆ Y │
│ --- ┆ --- ┆ --- ┆ --- ┆ ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ i64 ┆ f64 ┆ f64 ┆ ┆ f64 ┆ f64 ┆ i64 ┆ i64 │
╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡
│ 59 ┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0 ┆ 4.8598 ┆ 87 ┆ 151 │
│ 48 ┆ 1 ┆ 21.6 ┆ 87.0 ┆ … ┆ 3.0 ┆ 3.8918 ┆ 69 ┆ 75 │
│ 72 ┆ 2 ┆ 30.5 ┆ 93.0 ┆ … ┆ 4.0 ┆ 4.6728 ┆ 85 ┆ 141 │
│ 24 ┆ 1 ┆ 25.3 ┆ 84.0 ┆ … ┆ 5.0 ┆ 4.8903 ┆ 89 ┆ 206 │
│ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … ┆ … │
│ 47 ┆ 2 ┆ 24.9 ┆ 75.0 ┆ … ┆ 5.0 ┆ 4.4427 ┆ 102 ┆ 104 │
│ 60 ┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95 ┆ 132 │
│ 36 ┆ 1 ┆ 30.0 ┆ 95.0 ┆ … ┆ 4.79 ┆ 5.1299 ┆ 85 ┆ 220 │
│ 36 ┆ 1 ┆ 19.6 ┆ 71.0 ┆ … ┆ 3.0 ┆ 4.5951 ┆ 92 ┆ 57 │
└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘
r2 from prediction: 0.15937814745521017
對于優(yōu)先考慮快速迭代和快速原型的數(shù)據(jù)科學(xué)家來說,Rust的編譯時間可能是令人頭疼的問題。Rust的強(qiáng)靜態(tài)類型系統(tǒng)雖然有利于確保類型安全和減少運(yùn)行時錯誤,但也會在編碼過程中增加一層復(fù)雜性。