1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45
|
import torch import torchvision from torch import nn from torch.nn import ReLU,Sigmoid,Linear,Sequential,Conv2d,MaxPool2d,Flatten from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter
dataset = torchvision.datasets.CIFAR10("dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=1)
class Seq(nn.Module): def __init__(self): super(Seq,self).__init__() self.model = nn.Sequential( Conv2d(3,32,5,padding=2), MaxPool2d(2), Conv2d(32,32,5,padding=2), MaxPool2d(2), Conv2d(32,64,5,padding=2), MaxPool2d(2), Flatten(), Linear(1024,64), Linear(64,10) ) def forward(self,x): x = self.model(x) return x
seq = Seq()
torch.save(seq,"Seq.pth")
model = torch.load("Seq.pth")
torch.save(seq.state_dict(),"Seq_dic.pth") model1 = Seq() model1.load_state_dict(torch.load("Seq_dic.pth")) print(model1)
|