持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。
其它Tensor操作如 einsum等见:待更新。
用到两个函数:
torch.cat
torch.stack
一、List Tensor转Tensor (torch.cat)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 | / / An highlighted block >>> t1 = torch.FloatTensor([[ 1 , 2 ],[ 5 , 6 ]]) >>> t2 = torch.FloatTensor([[ 3 , 4 ],[ 7 , 8 ]]) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> ta = torch.cat(l,dim = 0 ) >>> ta = torch.cat(l,dim = 0 ).reshape( 2 , 2 , 2 ) >>> tb = torch.cat(l,dim = 1 ).reshape( 2 , 2 , 2 ) >>> ta tensor([[[ 1. , 2. ], [ 5. , 6. ]], [[ 3. , 4. ], [ 7. , 8. ]]]) >>> tb tensor([[[ 1. , 2. ], [ 3. , 4. ]], [[ 5. , 6. ], [ 7. , 8. ]]]) |
高维tensor
** 如果理解了2D to 3DTensor,以此类推,不难理解3D to 4D,看下面代码即可明白:**
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | >>> t1 = torch. range ( 1 , 8 ).reshape( 2 , 2 , 2 ) >>> t2 = torch. range ( 11 , 18 ).reshape( 2 , 2 , 2 ) >>> l = [] >>> l.append(t1) >>> l.append(t2) >>> torch.cat(l,dim = 2 ).reshape( 2 , 2 , 2 , 2 ) tensor([[[[ 1. , 2. ], [ 11. , 12. ]], [[ 3. , 4. ], [ 13. , 14. ]]], [[[ 5. , 6. ], [ 15. , 16. ]], [[ 7. , 8. ], [ 17. , 18. ]]]]) >>> torch.cat(l,dim = 1 ).reshape( 2 , 2 , 2 , 2 ) tensor([[[[ 1. , 2. ], [ 3. , 4. ]], [[ 11. , 12. ], [ 13. , 14. ]]], [[[ 5. , 6. ], [ 7. , 8. ]], [[ 15. , 16. ], [ 17. , 18. ]]]]) >>> torch.cat(l,dim = 0 ).reshape( 2 , 2 , 2 , 2 ) tensor([[[[ 1. , 2. ], [ 3. , 4. ]], [[ 5. , 6. ], [ 7. , 8. ]]], [[[ 11. , 12. ], [ 13. , 14. ]], [[ 15. , 16. ], [ 17. , 18. ]]]]) |
二、List Tensor转Tensor (torch.stack)
代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | import torch t1 = torch.FloatTensor([[ 1 , 2 ],[ 5 , 6 ]]) t2 = torch.FloatTensor([[ 3 , 4 ],[ 7 , 8 ]]) l = [t1, t2] t3 = torch.stack(l, dim = 2 ) print (t3.shape) print (t3) ## output: ## torch.Size([2, 2, 2]) ## tensor([[[1., 3.], ## [2., 4.]], ## [[5., 7.], ## [6., 8.]]]) |
以上为个人经验,希望能给大家一个参考,也希望大家多多支持IT俱乐部。