<2020 02 16>
구글 코랩을 이용하여 라이브러리 '파이토치' 기반으로 Mnist를 학습하였다.
28x28의 이미지 셋으로 라이브러리 'torchvision'에서 제공해준다.
아래 문구를 통해서 다운 받게 된다.
mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)
데이터 셋을 다운받은 것을 확인할 수 있다
아래 코드를 통해서 이미지 셋을 직접 확인 할 수 있다.
|
image = image.numpy()[0]
mean = 0.1307
std = 0.3081
image = ((mean * image) + std)
plt.imshow(image,cmap='gray')
import matplotlib.pyplot as plt
sample_data = next(iter(data_loader))
plot_img(sample_data[0][2])
plot_img(sample_data[0][30]) |
<전체 코드>
import torch
import torchvision
batch_size = 1000
mnist_train = torchvision.datasets.MNIST(root="MNIST_data/", train=True, transform=torchvision.transforms.ToTensor(), download=True)
mnist_test = torchvision.datasets.MNIST(root="MNIST_data/", train=False, transform=torchvision.transforms.ToTensor(), download=True)
data_loader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
device = torch.device("cuda:0")
linear = torch.nn.Linear(784, 10, bias=True).to(device)
loss = torch.nn.CrossEntropyLoss().to(device)
SDG = torch.optim.SGD(linear.parameters(), lr=0.1)
total_batch = len(data_loader) # 60 = 60000 / 1000 (total / batch_size)
training_epochs = 10
for epoch in range(training_epochs):
total_cost = 0
for X, Y in data_loader:
X = X.view(-1, 28 * 28).to(device)
Y = Y.to(device)
hypothesis = linear(X)
cost = loss(hypothesis, Y)
SDG.zero_grad()
cost.backward()
SDG.step()
total_cost += cost
avg_cost = total_cost / total_batch
print("Epoch:", "%03d" % (epoch+1), "cost =", "{:.9f}".format(avg_cost))
with torch.no_grad():
X_test = mnist_test.data.view(-1, 28 * 28).float().to(device)
Y_test = mnist_test.targets.to(device)
prediction = linear(X_test)
correct_prediction = torch.argmax(prediction, 1) == Y_test
accuracy = correct_prediction.float().mean()
print("Accuracy: ", accuracy.item())
def plot_img(image):
image = image.numpy()[0]
mean = 0.1307
std = 0.3081
image = ((mean * image) + std)
plt.imshow(image,cmap='gray')
import matplotlib.pyplot as plt
sample_data = next(iter(data_loader))
plot_img(sample_data[0][2])
댓글 없음:
댓글 쓰기