深度學(xué)習Pytorch框架Tensor張量
作者:佚名
本文主要介紹了Tensor的裁剪運算、索引與數(shù)據(jù)篩選、組合/拼接、切片、變形操作、填充操作和Tensor的頻譜操作(傅里葉變換)。
1 Tensor的裁剪運算
- 對Tensor中的元素進行范圍過濾
- 常用于梯度裁剪(gradient clipping),即在發(fā)生梯度離散或者梯度爆炸時對梯度的處理
- torch.clamp(input, min, max, out=None) → Tensor:將輸入input張量每個元素的夾緊到區(qū)間 [min,max],并返回結(jié)果到一個新張量。
2 Tensor的索引與數(shù)據(jù)篩選
- torch.where(codition,x,y):按照條件從x和y中選出滿足條件的元素組成新的tensor,輸入?yún)?shù)condition:條件限制,如果滿足條件,則選擇a,否則選擇b作為輸出。
- torch.gather(input,dim,index,out=None):在指定維度上按照索引賦值輸出tensor
- torch.inex_select(input,dim,index,out=None):按照指定索引賦值輸出tensor
- torch.masked_select(input,mask,out=None):按照mask輸出tensor,輸出為向量
- torch.take(input,indices):將輸入看成1D-tensor,按照索引得到輸出tensor
- torch.nonzero(input,out=None):輸出非0元素的坐標
- import torch
- #torch.where
- a = torch.rand(4, 4)
- b = torch.rand(4, 4)
- print(a)
- print(b)
- out = torch.where(a > 0.5, a, b)
- print(out)
- print("torch.index_select")
- a = torch.rand(4, 4)
- print(a)
- out = torch.index_select(a, dim=0,
- index=torch.tensor([0, 3, 2]))
- #dim=0按列,index取的是行
- print(out, out.shape)
- print("torch.gather")
- a = torch.linspace(1, 16, 16).view(4, 4)
- print(a)
- out = torch.gather(a, dim=0,
- index=torch.tensor([[0, 1, 1, 1],
- [0, 1, 2, 2],
- [0, 1, 3, 3]]))
- print(out)
- print(out.shape)
- #注:從0開始,第0列的第0個,第一列的第1個,第二列的第1個,第三列的第1個,,,以此類推
- #dim=0, out[i, j, k] = input[index[i, j, k], j, k]
- #dim=1, out[i, j, k] = input[i, index[i, j, k], k]
- #dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
- print("torch.masked_index")
- a = torch.linspace(1, 16, 16).view(4, 4)
- mask = torch.gt(a, 8)
- print(a)
- print(mask)
- out = torch.masked_select(a, mask)
- print(out)
- print("torch.take")
- a = torch.linspace(1, 16, 16).view(4, 4)
- b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))
- print(b)
- #torch.nonzero
- print("torch.take")
- a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
- out = torch.nonzero(a)
- print(out)
- #稀疏表示
3 Tensor的組合/拼接
- torch.cat(seq,dim=0,out=None):按照已經(jīng)存在的維度進行拼接
- torch.stack(seq,dim=0,out=None):沿著一個新維度對輸入張量序列進行連接。序列中所有的張量都應(yīng)該為相同形狀。
- print("torch.stack")
- a = torch.linspace(1, 6, 6).view(2, 3)
- b = torch.linspace(7, 12, 6).view(2, 3)
- print(a, b)
- out = torch.stack((a, b), dim=2)
- print(out)
- print(out.shape)
- print(out[:, :, 0])
- print(out[:, :, 1])
4 Tensor的切片
- torch.chunk(tensor,chunks,dim=0):按照某個維度平均分塊(最后一個可能小于平均值)
- torch.split(tensor,split_size_or_sections,dim=0):按照某個維度依照第二個參數(shù)給出的list或者int進行分割tensor
5 Tensor的變形操作
- torch().reshape(input,shape)
- torch().t(input):只針對2D tensor轉(zhuǎn)置
- torch().transpose(input,dim0,dim1):交換兩個維度
- torch().squeeze(input,dim=None,out=None):去除那些維度大小為1的維度
- torch().unbind(tensor,dim=0):去除某個維度
- torch().unsqueeze(input,dim,out=None):在指定位置添加維度,dim=-1在最后添加
- torch().flip(input,dims):按照給定維度翻轉(zhuǎn)張量
- torch().rot90(input,k,dims):按照指定維度和旋轉(zhuǎn)次數(shù)進行張量旋轉(zhuǎn)
- import torch
- a = torch.rand(2, 3)
- print(a)
- out = torch.reshape(a, (3, 2))
- print(out)
- print(a)
- print(torch.flip(a, dims=[2, 1]))
- print(a)
- print(a.shape)
- out = torch.rot90(a, -1, dims=[0, 2]) #順時針旋轉(zhuǎn)90°
- print(out)
- print(out.shape)
6 Tensor的填充操作
- torch.full((2,3),3.14)
7 Tensor的頻譜操作(傅里葉變換)
責任編輯:龐桂玉
來源:
深度學(xué)習這件小事