PyTorch常见用法
1. 导入 PyTorch 库
| 1 | import torch | 
2. 检查 GPU 支持
| 1 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | 
3. 创建自定义数据集
| 1 | class CustomDataset(Dataset): | 
4. 使用数据增强
| 1 | transform = transforms.Compose([ | 
5. 数据加载
| 1 | train_dataset = CustomDataset(train_data, train_labels, transform=transform) | 
6. 定义神经网络模型
| 1 | class Net(nn.Module): | 
7. 实例化模型和优化器
| 1 | model = Net().to(device) | 
8. 定义损失函数
| 1 | criterion = nn.CrossEntropyLoss() | 
9. 训练模型
| 1 | for epoch in range(num_epochs): | 
10. 评估模型
| 1 | model.eval() # 设置为评估模式 | 
11. 保存和加载模型
| 1 | # 保存模型 | 
12. 使用 GPU 加速
如果有 NVIDIA GPU,并且已经安装了 CUDA,那么 PyTorch 可以利用它来加速训练过程。
| 1 | model = model.to(device) | 
13. 使用 TensorBoard 进行可视化
TensorBoard 是 TensorFlow 的一个可视化工具,但 PyTorch 用户也可以通过 torch.utils.tensorboard 使用它。
| 1 | from torch.utils.tensorboard import SummaryWriter | 
14. 调整学习率
| 1 | scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1) | 
15. 使用预训练模型
| 1 | import torchvision.models as models | 
16. 迁移学习
| 1 | for param in model.parameters(): | 
PyTorch 高频用法 (1)
数学运算
| 1 | torch.add, torch.sub, torch.mul, torch.div # 张量加法,减法,乘法,除法 | 
张量操作
| 1 | torch.tensor # 创建一个张量 | 
网络层
| 1 | torch.nn.Module # 所有神经网络模块的基类 | 
PyTorch 高频用法 (2)
数据加载
| 1 | torch.utils.data.Dataset # 用于创建自定义数据集的抽象类 | 
损失函数
| 1 | torch.nn.MSELoss, torch.nn.BCELoss, torch.nn.NLLLoss # 均方误差损失,二元交叉熵损失,负对数似然损失 | 
梯度管理
| 1 | tensor.grad # 获取张量的梯度 | 
性能优化
| 1 | torch.jit # 提供 Just-In-Time 编译功能,将 PyTorch 代码转换为更高效的 TorchScript 代码 | 
PyTorch 高频用法 (3)
优化器
| 1 | torch.optim.SGD # 随机梯度下降优化器 | 
学习率
| 1 | torch.optim.lr_scheduler.StepLR # 每隔一定的 epoch 数,学习率乘以一个固定的因子 gamma | 
functional函数
torch.nn.functional 模块包含了一系列用于构建神经网络的函数,它们与在 torch.nn 中的类似,但 torch.nn.functional 中的函数通常是无状态的,意味着它们不包含可训练的参数。这些函数通常用于执行某些操作,比如激活函数、损失函数和各种层的操作。
| 1 | import torch.nn.functional as F | 
分布式训练
| 1 | torch.distributed # 包含了分布式训练的各种工具和方法 | 
CUDA管理
| 1 | torch.device # 指定使用CPU或者GPU等设备 | 
模型管理
| 1 | torch.save # 保存模型或张量到文件 | 
调试和检查
| 1 | torch.version # 获取当前PyTorch的版本信息 | 
其他函数
| 1 | torchviz # 一个小型库,用于可视化PyTorch执行图(计算图)。 | 
 
      
      
    