nn.linear函数

nn.Linear 是 PyTorch 框架中用于创建线性变换模型的函数。它接受输入张量(Tensor)和权重矩阵,通过矩阵乘法和加法操作,再加上一个可选的偏差向量(bias),然后通过一个激活函数(默认为线性激活函数),得到输出张量。

以下是 nn.Linear 函数的基本用法和参数说明:

import torch
import torch.nn as nn

# 创建一个线性层,输入特征数为2,输出特征数为3
linear_layer = nn.Linear(in_features=2, out_features=3)

# 输入张量,形状为[batch_size, 2]
input_tensor = torch.randn(4, 2)

# 应用线性变换
output_tensor = linear_layer(input_tensor)

# 输出张量形状为[batch_size, 3]
print(output_tensor.size())  # 输出:torch.Size([4, 3])

参数说明:

  • in_features:输入张量的特征数,即输入的维度大小。

  • out_features:输出张量的特征数,即输出的维度大小。

  • bias:一个布尔值,如果设置为 False,则该层不会学习附加偏差,默认值为 True

nn.Linear 层在深度学习模型中常用于全连接层。在 PyTorch 中,全连接层的输入和输出通常都是二维张量,其形状一般为 [batch_size, size],其中 batch_size 是样本数量,size 是特征数。

希望这能帮助你理解 nn.Linear 函数及其用法。

Top