您的位置:宽带测速网 > 编程知识 > PyTorch中如何使用预训练的模型

PyTorch中如何使用预训练的模型

2025-06-23 19:26来源:互联网 [ ]

在PyTorch中使用预训练的模型可以通过torchvision库中的models模块实现。该模块包含了一些常用的预训练模型,如ResNet、VGG、AlexNet等。以下是一个使用预训练的ResNet模型的示例:

import torchimport torchvision.models as modelsimport torchvision.transforms as transformsfrom PIL import Image# 加载预训练的ResNet模型model = models.resnet18(pretrained=True)model.eval()# 加载一张图片进行推理transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])img = Image.open('image.jpg')img = transform(img)img = img.unsqueeze(0)# 添加一个维度作为batch# 进行推理output = model(img)

在上面的示例中,我们首先加载了预训练的ResNet模型,并设置为evaluation模式。然后,我们加载了一张图片,并对其进行预处理,最后通过模型进行推理得到输出。需要注意的是,我们在推理之前还需要调用model.eval()来将模型设置为evaluation模式。