Tensor: ... @overload def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -。关于pytorch中scatter_add_函数的分析、理解与实现。" />

关于pytorch中scatter_add_函数的分析、理解与实现

import torch
import numpy as np
from torch import Tensor
"""
@overload
def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...
def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
对pytorch中的scatter_add函数的理解和简单测试:
参数:tensor,dim,index,tensor 返回:tensor 功能:将other_tensor的值累加到self_tensor的相应位置,用index_tensor对应位置的值替换掉self_tensor下标的dim维 举例:

self_tensor= [[1, 2], [3, 4]] shape=(2,2) other_tensor = [[5, 6], [7, 8]] shape=(2,2) index_tensor = [[0, 0], [1, 1]] shape=(2,2) dim = 1 以上三个tensor的shape必须一致,下标为:[0,0] [0,1] [1,0] [1,1] dim=1,那么,self_tensor的第1维下标由index_tensor表示,[0,0] [0,0] [1,1] [1,1] 则: self_tensor[0,0] = 1 + 5 + 6 = 12 self_tensor[0,1] = 2 self_tensor[1,0] = 3 self_tensor[1,1] = 4 + 7 + 8 = 19

【关于pytorch中scatter_add_函数的分析、理解与实现】"""
def scatter_add(input_tensor: torch.Tensor, dim: int, index: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
# tensor的维数是不确定的,因此无法用for循环的方式 # 如果tensor是2维,[金属期货](https://www.gendan5.com/cf/mf.html)那么dim=0或1,两层for循环,用other对self进行填充 # 如果tensor是3维,那么dim=0、1、2,需要三层for循环来遍历other if input_tensor.dim() == 2: for i in range(index_tensor.size()[0]): for j in range(index_tensor.size()[1]): if dim == 0:# self矩阵的第0维索引 self_tensor[index_tensor[i][j]][j] += other_tensor[i][j] elif dim == 1:# self矩阵的第1维索引 self_tensor[i][index_tensor[i][j]] += other_tensor[i][j] elif input_tensor.dim() == 3: pass return self_tensor

if name == '__main__':
index_tensor = torch.tensor([[0, 0], [1, 1]]) print('index_tensor: \n', index_tensor.dim()) self_tensor = torch.arange(1, 5).view(2, 2) print('self_tensor: \n', self_tensor) other_tensor = torch.arange(5, 9).view(2, 2) print('other_tensor: \n', other_tensor) dim = 1 for i in range(index_tensor.size()[0]): for j in range(index_tensor.size()[1]): replace_index = index_tensor[i][j] print(i, j, replace_index) if dim == 0: # self矩阵的第0维索引 self_tensor[replace_index][j] += other_tensor[i][j] elif dim == 1: # self矩阵的第1维索引 self_tensor[i][replace_index] += other_tensor[i][j] print(self_tensor) index_tensor = torch.tensor([[0, 1], [1, 1]]) print('index_tensor: \n', index_tensor) self_tensor = torch.arange(0, 4).view(2, 2) print('self_tensor: \n', self_tensor) other_tensor = torch.arange(5, 9).view(2, 2) print('other_tensor: \n', other_tensor) self_tensor.scatter_add_(dim=0, index=index_tensor, src=https://www.it610.com/article/other_tensor) print(self_tensor)

    推荐阅读