当前位置: 代码网 > it编程>编程语言>Asp.net > 详解PyTorch nn.Embedding() 嵌入

详解PyTorch nn.Embedding() 嵌入

2024年11月05日 Asp.net 我要评论
在对文本序列进行分词(tokenize)并映射后,字符串序列就转变为了数字(token id)序列,这些 token id 可以直接输入到模型中,但需要明白的是,模型并不能直接从一个纯粹的数字中获取丰

在对文本序列进行分词(tokenize)并映射后,字符串序列就转变为了数字(token id)序列,这些 token id 可以直接输入到模型中,但需要明白的是,模型并不能直接从一个纯粹的数字中获取丰富的信息。类比到人类的认知,我们理解一个字或词并不是仅靠符号,而是其背后的含义。

nn.embedding 嵌入层

torch.nn.embedding(num_embeddings, embedding_dim, padding_idx=none, max_norm=none, norm_type=2.0, scale_grad_by_freq=false, sparse=false, _weight=none, _freeze=false, device=none, dtype=none)

a simple lookup table that stores embeddings of a fixed dictionary and size.

一个简单的查找表,用于存储固定大小的字典中每个词的嵌入向量。

参数

  • num_embeddings (int): 嵌入字典的大小,即词汇表的大小 (vocab size)。
  • embedding_dim (int): 每个嵌入向量的维度大小。
  • padding_idx (int, 可选): 指定填充对应的索引值。该索引对应的嵌入向量在训练过程中不会更新,即梯度不参与反向传播,通常作为“填充”标记使用。对于新构建的 embedding 模块,此索引的嵌入向量默认值为全零,但可以更改为其他值。
  • max_norm (float, 可选): 如果设置,超过此值的嵌入向量范数将被重新归一化,使其最大范数等于 max_norm
  • norm_type (float, 可选): 用于计算 max_norm 的 p-范数,默认为 2,即计算 2 范数。
  • scale_grad_by_freq (bool, 可选): 如果为 true,梯度将根据单词在 mini-batch 中的频率的倒数进行缩放,适用于高频词的梯度调整。默认为 false
  • sparse (bool, 可选): 如果设置为 true,则权重矩阵的梯度为稀疏张量,适合大规模词汇表的内存优化。
  • 变量 weight (tensor): 模块的可学习权重,形状为 (num_embeddings, embedding_dim),初始值从正态分布 n(0, 1) 中采样。

方法

from_pretrained(embeddings, freeze=true, padding_idx=none, max_norm=none, norm_type=2.0, scale_grad_by_freq=false, sparse=false)

create embedding instance from given 2-dimensional floattensor.

用于从给定的 2 维浮点张量(floattensor)创建一个 embedding 实例。

参数

  • embeddings (tensor): 一个包含嵌入权重的 floattensor。第一个维度代表 num_embeddings(词汇表大小),第二个维度代表 embedding_dim(嵌入向量维度)。
  • freeze (bool, 可选): 如果为 true,则嵌入张量在训练过程中保持不变,相当于设置 embedding.weight.requires_grad = false。默认值为 true
  • 其余参数参考之前定义。

要点示例未完待续…(预计 11.6 前上传)

qa

q1:对于神经网络来说,什么是“符号”及其“背后的含义”?

答案是:token idembedding

那么,什么是 embedding?

我们可以通过 pytorch 中的 nn.embedding 类来理解它,先跳过繁琐的介绍,运行代码来直观感受:

import torch
import torch.nn as nn
# 设置随机种子以确保结果可复现
torch.manual_seed(42)
# 定义嵌入层参数
num_embeddings = 5  # 假设词汇表中有 5 个 token
embedding_dim = 3   # 每个 token 对应 3 维嵌入向量
# 初始化嵌入层
embedding = nn.embedding(num_embeddings, embedding_dim)
# 定义整数索引
input_indices = torch.tensor([0, 2, 4])
# 查找嵌入向量
output = embedding(input_indices)
# 打印结果
print("权重矩阵:")
print(embedding.weight.data)
print("\nembedding 输出:")
print(output)

输出:

权重矩阵:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 0.2303, -1.1229, -0.1863],
        [ 2.2082, -0.6380,  0.4617],
        [ 0.2674,  0.5349,  0.8094],
        [ 1.1103, -1.6898, -0.9890]])

