2020년 2월 16일 일요일

Pytorch를 이용한 Mnist 학습하기

<2020 02 16>
구글 코랩을 이용하여 라이브러리 '파이토치' 기반으로 Mnist를 학습하였다.


Image result for Mnist
<Mnist>

28x28의 이미지 셋으로 라이브러리 'torchvision'에서 제공해준다.
아래 문구를 통해서 다운 받게 된다.
mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)
   데이터 셋을 다운받은 것을 확인할 수 있다


아래 코드를 통해서 이미지 셋을 직접 확인 할 수 있다.
 def plot_img(image):
    image = image.numpy()[0]
    mean = 0.1307
    std = 0.3081
    image = ((mean * image) + std)
    plt.imshow(image,cmap='gray')
import matplotlib.pyplot as plt
sample_data = next(iter(data_loader))
plot_img(sample_data[0][2])

plot_img(sample_data[0][30])
 

<전체 코드>

import torch
import torchvision
batch_size = 1000
mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)
data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)

device = torch.device("cuda:0")
linear = torch.nn.Linear(78410, bias=True).to(device)
loss = torch.nn.CrossEntropyLoss().to(device)
SDG = torch.optim.SGD(linear.parameters(), lr=0.1)
total_batch = len(data_loader) # 60 = 60000 / 1000 (total / batch_size)
training_epochs = 10

for epoch in range(training_epochs):
    total_cost = 0
    for X, Y in data_loader:
        X = X.view(-128 * 28).to(device)
        Y = Y.to(device)
        
        hypothesis = linear(X)
        cost = loss(hypothesis, Y)
        SDG.zero_grad()
        cost.backward()
        SDG.step()
        total_cost += cost 
    avg_cost = total_cost / total_batch
    print("Epoch:""%03d" % (epoch+1), "cost =""{:.9f}".format(avg_cost))


with torch.no_grad():
    X_test = mnist_test.data.view(-128 * 28).float().to(device)
    Y_test = mnist_test.targets.to(device)
    prediction = linear(X_test)
    correct_prediction = torch.argmax(prediction, 1) == Y_test
    accuracy = correct_prediction.float().mean()
    print("Accuracy: ", accuracy.item())


def plot_img(image):
    image = image.numpy()[0]
    mean = 0.1307
    std = 0.3081
    image = ((mean * image) + std)
    plt.imshow(image,cmap='gray')
import matplotlib.pyplot as plt
sample_data = next(iter(data_loader))
plot_img(sample_data[0][2])

댓글 없음:

댓글 쓰기