掌握 PyTorch 張量乘法:八個(gè)關(guān)鍵函數(shù)與應(yīng)用場景對比解析
PyTorch提供了幾種張量乘法的方法,每種方法都是不同的,并且有不同的應(yīng)用。我們來詳細(xì)介紹每個(gè)方法,并且詳細(xì)解釋這些函數(shù)有什么區(qū)別:
一、torch.matmul
torch.matmul 是 PyTorch 中用于矩陣乘法的函數(shù)。它能夠處理各種不同維度的張量,并根據(jù)張量的維度自動(dòng)調(diào)整其操作方式。
torch.matmul 可以執(zhí)行以下幾種矩陣乘法:
- 二維張量之間的矩陣乘法:
- 這是經(jīng)典的矩陣乘法操作。當(dāng)兩個(gè)張量都是二維的 (即矩陣),torch.matmul 進(jìn)行標(biāo)準(zhǔn)的矩陣乘法操作。
- 例如:假設(shè) A 是形狀為 (m, n) 的張量,B 是形狀為 (n, p) 的張量,那么 torch.matmul(A, B) 結(jié)果是一個(gè)形狀為 (m, p) 的張量。
- 高維張量之間的矩陣乘法:
torch.matmul 可以處理更高維的張量。當(dāng)輸入張量的維度大于2時(shí),它將執(zhí)行批量矩陣乘法。
對于形狀為 (..., m, n) 的張量 A 和形狀為 (..., n, p) 的張量 B,torch.matmul(A, B) 的結(jié)果是形狀為 (..., m, p) 的張量,其中 ... 表示相同的批量維度。批量維度部分將自動(dòng)廣播。
一維和二維張量的乘法:
當(dāng)?shù)谝粋€(gè)張量是1D張量(向量),第二個(gè)張量是2D張量時(shí),torch.matmul 會(huì)將1D張量視為行向量(或列向量)參與矩陣乘法。
例如:A 是形狀為 (n,) 的張量,B 是形狀為 (n, p) 的張量,那么 torch.matmul(A, B) 的結(jié)果是形狀為 (p,) 的張量。
反之,如果第一個(gè)張量是2D張量,第二個(gè)是1D張量,則結(jié)果是一個(gè)形狀為 (m,) 的張量。
import torch
# 示例 1: 二維張量之間的矩陣乘法
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.matmul(A, B)
print(result) # 輸出: tensor([[19, 22], [43, 50]])
# 示例 2: 高維張量之間的矩陣乘法(批次矩陣乘法)
A = torch.rand(2, 3, 4)
B = torch.rand(2, 4, 5)
result = torch.matmul(A, B)
print(result.shape) # 輸出: torch.Size([2, 3, 5])
# 示例 3: 1D 和 2D 張量之間的乘法
A = torch.tensor([1, 2, 3])
B = torch.tensor([[4, 5], [6, 7], [8, 9]])
result = torch.matmul(A, B)
print(result) # 輸出: tensor([40, 46])
torch.matmul 支持廣播,這意味著當(dāng)輸入張量的形狀不完全匹配時(shí),它可以自動(dòng)擴(kuò)展維度以進(jìn)行相應(yīng)的矩陣乘法。例如,兩個(gè)張量的形狀分別為 (1, 2, 3) 和 (3, 4),torch.matmul 可以將第二個(gè)張量自動(dòng)擴(kuò)展為形狀 (1, 3, 4),然后進(jìn)行批次矩陣乘法。
torch.matmul 底層使用了高效的線性代數(shù)庫(如 BLAS),確保了矩陣乘法的性能。對于大型矩陣運(yùn)算,torch.matmul 通常是非常高效的。它的靈活性和性能使得它成為 PyTorch 中廣泛使用的操作之一。
二、torch.mm
torch.mm 是 PyTorch 中專門用于二維張量(矩陣)之間進(jìn)行矩陣乘法的函數(shù)。與 torch.matmul 不同,torch.mm 僅適用于2D張量,并且不支持高維張量或廣播操作。
torch.mm 進(jìn)行標(biāo)準(zhǔn)的矩陣乘法操作,適用于兩個(gè)2D張量(矩陣)之間的乘法。對于形狀為 (m, n) 的張量 A 和形狀為 (n, p) 的張量 B,torch.mm(A, B) 的結(jié)果是一個(gè)形狀為 (m, p) 的張量。
import torch
# 示例 1: 二維張量之間的矩陣乘法
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([[7, 8], [9, 10], [11, 12]])
result = torch.mm(A, B)
print(result) # 輸出: tensor([[ 58, 64], [139, 154]])
在這個(gè)例子中,矩陣 A 的形狀是 (2, 3),矩陣 B 的形狀是 (3, 2)。結(jié)果矩陣的形狀是 (2, 2),且每個(gè)元素是通過對應(yīng)行與列元素的乘積之和計(jì)算得出的。
torch.mm 不支持廣播機(jī)制,這意味著兩個(gè)輸入矩陣的形狀必須嚴(yán)格匹配(即第一個(gè)矩陣的列數(shù)必須等于第二個(gè)矩陣的行數(shù))。
torch.mm 是針對二維矩陣乘法優(yōu)化的,它利用了底層的高效線性代數(shù)庫(如 BLAS)。當(dāng)僅需要進(jìn)行2D張量的矩陣乘法時(shí),torch.mm 可能比 torch.matmul 更加高效,因?yàn)樗苊饬?torch.matmul 中針對高維張量所做的額外處理。
注意事項(xiàng):
輸入張量必須是二維的。如果輸入是高維張量,使用 torch.mm 會(huì)導(dǎo)致錯(cuò)誤。兩個(gè)矩陣的形狀必須是兼容的,即第一個(gè)矩陣的列數(shù)必須等于第二個(gè)矩陣的行數(shù),否則會(huì)拋出維度不匹配的錯(cuò)誤。
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([1, 2])
# 這會(huì)引發(fā)一個(gè)錯(cuò)誤,因?yàn)?B 不是二維張量
result = torch.mm(A, B) # RuntimeError: matrices expected, got 1D, 2D tensors
在上面的示例中,由于 B 是一維張量而非二維矩陣,因此 torch.mm 會(huì)拋出錯(cuò)誤。解決方法是將 B 轉(zhuǎn)換為二維張量,例如 B.unsqueeze(1),以使其形狀符合矩陣乘法的要求。
torch.mm 常用于涉及矩陣乘法的各種場景,特別是在機(jī)器學(xué)習(xí)和深度學(xué)習(xí)中。例如,在神經(jīng)網(wǎng)絡(luò)的全連接層中,計(jì)算權(quán)重矩陣和輸入向量的乘積時(shí)經(jīng)常使用 torch.mm。此外,torch.mm 也可以用于線性代數(shù)中的基本操作,如求解線性方程組、計(jì)算特征值等。
torch.mm 它操作簡潔且性能高效,適用于需要進(jìn)行標(biāo)準(zhǔn)矩陣乘法的場景。對于二維矩陣乘法來說,它比 torch.matmul 更直接,因此在需要矩陣乘法且確定張量維度為2D的情況下,torch.mm 是一個(gè)理想的選擇。
三、torch.bmm
torch.bmm 是 PyTorch 中用于進(jìn)行批次矩陣乘法的函數(shù)。它專門處理三維張量,其中第一個(gè)維度表示批次大小,后兩個(gè)維度表示需要進(jìn)行矩陣乘法的矩陣。因此torch.bmm 是進(jìn)行批次矩陣操作的一個(gè)高效工具。
torch.bmm 用于對形狀為 (b, m, n) 的張量 A 和形狀為 (b, n, p) 的張量 B 進(jìn)行批次矩陣乘法,輸出結(jié)果是形狀為 (b, m, p) 的張量。這里,b 表示批次大小,m 和 n 是矩陣的行和列數(shù),p 是結(jié)果矩陣的列數(shù)。
import torch
# 示例: 批次矩陣乘法
A = torch.randn(10, 3, 4) # 形狀為 (10, 3, 4)
B = torch.randn(10, 4, 5) # 形狀為 (10, 4, 5)
result = torch.bmm(A, B)
print(result.shape) # 輸出: torch.Size([10, 3, 5])
在這個(gè)例子中:
- 張量 A 的形狀是 (10, 3, 4),表示有10個(gè)3x4的矩陣。
- 張量 B 的形狀是 (10, 4, 5),表示有10個(gè)4x5的矩陣。
- torch.bmm(A, B) 的結(jié)果是形狀為 (10, 3, 5) 的張量,這表示批次中的每一對矩陣都進(jìn)行了乘法操作。
torch.bmm 實(shí)際上是對批次中的每一對矩陣單獨(dú)進(jìn)行矩陣乘法操作,因此它要求輸入張量的第一個(gè)維度(即批次大?。┦窍嗤模⑶液髢蓚€(gè)維度必須滿足矩陣乘法的要求(即第一個(gè)矩陣的列數(shù)等于第二個(gè)矩陣的行數(shù))。
torch.bmm 對批次矩陣乘法進(jìn)行了優(yōu)化,使用了高效的底層線性代數(shù)庫。它在處理大型批次矩陣乘法時(shí)性能非常高效。由于它可以在批次上并行執(zhí)行操作,因此特別適用于深度學(xué)習(xí)中的批量計(jì)算場景。
torch.bmm 只適用于三維張量,其中第一個(gè)維度表示批次大小。對于高于或低于三維的張量,它會(huì)報(bào)錯(cuò)?;蛘哒f他是torch.mm的批次化版本。torch.bmm 不支持廣播機(jī)制,因此輸入張量的第一個(gè)維度(批次大?。┍仨殗?yán)格相同。
torch.bmm 常用于需要對多個(gè)矩陣對同時(shí)進(jìn)行乘法操作的場景,特別是在深度學(xué)習(xí)中的以下情境:
- 批量計(jì)算:在訓(xùn)練神經(jīng)網(wǎng)絡(luò)時(shí),我們通常將輸入數(shù)據(jù)分批處理,每批次數(shù)據(jù)對應(yīng)多個(gè)矩陣。torch.bmm 可以有效地處理這種批次矩陣操作。
- 圖卷積網(wǎng)絡(luò)(GCN):在圖神經(jīng)網(wǎng)絡(luò)中,批次矩陣乘法經(jīng)常用于計(jì)算節(jié)點(diǎn)特征和鄰接矩陣的乘積。
- 時(shí)間序列模型:在時(shí)間序列建模中,可能需要對每個(gè)時(shí)間步長應(yīng)用不同的變換矩陣,這時(shí)可以使用 torch.bmm 進(jìn)行批量處理。
torch.bmm 是專門用于批次矩陣乘法。當(dāng)需要對多個(gè)矩陣對同時(shí)進(jìn)行乘法操作時(shí),它提供了高效且簡潔的解決方案。
四、torch.mul
torch.mul 是 PyTorch 中用于執(zhí)行元素級乘法(也稱為逐元素乘法)的函數(shù)。它可以對張量的每個(gè)元素進(jìn)行對應(yīng)位置的乘法操作,支持任意維度的張量,并且可以自動(dòng)進(jìn)行廣播操作來適應(yīng)不同形狀的張量。
torch.mul 可以對兩個(gè)張量的對應(yīng)元素進(jìn)行乘法運(yùn)算。假設(shè)有兩個(gè)張量 A 和 B,那么 torch.mul(A, B) 將返回一個(gè)新的張量,其中每個(gè)元素是 A 和 B 在相同位置的元素的乘積。這個(gè)操作等同于使用 * 操作符,如 A * B。
import torch
# 示例 1: 相同形狀的張量的元素級乘法
A = torch.tensor([1, 2, 3])
B = torch.tensor([4, 5, 6])
result = torch.mul(A, B)
print(result) # 輸出: tensor([ 4, 10, 18])
# 示例 2: 不同形狀的張量進(jìn)行廣播后的元素級乘法
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([10, 20, 30])
result = torch.mul(A, B)
print(result) # 輸出: tensor([[10, 40, 90], [40, 100, 180]])
# 示例 3: 通過標(biāo)量進(jìn)行元素級乘法
A = torch.tensor([1, 2, 3])
result = torch.mul(A, 10)
print(result) # 輸出: tensor([10, 20, 30])
在這些示例中:
- 在第一個(gè)示例中,A 和 B 是形狀相同的張量,因此對應(yīng)元素直接相乘。
- 在第二個(gè)示例中,A 是二維張量,而 B 是一維張量,PyTorch 自動(dòng)對 B 進(jìn)行廣播,使其形狀與 A 匹配,然后進(jìn)行逐元素乘法。
- 在第三個(gè)示例中,A 和一個(gè)標(biāo)量值相乘,每個(gè)元素都乘以該標(biāo)量。
torch.mul 支持廣播機(jī)制,這意味著當(dāng)兩個(gè)張量的形狀不完全相同時(shí),它可以自動(dòng)擴(kuò)展較小形狀的張量,使其與較大形狀的張量兼容,然后進(jìn)行逐元素乘法。
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
B = torch.tensor([10, 20, 30])
result = torch.mul(A, B)
在這個(gè)例子中,A 的形狀是 (2, 3),而 B 的形狀是 (3,)。PyTorch 自動(dòng)將 B 擴(kuò)展為 (2, 3),然后對每個(gè)對應(yīng)元素進(jìn)行乘法運(yùn)算。
torch.mul 是一個(gè)高效的逐元素操作,因?yàn)樗苯釉谠丶墑e上進(jìn)行計(jì)算,適用于需要對大批量數(shù)據(jù)進(jìn)行逐元素操作的場景。它可以充分利用現(xiàn)代硬件的并行計(jì)算能力(如GPU),在處理大型張量時(shí)非常高效。
注意事項(xiàng)
雖然 torch.mul 支持廣播,但在進(jìn)行操作時(shí),確保兩個(gè)張量的形狀是兼容的非常重要。如果形狀不兼容,將會(huì)引發(fā)運(yùn)行時(shí)錯(cuò)誤。當(dāng)使用標(biāo)量時(shí),標(biāo)量會(huì)被自動(dòng)廣播到張量的每個(gè)元素,因此直接操作是安全的。
import torch
A = torch.tensor([1, 2, 3])
B = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 形狀不兼容,無法進(jìn)行逐元素乘法
result = torch.mul(A, B) # 會(huì)引發(fā) RuntimeError: The size of tensor a (3) must match the size of tensor b (2) at non-singleton dimension 0
在這個(gè)錯(cuò)誤示例中,由于 A 是一維張量,而 B 是二維張量且第一個(gè)維度不匹配,因此無法廣播,導(dǎo)致錯(cuò)誤。
torch.mul 在許多機(jī)器學(xué)習(xí)和深度學(xué)習(xí)任務(wù)中都非常有用。例如:
- 權(quán)重調(diào)整:在神經(jīng)網(wǎng)絡(luò)中,可以通過 torch.mul 來逐元素調(diào)整權(quán)重或激活值。
- 掩碼操作:在圖像處理中,可以使用 torch.mul 來對圖像應(yīng)用掩碼,逐元素控制哪些部分需要保留或修改。
- 歸一化:可以逐元素將張量歸一化或縮放,以滿足特定的算法要求。
torch.mul 在處理各種張量操作時(shí)非常有用。它支持廣播機(jī)制,可以自動(dòng)適應(yīng)不同形狀的張量,從而在多種應(yīng)用場景中提供簡潔而高效的解決方案。
五、torch.mv
torch.mv 是 PyTorch 中用于進(jìn)行矩陣與向量乘法的函數(shù)。它專門用于二維張量(矩陣)和一維張量(向量)之間的乘法操作。torch.mv 是矩陣乘法的一種特殊情況,適用于當(dāng)你需要將矩陣乘以向量時(shí)使用。
torch.mv 執(zhí)行的是矩陣與向量的乘法操作。假設(shè)有一個(gè)矩陣 A,它的形狀為 (m, n),以及一個(gè)向量 v,它的形狀為 (n,),那么 torch.mv(A, v) 將返回一個(gè)形狀為 (m,) 的一維張量(向量),結(jié)果是矩陣 A 與向量 v 的乘積。
import torch
# 示例: 矩陣與向量的乘法
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
v = torch.tensor([7, 8, 9])
result = torch.mv(A, v)
print(result) # 輸出: tensor([ 50, 122])
在這個(gè)示例中,矩陣 A 的形狀為 (2, 3),向量 v 的形狀為 (3,)。通過 torch.mv(A, v),我們得到的結(jié)果是形狀為 (2,) 的向量 [50, 122],其中每個(gè)元素是通過矩陣與向量的標(biāo)準(zhǔn)乘法計(jì)算得出的。
torch.mv 執(zhí)行的矩陣與向量乘法遵循以下規(guī)則:對于矩陣 A 中的每一行,將該行與向量 v 的所有元素逐元素相乘,并將乘積的結(jié)果求和,得到一個(gè)標(biāo)量。這個(gè)標(biāo)量就是結(jié)果向量對應(yīng)位置的值。
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
v = torch.tensor([7, 8, 9])
result = torch.mv(A, v)
# 結(jié)果:
# result[0] = 1*7 + 2*8 + 3*9 = 50
# result[1] = 4*7 + 5*8 + 6*9 = 122
torch.mv 專門用于矩陣和向量的乘法,比通用的矩陣乘法函數(shù)如 torch.matmul 或 torch.mm 更加高效,因?yàn)樗苊饬藢Χ嘤嗑S度的處理。這使得 torch.mv 在執(zhí)行矩陣與向量乘法時(shí)速度更快,并且更適合用于大規(guī)模計(jì)算。
注意事項(xiàng)
矩陣 A 的列數(shù)(第二個(gè)維度)必須等于向量 v 的長度(第一個(gè)維度),否則將會(huì)報(bào)錯(cuò)。
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]])
v = torch.tensor([7, 8])
# 這將引發(fā)錯(cuò)誤,因?yàn)?v 的形狀與 A 的列數(shù)不匹配
result = torch.mv(A, v) # 會(huì)引發(fā) RuntimeError: size mismatch, m1: [2x3], m2: [2] at THTensorMath.cpp:41
在這個(gè)錯(cuò)誤示例中,向量 v 的長度與矩陣 A 的列數(shù)不匹配,因此無法進(jìn)行矩陣與向量乘法。
torch.mv 是 PyTorch 中用于執(zhí)行矩陣與向量乘法的專用函數(shù)。它對矩陣與向量乘法進(jìn)行了優(yōu)化,能夠高效處理這類操作,是線性代數(shù)、深度學(xué)習(xí)和科學(xué)計(jì)算中常用的工具。在許多應(yīng)用場景中都很有用,特別是在以下情況下:
- 線性代數(shù)操作:在計(jì)算線性方程組、特征值問題等線性代數(shù)問題時(shí),經(jīng)常需要進(jìn)行矩陣與向量的乘法。
- 神經(jīng)網(wǎng)絡(luò)計(jì)算:在神經(jīng)網(wǎng)絡(luò)的前向傳播過程中,特別是全連接層中,權(quán)重矩陣與輸入向量的乘法操作可以通過 torch.mv 高效地實(shí)現(xiàn)。
- 物理模擬:在一些物理模擬中,狀態(tài)向量與轉(zhuǎn)換矩陣的乘法操作可以通過 torch.mv 實(shí)現(xiàn)。
六、torch.dot
torch.dot 是 PyTorch 中用于計(jì)算兩個(gè)一維張量(即向量)之間的點(diǎn)乘(內(nèi)積)的函數(shù)。點(diǎn)乘是一種基本的向量操作,在許多數(shù)學(xué)和工程應(yīng)用中都有廣泛的應(yīng)用。
torch.dot 計(jì)算的是兩個(gè)向量之間的點(diǎn)積。假設(shè)有兩個(gè)向量 a 和 b,它們的長度相同(即形狀都為 (n,)),那么 torch.dot(a, b) 的結(jié)果是一個(gè)標(biāo)量(即一個(gè)數(shù)值),這個(gè)值是通過對應(yīng)位置的元素相乘后再求和得到的。
import torch
# 示例: 兩個(gè)向量的點(diǎn)乘
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.dot(a, b)
print(result) # 輸出: tensor(32)
在這個(gè)示例中:向量 a 的形狀為 (3,),向量 b 的形狀也是 (3,)。通過 torch.dot(a, b),我們得到了標(biāo)量 32,其計(jì)算過程為:1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32。
torch.dot 計(jì)算點(diǎn)乘的方式是逐元素相乘,然后將結(jié)果求和。對于兩個(gè)長度為 n 的向量 a 和 b,點(diǎn)積的計(jì)算公式如下:
result = (a[0] * b[0]) + (a[1] * b[1]) + ... + (a[n-1] * b[n-1])
torch.dot 是對兩個(gè)一維張量進(jìn)行點(diǎn)積的優(yōu)化實(shí)現(xiàn),由于其簡單的計(jì)算流程和對向量操作的專門優(yōu)化,它通常具有非常高的性能,特別是在 GPU 上處理大規(guī)模數(shù)據(jù)時(shí)表現(xiàn)尤為優(yōu)異。
torch.dot 僅適用于一維張量(向量),如果輸入的張量不是一維的,會(huì)引發(fā)錯(cuò)誤。并且torch.dot 返回一個(gè)標(biāo)量(標(biāo)量張量),而不是張量。由于點(diǎn)積的對稱性,torch.dot(a, b) 與 torch.dot(b, a) 的結(jié)果是相同的。
與其他操作的對比
- torch.matmul 和 torch.mm:這些函數(shù)用于矩陣乘法,適用于高維張量。torch.dot 只用于一維張量的點(diǎn)積。
- torch.mul:這是逐元素乘法,不是點(diǎn)積。torch.mul(a, b) 會(huì)返回一個(gè)與 a 和 b 形狀相同的張量,其中每個(gè)元素是對應(yīng)元素的乘積,而 torch.dot(a, b) 會(huì)返回一個(gè)標(biāo)量。
torch.dot 是一個(gè)簡單而高效的函數(shù),專門用于計(jì)算一維張量之間的點(diǎn)積。在許多數(shù)學(xué)、物理和工程應(yīng)用中,它是一個(gè)非常重要的工具。點(diǎn)積在很多場景中都有應(yīng)用,包括但不限于:
- 向量投影:在幾何中,點(diǎn)乘可以用于計(jì)算一個(gè)向量在另一個(gè)向量方向上的投影。
- 相似性計(jì)算:在信息檢索和機(jī)器學(xué)習(xí)中,兩個(gè)向量的點(diǎn)積可以用于衡量它們的相似性。例如,在詞向量(Word Embeddings)的相似性計(jì)算中,點(diǎn)積是常用的度量方法之一。
- 能量計(jì)算:在物理學(xué)中,點(diǎn)積用于計(jì)算力和位移的乘積(即功的計(jì)算)。
七、torch.outer
torch.outer 是 PyTorch 中用于計(jì)算兩個(gè)一維張量(即向量)之間的外積(外積矩陣)的函數(shù)。外積是線性代數(shù)中的一種基本運(yùn)算,結(jié)果是一個(gè)矩陣,其元素是兩個(gè)輸入向量各元素的乘積。
torch.outer 計(jì)算的是兩個(gè)向量的外積。假設(shè)有兩個(gè)向量 a 和 b,它們的形狀分別是 (n,) 和 (m,),那么 torch.outer(a, b) 的結(jié)果是一個(gè)形狀為 (n, m) 的二維張量(矩陣),這個(gè)矩陣中的元素由 a[i] * b[j] 計(jì)算得到。
import torch
# 示例: 兩個(gè)向量的外積
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.outer(a, b)
print(result)
# 輸出:
# tensor([[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18]])
在這個(gè)示例中:
- 向量 a 的形狀為 (3,),向量 b 的形狀也為 (3,)。
- 通過 torch.outer(a, b),我們得到了形狀為 (3, 3) 的矩陣。這個(gè)矩陣的每個(gè)元素都是由 a[i] 和 b[j] 的乘積計(jì)算得出。
torch.outer 是對兩個(gè)一維張量進(jìn)行外積的優(yōu)化實(shí)現(xiàn)。由于其操作涉及大量的元素乘法,因此在處理大型向量時(shí),特別是在 GPU 上計(jì)算,torch.outer 的性能表現(xiàn)十分出色。
torch.outer 僅適用于一維張量,即向量,并返回一個(gè)二維張量(矩陣),其形狀為 (n, m),其中 n 和 m 是輸入向量的長度。
與其他操作的對比
- torch.matmul 和 torch.mm:這些函數(shù)用于矩陣乘法,適用于高維張量。torch.outer 專用于計(jì)算兩個(gè)一維張量之間的外積。
- torch.mul:這是逐元素乘法。如果兩個(gè)張量的形狀相同,torch.mul(a, b) 將執(zhí)行逐元素乘法,而不是計(jì)算外積。
torch.outer 是一個(gè)用于計(jì)算兩個(gè)一維張量之間外積的高效工具。它在生成矩陣、處理雙線性形式、構(gòu)建張量積等應(yīng)用中非常有用。外積在很多場景中都有應(yīng)用,包括但不限于:
- 矩陣構(gòu)建:外積可用于生成特定類型的矩陣,例如克羅內(nèi)克積。
- 雙線性形式:在雙線性形式的表示中,外積經(jīng)常用于構(gòu)建張量。
- 機(jī)器學(xué)習(xí):在神經(jīng)網(wǎng)絡(luò)的權(quán)重更新、特征交互等場景中,外積運(yùn)算可以構(gòu)造高階特征。
8、torch.einsum
torch.einsum 是 PyTorch 中一個(gè)非常強(qiáng)大的函數(shù),它使用愛因斯坦求和約定(Einstein Summation Convention)來執(zhí)行復(fù)雜的張量操作。torch.einsum 的靈活性使得它可以用于各種矩陣和張量運(yùn)算,包括矩陣乘法、轉(zhuǎn)置、內(nèi)積、外積、以及其他高階張量運(yùn)算。
愛因斯坦求和約定是一種簡化張量操作的符號表示方法,其中重復(fù)的指標(biāo)自動(dòng)表示求和。torch.einsum 使用字符串表示張量操作,將輸入張量的維度與輸出維度通過指定的模式進(jìn)行映射。
torch.einsum(equation, *operands)
equation:一個(gè)字符串,描述了輸入和輸出張量的維度關(guān)系。
*operands:一個(gè)或多個(gè)張量,參與計(jì)算的張量。
使用示例
1、矩陣乘法
矩陣乘法是最常見的張量操作之一。對于兩個(gè)矩陣 A 和 B,使用 torch.einsum 進(jìn)行矩陣乘法可以表示為:
import torch
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
result = torch.einsum('ik,kj->ij', A, B)
print(result) # 輸出: tensor([[19, 22], [43, 50]])
這里,'ik,kj->ij' 表示:
- A 的維度為 i(行)和 k(列)。
- B 的維度為 k(行)和 j(列)。
- 輸出的矩陣 C 的維度為 i(行)和 j(列),其中 k 是求和維度。
2、向量內(nèi)積(點(diǎn)積)
對于兩個(gè)向量 a 和 b,它們的內(nèi)積可以用 torch.einsum 表示為:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,i->', a, b)
print(result) # 輸出: tensor(32)
這里,'i,i->' 表示:
- a 和 b 都是一維向量,維度為 i。
- 輸出是一個(gè)標(biāo)量(沒有索引),表示所有元素的乘積之和。
3、向量外積
向量外積可以表示為:
a = torch.tensor([1, 2, 3])
b = torch.tensor([4, 5, 6])
result = torch.einsum('i,j->ij', a, b)
print(result)
# 輸出:
# tensor([[ 4, 5, 6],
# [ 8, 10, 12],
# [12, 15, 18]])
這里,'i,j->ij' 表示:
- a 的維度為 i,b 的維度為 j。
- 輸出矩陣 C 的維度為 ij,表示 a[i] 和 b[j] 的乘積。
torch.einsum 是一個(gè)通用且靈活的工具,但其性能可能不如專門為某些操作優(yōu)化的函數(shù)(如 torch.matmul)。所以在性能關(guān)鍵的應(yīng)用中,使用專門的張量操作函數(shù)可能會(huì)更高效。不過對于需要簡潔表示復(fù)雜操作的場景,torch.einsum 仍然是首選。
總結(jié)
以下是對 PyTorch 中幾種常用張量操作函數(shù)的總結(jié):
- torch.matmul (矩陣乘法)
- 功能:執(zhí)行矩陣乘法,支持二維矩陣、批量矩陣乘法、高維張量乘法。
- 應(yīng)用:廣泛用于神經(jīng)網(wǎng)絡(luò)中的矩陣運(yùn)算,如全連接層的計(jì)算。
- torch.mm (矩陣乘法)
- 功能:專門用于二維張量(矩陣)之間的乘法,不支持廣播和高維張量。
- 應(yīng)用:適用于明確為二維矩陣的乘法操作,性能高效。
- torch.bmm (批次矩陣乘法)
- 功能:對三維張量進(jìn)行批次矩陣乘法,適用于批量處理的場景。
- 應(yīng)用:常用于深度學(xué)習(xí)中的批量數(shù)據(jù)處理和圖神經(jīng)網(wǎng)絡(luò)中的鄰接矩陣計(jì)算。
- torch.mul (元素級乘法)
- 功能:逐元素乘法,支持任意維度張量并自動(dòng)廣播。
- 應(yīng)用:用于權(quán)重調(diào)整、掩碼操作、數(shù)據(jù)歸一化等逐元素運(yùn)算。
- torch.mv (矩陣與向量乘法)
- 功能:用于二維矩陣與一維向量之間的乘法操作。
- 應(yīng)用:適用于神經(jīng)網(wǎng)絡(luò)中的前向傳播、線性代數(shù)操作。
- torch.dot (點(diǎn)乘)
- 功能:計(jì)算兩個(gè)一維張量(向量)之間的點(diǎn)積,結(jié)果是一個(gè)標(biāo)量。
- 應(yīng)用:用于計(jì)算向量內(nèi)積、向量相似性、物理學(xué)中的能量計(jì)算。
- torch.outer (外積)
- 功能:計(jì)算兩個(gè)一維張量之間的外積,結(jié)果是一個(gè)二維矩陣。
- 應(yīng)用:用于構(gòu)建矩陣、處理雙線性形式、特征交互等。
- torch.einsum (愛因斯坦求和約定)
- 功能:使用愛因斯坦求和約定進(jìn)行復(fù)雜張量運(yùn)算,包括矩陣乘法、轉(zhuǎn)置、內(nèi)積、外積等。
- 應(yīng)用:廣泛用于線性代數(shù)、物理學(xué)計(jì)算、機(jī)器學(xué)習(xí)中的復(fù)雜操作。
這些 PyTorch 張量操作函數(shù)各有其專門用途和應(yīng)用場景。torch.matmul、torch.mm 和 torch.bmm 主要用于矩陣乘法;torch.mul 和 torch.outer 用于逐元素和外積操作;torch.mv 和 torch.dot 處理矩陣與向量、向量與向量的乘法;torch.einsum 則是處理復(fù)雜張量運(yùn)算的多功能工具。