數(shù)據(jù)工程中的單元測(cè)試完全指南
在數(shù)據(jù)工程領(lǐng)域中,經(jīng)常被忽視的一項(xiàng)實(shí)踐是單元測(cè)試。許多人可能認(rèn)為單元測(cè)試僅僅是一種軟件開(kāi)發(fā)方法論,但事實(shí)遠(yuǎn)非如此。隨著我們努力構(gòu)建穩(wěn)健、無(wú)錯(cuò)誤的數(shù)據(jù)流水線和SQL數(shù)據(jù)模型,單元測(cè)試在數(shù)據(jù)工程中的價(jià)值變得越來(lái)越清晰。
本文帶你深入探索如何將這些成熟的軟件工程實(shí)踐應(yīng)用到數(shù)據(jù)工程中。
1 單元測(cè)試的重要性
在數(shù)據(jù)工程的背景下,采用單元測(cè)試可以確保您的數(shù)據(jù)和業(yè)務(wù)邏輯的準(zhǔn)確性,進(jìn)而產(chǎn)出高質(zhì)量的數(shù)據(jù),獲得您的數(shù)據(jù)分析師、科學(xué)家和決策者對(duì)數(shù)據(jù)的信任。
2 單元測(cè)試數(shù)據(jù)流水線
數(shù)據(jù)流水線通常涉及復(fù)雜的數(shù)據(jù)抽取、轉(zhuǎn)換和加載(ETL)操作序列,出錯(cuò)的可能性很大。為了對(duì)這些操作進(jìn)行單元測(cè)試,我們將流水線拆分為單個(gè)組件,并對(duì)每個(gè)組件進(jìn)行獨(dú)立驗(yàn)證。
以一個(gè)簡(jiǎn)單的流水線為例,該流水線從CSV文件中提取數(shù)據(jù),通過(guò)清除空值來(lái)轉(zhuǎn)換數(shù)據(jù),然后將其加載到數(shù)據(jù)庫(kù)中。以下是使用pandas的基于Python的示例:
import pandas as pd
from sqlalchemy import create_engine
# 加載CSV文件的函數(shù)
def load_data(file_name):
data = pd.read_csv(file_name)
return data
# 清理數(shù)據(jù)的函數(shù)
def clean_data(data):
data = data.dropna()
return data
# 將數(shù)據(jù)保存到SQL數(shù)據(jù)庫(kù)的函數(shù)
def save_data(data, db_string, table_name):
engine = create_engine(db_string)
data.to_sql(table_name, engine, if_exists='replace')
# 運(yùn)行數(shù)據(jù)流水線
data = load_data('data.csv')
data = clean_data(data)
save_data(data, 'sqlite:///database.db', 'my_table')
為了對(duì)這個(gè)流水線進(jìn)行單元測(cè)試,我們使用像pytest這樣的庫(kù)為每個(gè)函數(shù)編寫(xiě)單獨(dú)的測(cè)試。
在這個(gè)示例中,有三個(gè)主要的函數(shù):load_data、clean_data和save_data。我們會(huì)為每個(gè)函數(shù)編寫(xiě)測(cè)試。對(duì)于load_data和save_data,需要設(shè)置一個(gè)臨時(shí)的CSV文件和SQLite數(shù)據(jù)庫(kù),可以使用pytest庫(kù)的fixture功能來(lái)實(shí)現(xiàn)。
import os
import pandas as pd
import pytest
from sqlalchemy import create_engine, inspect
# 使用pytest fixture來(lái)設(shè)置臨時(shí)的CSV文件和SQLite數(shù)據(jù)庫(kù)
@pytest.fixture
def csv_file(tmp_path):
data = pd.DataFrame({
'name': ['John', 'Jane', 'Doe'],
'age': [34, None, 56] # Jane的年齡缺失
})
file_path = tmp_path / "data.csv"
data.to_csv(file_path, index=False)
return file_path
@pytest.fixture
def sqlite_db(tmp_path):
file_path = tmp_path / "database.db"
return 'sqlite:///' + str(file_path)
def test_load_data(csv_file):
data = load_data(csv_file)
assert 'name' in data.columns
assert 'age' in data.columns
assert len(data) == 3
def test_clean_data(csv_file):
data = load_data(csv_file)
data = clean_data(data)
assert data['age'].isna().sum() == 0
assert len(data) == 2 # Jane的記錄應(yīng)該被刪除
def test_save_data(csv_file, sqlite_db):
data = load_data(csv_file)
data = clean_data(data)
save_data(data, sqlite_db, 'my_table')
# 檢查數(shù)據(jù)是否保存正確
engine = create_engine(sqlite_db)
inspector = inspect(engine)
tables = inspector.get_table_names()
assert 'my_table' in tables
loaded_data = pd.read_sql('my_table', engine)
assert len(loaded_data) == 2 # 只應(yīng)該存在John和Doe的記錄
這里是另一個(gè)例子:假設(shè)您有一個(gè)從CSV文件中加載數(shù)據(jù)并將其中的“日期”列從字符串轉(zhuǎn)換為日期時(shí)間的流水線:
def convert_date(data, date_column):
data[date_column] = pd.to_datetime(data[date_column])
return data
為上述函數(shù)編寫(xiě)的單元測(cè)試將傳入具有已知日期字符串格式的DataFrame。然后,它將驗(yàn)證函數(shù)是否正確將日期轉(zhuǎn)換為日期時(shí)間對(duì)象,并且它是否適當(dāng)處理無(wú)效格式。
我們?yōu)樯鲜鰣?chǎng)景編寫(xiě)一個(gè)單元測(cè)試。該測(cè)試首先使用有效日期檢查函數(shù),斷言輸出DataFrame中的“date”列確實(shí)是datetime類型,并且值與預(yù)期相符。然后,它檢查在給出無(wú)效日期時(shí),函數(shù)是否正確引發(fā)了ValueError。
import pandas as pd
import pytest
def test_convert_date():
# 使用有效日期進(jìn)行測(cè)試
test_data = pd.DataFrame({
'date': ['2021-01-01', '2021-01-02']
})
converted_data = convert_date(test_data.copy(), 'date')
assert pd.api.types.is_datetime64_any_dtype(converted_data['date'])
assert converted_data.loc[0, 'date'] == pd.Timestamp('2021-01-01')
assert converted_data.loc[1, 'date'] == pd.Timestamp('2021-01-02')
# 使用無(wú)效日期進(jìn)行測(cè)試
test_data = pd.DataFrame({
'date': ['2021-13-01'] # 這個(gè)日期是無(wú)效的,因?yàn)闆](méi)有第13個(gè)月
})
with pytest.raises(ValueError):
convert_date(test_data, 'date')
以下是最后一個(gè)例子:假設(shè)您有一個(gè)加載數(shù)據(jù)并進(jìn)行聚合的流水線,計(jì)算每個(gè)地區(qū)的總銷售額:
def aggregate_sales(data):
aggregated = data.groupby('region').sales.sum().reset_index()
return aggregated
為該函數(shù)編寫(xiě)的單元測(cè)試將向其傳遞具有各個(gè)地區(qū)銷售數(shù)據(jù)的DataFrame。測(cè)試將驗(yàn)證函數(shù)是否正確計(jì)算每個(gè)地區(qū)的總銷售額。
我們?yōu)樵摵瘮?shù)編寫(xiě)一個(gè)單元測(cè)試。在這個(gè)測(cè)試中,我們首先向aggregate_sales函數(shù)傳遞一個(gè)具有已知銷售數(shù)據(jù)的DataFrame,并檢查它是否正確聚合了銷售額。然后,向其傳遞一個(gè)沒(méi)有銷售數(shù)據(jù)的DataFrame,并檢查它是否正確將這些銷售額聚合為0。這樣可以確保函數(shù)正確處理典型情況和邊緣情況。
以下是使用pytest庫(kù)為aggregate_sales函數(shù)編寫(xiě)單元測(cè)試的示例:
import pandas as pd
import pytest
def test_aggregate_sales():
# 各個(gè)地區(qū)的銷售數(shù)據(jù)
test_data = pd.DataFrame({
'region': ['North', 'North', 'South', 'South', 'East', 'East', 'West', 'West'],
'sales': [100, 200, 300, 400, 500, 600, 700, 800]
})
aggregated = aggregate_sales(test_data)
assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 300
assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 700
assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 1100
assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 1500
# 沒(méi)有銷售數(shù)據(jù)的測(cè)試
test_data = pd.DataFrame({
'region': ['North', 'South', 'East', 'West'],
'sales': [0, 0, 0, 0]
})
aggregated = aggregate_sales(test_data)
assert aggregated.loc[aggregated['region'] == 'North', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'South', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'East', 'sales'].values[0] == 0
assert aggregated.loc[aggregated['region'] == 'West', 'sales'].values[0] == 0
本文轉(zhuǎn)載自微信公眾號(hào)「Java學(xué)研大本營(yíng)」,可以通過(guò)以下二維碼關(guān)注。轉(zhuǎn)載本文請(qǐng)聯(lián)系公眾號(hào)。