torch模型仅更新部分参数(关于Torch中的scatter)

torch模型仅更新部分参数(关于Torch中的scatter)(1)

scatter_ 和 one hot

看了很多博客,中国人写博客有一个特点就是复制来复制去,根本没有讲到重点,好了废话不多扯,今天讲下 scatter_ 函数。

操作一:

import torch # 导入 torch模块,这里操作的都是张量数据 src = torch.arange(1, 11).reshape((2, 5)) # 这里创建一个 2行5列的数据 print(src) # 打印出来 tensor([[ 1, 2, 3, 4, 5], [ 6, 7, 8, 9, 10]])

上面这个是准备数据,是一个两行五列的数据。再创建一个索引数据

index = torch.tensor([[0, 1, 2, 0, 2]]) print(index) tensor([[0, 1, 2, 0, 2]])

在这之前都是很简单的,相比读者肯定能看到,无非就是两个数据,请耐心往下看

result_1 = torch.zeros(3, 5, dtype=src.dtype) # 创建一个3行5列的数据全是0 print(result_1) tensor([[0, 0, 0, 0 0], [0, 0 0, 0, 0], [0, 0, 0, 0, 0]])

解析来就是使用 scatter_函数: 也就是根据相关索引,把result_1的指定位置填充下

result = result_1.scatter_(0, index, src)

这里是什么意思呢, 0 表示按列来处理,result_1 是需要被更改的数据,index是索引位置, src数用来填充的数据,举例子: 如上面描述:

result_1 = tensor([[0, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

index = tensor([[0, 1, 2, 0, 2]])

tensor([[ 1, 2, 3, 4, 5],

[ 6, 7, 8, 9, 10]])

第一个参数 0 表示按列来处理

索引第1个值为0,这表示第1列的第1个数据设置为scr中的第2个数据

ensor([[1, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

索引第2个值为1,这表示第2列的第2个数据设置为scr中的第2个数据

ensor([[1, 0, 0, 0 0],

[0, 2 0, 0, 0],

[0, 0, 0, 0, 0]])

索引第3个值为2,这表示第3列的第3个数据设置为scr中的第3个数据

ensor([[1, 0, 0, 0 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 0]])

索引第4个值为0,这表示第4列的第1个数据设置为scr中的第4个数据

ensor([[1, 0, 0, 4 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 0]])

索引第5个值为2,这表示第5列的第3个数据设置为scr中的第5个数据

ensor([[1, 0, 0, 4 0],

[0, 2 0, 0, 0],

[0, 0, 3, 0, 5]])

以上就是详细的计算流程

操作2:

idx = torch.tensor([[0, 1, 2, 3,4]]) last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)

这里第一步我相信大家都熟悉,就是创建一个数据而已,这里我们理解为索引数据

1、torch.zeros(3, 5, dtype=src.dtype). 表示的是创建一个3行5列的数据矩阵,全是0

tensor([[0, 0, 0, 0 0],

[0, 0 0, 0, 0],

[0, 0, 0, 0, 0]])

2、dim=1,表示是按行计算

3、value,表示相应的位置上设置为某个值

idx = torch.tensor([[0, 1, 2, 3,4]])

表示的是第一行的第 0 1 2 3 4 的位置上全是设置为2,也就是

tensor([[2, 2, 2, 2, 2],

[0, 0, 0, 0, 0],

[0, 0, 0, 0, 0]])

当然,我相信某些人还是一脸懵逼,再继续往下看

idx = torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]]) last = torch.zeros(3, 5, dtype=src.dtype).scatter_(dim=1, index=idx, value=2)

这里我们看到idx为 torch.tensor([[0, 1, 2, 3,4],[0,0,0,0,4]])

这个idx有两行,那么他对应的也是 torch.zeros(3, 5, dtype=src.dtype)中的两行数据,

[0, 1, 2, 3,4] 表示的是第一行的第 0 1 2 3 4 的位置上全是设置为2

[0,0,0,0,4]]表示的是第二行的第 0 、4 的位置上设置为2,其他地方不变

因此整体数据变成了

tensor([[2, 2, 2, 2, 2],

[2, 0, 0, 0, 2],

[0, 0, 0, 0, 0]])

好了,这个函数介绍到此为止,希望能帮到大家

,

免责声明:本文仅代表文章作者的个人观点,与本站无关。其原创性、真实性以及文中陈述文字和内容未经本站证实,对本文以及其中全部或者部分内容文字的真实性、完整性和原创性本站不作任何保证或承诺,请读者仅作参考,并自行核实相关内容。文章投诉邮箱:anhduc.ph@yahoo.com

    分享
    投诉
    首页