[PyTorch 강의 12강] 파이토치 MLP Fashion MNIST 예제
오늘 우리는 PyTorch를 사용하여 MLP(Multi-Layer Perceptron)를 구축하고 Fashion MNIST 데이터셋을 이용하여 학습해 보겠습니다. Fashion MNIST는 의류와 악세사리 이미지로 구성된 데이터셋으로, 10개의 카테고리로 분류된 이미지 데이터를 포함하고 있습니다.
먼저 필요한 패키지들을 import 합니다.
import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import FashionMNIST
다음으로 데이터를 불러와서 DataLoader를 만들어줍니다.
transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) train_dataset = FashionMNIST(root='./data', train=True, transform=transform, download=True) test_dataset = FashionMNIST(root='./data', train=False, transform=transform, download=True) train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
이제 MLP 모델을 정의해줍니다.
class MLP(nn.Module): def __init__(self): super(MLP, self).__init__() self.fc1 = nn.Linear(784, 256) self.fc2 = nn.Linear(256, 128) self.fc3 = nn.Linear(128, 10) def forward(self, x): x = x.view(x.size(0), -1) x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = self.fc3(x) return x
마지막으로 모델을 학습하고 테스트해봅니다.
model = MLP() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=0.001) for epoch in range(10): model.train() for inputs, labels in train_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() model.eval() total_correct = 0 total_samples = 0 with torch.no_grad(): for inputs, labels in test_loader: outputs = model(inputs) _, predicted = torch.max(outputs, 1) total_samples += labels.size(0) total_correct += (predicted == labels).sum().item() accuracy = total_correct / total_samples print(f'Epoch {epoch+1}, Accuracy: {accuracy}')
이렇게 간단한 MLP 모델을 이용하여 Fashion MNIST 데이터셋을 분류해보았습니다. PyTorch의 강력한 기능을 활용하여 더 복잡한 모델을 구축하고 성능을 향상시킬 수 있습니다. 계속해서 공부하고 실습해보면서 꾸준한 발전을 이루어나갈 수 있을 것입니다.
파이토치 MLP Fashion MNIST 예제코드입니다.
https://github.com/neowizard2018/neowizard/blob/master/PyTorch/PyTorch_LEC11_MLP_FashionMNIST_Example.ipynb 입니다. 실습하실때 참고하시기 바랍니다
감사합니다