nn.Linear
是 PyTorch 框架中用于创建线性变换模型的函数。它接受输入张量(Tensor)和权重矩阵,通过矩阵乘法和加法操作,再加上一个可选的偏差向量(bias),然后通过一个激活函数(默认为线性激活函数),得到输出张量。
以下是 nn.Linear
函数的基本用法和参数说明:
import torch
import torch.nn as nn
# 创建一个线性层,输入特征数为2,输出特征数为3
linear_layer = nn.Linear(in_features=2, out_features=3)
# 输入张量,形状为[batch_size, 2]
input_tensor = torch.randn(4, 2)
# 应用线性变换
output_tensor = linear_layer(input_tensor)
# 输出张量形状为[batch_size, 3]
print(output_tensor.size()) # 输出:torch.Size([4, 3])
参数说明:
-
in_features
:输入张量的特征数,即输入的维度大小。 -
out_features
:输出张量的特征数,即输出的维度大小。 -
bias
:一个布尔值,如果设置为False
,则该层不会学习附加偏差,默认值为True
。
nn.Linear
层在深度学习模型中常用于全连接层。在 PyTorch 中,全连接层的输入和输出通常都是二维张量,其形状一般为 [batch_size, size]
,其中 batch_size
是样本数量,size
是特征数。
希望这能帮助你理解 nn.Linear
函数及其用法。