embedding 输出:
tensor([[ 0.3367,  0.1288,  0.2345],
        [ 2.2082, -0.6380,  0.4617],
        [ 1.1103, -1.6898, -0.9890]], grad_fn=<embeddingbackward0>)

在这里,input_indices = [0, 2, 4] 从权重矩阵中选择第 0、2 和 4 行作为对应的嵌入表示。是的没错,embedding 的获取就是这么简单。

接下来,构建一个 embedding 类进行理解:

class embedding():
    def __init__(self, num_embeddings, embedding_dim):
        self.weight = torch.nn.parameter(torch.randn(num_embeddings, embedding_dim))
    def forward(self, indices):
        return self.weight[indices]  # 没错,就是返回对应的行

可以看出,embedding 类的本质是一个查找表(lookup table)。在上面的示例中,embedding.weight 中存储了 5 个(num_embeddings)嵌入向量,每个向量有 3 个维度(embedding_dim)。当提供 input_indices 时,查找表返回对应的嵌入向量(权重矩阵的行)。

q2: 最初的权重矩阵是什么?最终的嵌入向量由什么决定?

最初的权重矩阵是一般随机初始化的,在训练过程中会更新权重,使其能有效地表达背后的含义。

q3: 什么是语义?

举个简单的例子来理解“语义”关系:像“猫”和“狗”在向量空间中的表示应该非常接近,因为它们都是宠物;“男人”和“女人”之间的向量差异可能代表性别的区别。此外,不同语言的词汇,如“男人”(中文)和“man”(英文),如果在相同的嵌入空间中,它们的向量也会非常接近,反映出跨语言的语义相似性。同时,【“女人”和“woman”(中文-英文)】与【“男人”和“man”(中文-英文)】之间的差异也可能非常相似。

本文“狭义”地解读了与 token id 一起出现的 embedding,这个概念在自然语言处理(nlp)中有着更具体的称呼:word embedding。

到此这篇关于pytorch nn.embedding() 嵌入详解的文章就介绍到这了,更多相关pytorch nn.embedding() 嵌入内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!

(0)

相关文章:

  • C# WPF自制白板工具

    C# WPF自制白板工具

    随着电子屏幕技术的发展,普通的黑板已不再适用现在的教学和演示环境,电子白板应运而生。本篇使用wpf开发了一个电子白板工具,功能丰富,非常使用日常免费使用,或者进... [阅读全文]
  • C# WPF自制简单的批注工具

    C# WPF自制简单的批注工具

    在教学和演示中,我们通常需要对重点进行批注,下载安装第三方工具批注显得很麻烦。本篇使用wpf开发了一个批注工具,工具小巧,功能丰富,非常使用日常免费使用,或者进... [阅读全文]
  • C#处理XML文件的示例详解

    C#处理XML文件的示例详解

    一、基本介绍可扩展标记语言(英语:extensible markup language,简称:xml),是一种标记语言。标记指计算机所能理解的信息符号,通过此种... [阅读全文]
  • 使用C#和OpenXML读取大型Excel文件

    使用C#和OpenXML读取大型Excel文件

    介绍高效读取大型 excel 文件可能具有挑战性,尤其是在处理需要高性能和可扩展性的应用程序时。microsoft 的 openxml sdk 提供了一套强大的... [阅读全文]
  • C#中Task任务类用法详解

    C#中Task任务类用法详解

    前言task类是.net平台上进行多线程和异步操作的重要工具。它提供了简洁而强大的api支持,使得开发者能够更加高效地利用系统资源,实现复杂的并行和异步操作。无... [阅读全文]
  • C#文字识别API场景解析、表格识别提取功能实现

    C#文字识别API场景解析、表格识别提取功能实现

    在快节奏的工作与生活环境中,如何提高企业工作效率、提升用户体验成为了人们追求的共同目标。针对市场发展需求,一种将任意场景图片中的文字转换为可编辑文本的文字识别技... [阅读全文]

版权声明:本文内容由互联网用户贡献,该文观点仅代表作者本人。本站仅提供信息存储服务,不拥有所有权,不承担相关法律责任。 如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 2386932994@qq.com 举报,一经查实将立刻删除。

发表评论

验证码:
Copyright © 2017-2025  代码网 保留所有权利. 粤ICP备2024248653号
站长QQ:2386932994 | 联系邮箱:2386932994@qq.com