当前位置: 代码网 > it编程>编程语言>Asp.net > Pytorch中的masked_fill基本知识详解

Pytorch中的masked_fill基本知识详解

2024年10月27日 Asp.net 我要评论
1. 基本知识基本的原理知识如下:输入张量和掩码:masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码掩码的形状必须与输入张量相同,true 表示需要填充的位置,false 表示保持原

1. 基本知识

基本的原理知识如下:

输入张量和掩码
masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码
掩码的形状必须与输入张量相同,true 表示需要填充的位置,false 表示保持原值

掩码操作
在执行 masked_fill 操作时,函数会检查掩码中每个元素的值
如果掩码对应的位置为 true,则在输出张量中填充指定的值;
如果为 false,则保留输入张量中对应位置的值

输出结果
最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样

在代码逻辑上

创建掩码
mask 是一个布尔张量,标识了哪些位置需要填充:

[[false, true, false],
 [true, false, true],
 [false, false, true]]

执行 masked_fill
当调用 tensor.masked_fill(mask, -1) 时,pytorch 会遍历掩码中的每个元素:对于 mask 中的每个 true 值,tensor 在对应位置的值会被替换为 -1,对于 false 值,保持原值不变

masked_fill 操作是基于 c/c++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时

2. demo

demo 1: 基本用法

import torch

# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])

# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[false, true, false],
                     [true, false, true],
                     [false, false, true]])

# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)

print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)

截图如下:

demo 2: 与条件结合使用

import torch
# 创建一个随机张量
tensor = torch.randn(3, 3)
# 创建掩码:标记负值的位置
mask = tensor < 0
# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)

截图如下:

demo 3: 结合计算

import torch
# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
                       [40, 50, 60],
                       [70, 80, 90]])
# 创建掩码:标记大于 50 的位置
mask = tensor > 50
# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)

截图如下:

到此这篇关于pytorch中的masked_fill基本知识的文章就介绍到这了,更多相关pytorch masked_fill内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!

(0)

相关文章:

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com