딥러닝/딥러닝: 기초 개념

사전 학습된(Pre-trained) 모델 이해

qordnswnd123 2025. 6. 10. 17:16

1. 사전 학습 모델의 구조

사전 학습된 모델은 대규모 데이터셋에서 학습된 모델로, 컴퓨터 비전에서 자주 사용됩니다. 이 모델들은 다양한 구조와 깊이를 가지며, 각각의 장점이 있습니다.

 

1) ResNet (Residual Network):

ResNet은 '잔차 연결(residual connection)'을 도입하여, 네트워크가 매우 깊어질 때 발생하는 기울기 소실 문제를 해결합니다. 이 구조는 레이어 간 직접 연결을 추가하여, 네트워크가 깊어져도 성능이 향상될 수 있도록 돕습니다.
장점: 깊은 네트워크에서도 성능이 안정적이며, 다양한 이미지 분류 및 인식 작업에서 널리 사용됩 니다.
주요 구성 요소:  기본 블록(Residual Block): 직접 연결로 이루어진 기본 단위, 다양한 깊이의 모델(ResNet18, ResNet34, ResNet50 등)이 존재


2) VGG (Visual Geometry Group):

VGG는 매우 간단한 구조로, 여러 개의 3x3 필터를 쌓아 깊이 있는 CNN(Convolutional Neural Network)을 형성한 모델입니다. VGG는 단순한 구조 덕분에 해석하기 쉬운 장점이 있습니다.
장점: 구조가 단순하고 직관적이어서 이미지 분류에서 널리 사용됩니다.
주요 구성 요소: 여러 개의 3x3 컨볼루션 필터를 연속으로 적용, 일반적으로 VGG16과 VGG19가 사용됨.

 

※ 모델의 각 레이어 이해

 

모델은 여러 레이어로 구성되어 있으며, 각 레이어는 이미지의 특징을 점점 더 추상적으로 변환합니다.

  • 컨볼루션 레이어(Convolution Layer): 이미지에서 중요한 특징(엣지, 색상 등)을 추출하는 역할을 합니다.
  • 풀링 레이어(Pooling Layer): 이미지의 크기를 줄이면서도 중요한 정보는 유지하여 계산 비용을 줄입니다.
  • 완전 연결 레이어(Fully Connected Layer): 이미지 특징을 바탕으로 최종적으로 클래스를 분류하 는 역할을 합니다.
  • 소프트맥스 레이어(Softmax Layer): 최종적으로 이미지가 각 클래스에 속할 확률을 계산하는 레이어입니다.

2. 사전 학습 모델 적용

2.1 torchvision.models의 사전 학습된 모델 불러오기

# dataset_utils.py 파일에서 필요한 함수 불러오기
from dataset_utils import setup_folders_and_extract
setup_folders_and_extract()
import torch
import torchvision.models as models

# ResNet18 사전 학습된 모델 불러오기
model = models.resnet18(pretrained=True)

# 모델을 평가 모드로 설정 (추론 시 필요)
model.eval()

# 모델 구조 확인
print(model)
ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer2): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer3): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (layer4): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (downsample): Sequential(
        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)
        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): BasicBlock(
      (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
  (fc): Linear(in_features=512, out_features=1000, bias=True)
)

 

1) import torch
PyTorch 라이브러리를 불러옵니다. PyTorch는 텐서 연산과 딥러닝 모델을 구현하는데 사용되는 파이썬 라이브러리입 니다.

 

2) import torchvision.models as models

torchvision.models는 PyTorch에서 제공하는 사전 학습된 모델들이 포함된 모듈입니다. 이미지 분류 등의 작업에 사용되는 다양한 사전 학습된 모델을 쉽게 불러올 수 있습니다. 여기서는 models로 모듈을 불러옵니다.

 

3) model = models.resnet18(pretrained=True)

resnet18은 ResNet(Residual Network)이라는 신경망 아키텍처의 18-layer 버전입니다. ResNet은 잔차 학습을 도입하여 매우 깊은 신경망을 효과적으로 학습할 수 있도록 설계된 모델입니다.
pretrained=True는 이 모델이 대규모 데이터셋(예: ImageNet)에서 학습된 가중치를 불러온다는 의미입니다. 즉, 모델이 이미 학습된 상태이므로, 새로운 데이터에 대해 즉시 예측할 수 있습니다. 학습 없이 바로 사용할 수 있도록 모델이 준비되어 있습니다.

 

4) model.eval()

모델을 "평가 모드" 로 설정합니다. PyTorch에는 학습 모드(train())와 평가 모드(eval())가 있습니다. 평가 모드는 추론(예측) 시 사용되며, 이때 드롭아웃(Dropout)이나 배치 정규화(Batch Normalization) 같은 학습 중에만 사용하는 기능들이 비활성화됩니다. 추론 시에는 이러한 기능들이 필요하지 않기 때문에, 평가 모드로 설정하는 것이 중요합니다.


5) print(model)
모델의 전체 구조를 출력합니다. ResNet18 모델은 여러 계층(layer)으로 이루어져 있으며, 각 계층이 어떤 역할을 하는지 구조를 확인할 수 있습니다. 출력되는 내용은 모델의 각 레이어와 그 구성 요소들을 보여줍니다. 이를 통해 모델의 구조를 시각적으로 확인할 수 있습니다.

 

2.2 이미지 데이터셋으로 예측해보기

import torchvision.transforms as transforms
from PIL import Image

# CIFAR-10 이미지 전처리
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 테스트 이미지를 불러오기 (예시로 하나의 이미지 사용)
img = Image.open('cifar_image.jpg')
img_t = transform(img)
batch_t = torch.unsqueeze(img_t, 0)  # 배치 형태로 변환
print(f'배치사이즈: ',batch_t.shape)

# 모델을 사용해 예측
with torch.no_grad():  # 추론 중에는 그래디언트를 계산하지 않음
    out = model(batch_t)

# 결과 출력
print(out)
_, predicted_class = torch.max(out, 1)
print(f"Predicted class: {predicted_class.item()}")

 

1) import torchvision.transforms as transforms

torchvision.transforms는 이미지 데이터를 전처리하는 데 사용되는 모듈입니다. 이미지를 텐서로 변환하거나 크기 조정, 정규화 등을 수행할 수 있습니다.

 

2) from PIL import Image

Python Imaging Library(PIL)를 사용하여 이미지 파일을 열거나 처리할 수 있습니다.

 

3) 이미지 전처리 정의 (transform = transforms.Compose([...]))

전처리 과정은 이미지 데이터를 모델에 적합하게 만드는 중요한 단계입니다.
Resize(224): CIFAR-10 이미지 크기는 32x32이지만, ResNet18 모델의 입력 크기는 224x224입니다. 따라서 이미지를 224x224로 크기를 조정합니다.
ToTensor():이미지를 PyTorch에서 사용하는 텐서(tensor)로 변환합니다.
Normalize(mean, std): 이미지를 정규화합니다. 사전 학습된 모델(ResNet18)은 ImageNet 데이터셋에서 학습되었기 때문에, 해당 데이터셋에 맞는 평균값과 표준편차로 이미지를 정규화합니다. 각 채널(R, G, B)에 대해 다른 평균과 표준편차를 사용합니다.

 

4) 이미지 불러오기 (img=Image.open('cifar_image.jpg"))

cifar_image.jpg라는 이미지 파일을 열어 이를 전처리할 준비를 합니다.

 

5) 이미지 전처리 적용 (img_transform(img))

CIFAR-10 이미지를 전처리(크기 조정 및 정규화)한 후, PyTorch 텐서로 변환합니다.

 

6) 배치 형태로 변환 (batch_torch.unsqueeze(img_t, 0))

모델은 여러 이미지를 한 번에 처리하기 위해 입력을 배치(batch)로 받습니다. unsqueeze 함수는 차원을 하나 추가하여 (1, C, H, W) 형태의 배치 데이터를 만듭니다. 여기서 1은 배치 크기, C는 채널 수, H는 높이, W는 너비입니다.


7) 모델을 사용한 예측 (out=model(batch_t))

torch.no_grad()는 추론 과정에서 불필요한 그래디언트 계산을 방지합니다. 이는 메모리 사용량을 줄여주고, 추론 속도를 높여줍니다. 그 후, 전처리된 이미지를 사전 학습된 모델에 입력하여 예측 결과를 얻습니다.


8) 결과 출력 (print(out))

모델이 예측한 결과값을 출력합니다. 이 결과는 사전 학습된 ResNet18 모델이 ImageNet 데이터셋에서 학습된 구조로 예측한 값이므로, 1000개의 클래스에 대한 예측 점수(로짓, logits) 가 포함됩니다. 각 클래스별로 이 이미지가 해당 클래스에 속할 가능성을 나타내는 점수를 제공합니다.

 

배치사이즈:  torch.Size([1, 3, 224, 224])
tensor([[ 4.0363e+00,  1.4217e+00, -1.4910e+00, -4.6077e+00, -3.7689e+00,
          2.1729e+00, -3.9440e+00,  3.7153e-01,  1.1757e+00, -8.5295e-01,
          4.3050e+00, -1.6868e+00, -2.3042e+00, -3.8390e+00, -3.3013e+00,
         -3.3915e+00, -3.2870e+00, -3.5190e+00, -5.4709e+00, -4.4180e+00,
         -1.0408e+00,  1.0697e+00, -2.7281e+00, -1.4952e+00, -1.8120e+00,
          1.1644e+00,  3.8687e+00,  2.7678e+00,  3.6982e+00,  1.2574e+00,
         -5.5958e-01,  1.4652e-01,  3.0336e+00, -2.1949e+00, -2.2737e+00,
          4.0378e-01,  2.4157e+00, -8.8572e-01,  2.9774e+00, -6.9137e-01,
          1.6671e-01,  2.1812e+00,  2.0316e+00,  7.9398e+00,  6.8402e-01,
          8.5255e-01, -1.6696e+00,  2.8611e-01,  7.9614e-01,  1.4602e+00,
         -4.0792e-01,  5.5829e+00,  9.9238e-01, -1.8718e+00, -9.1278e-01,
         -2.5580e+00, -1.5337e+00, -3.2204e-01, -1.6695e+00,  3.5092e+00,
          2.3370e+00,  3.0000e-01,  7.2230e+00,  4.0047e+00,  2.2438e+00,
          1.0196e+00,  2.5192e+00, -1.0026e+00,  4.5172e+00, -1.2814e+00,
         -6.7791e-01, -2.4161e-01, -2.2462e+00,  5.0070e+00, -3.3987e+00,
         -2.7169e+00, -1.4485e+00, -2.3588e+00,  2.5738e+00,  2.0204e+00,
          3.3845e+00, -1.9207e+00,  5.8443e-01,  3.0611e+00, -1.9362e+00,
         -1.8143e+00,  3.2615e+00, -3.0637e+00, -3.0488e+00, -3.7584e+00,
         -3.1768e+00, -1.4058e+00, -8.3663e-01, -3.4277e+00, -1.9191e+00,
          1.0270e+00, -4.6133e+00, -2.4990e+00,  4.7423e+00,  8.4242e-01,
         -3.0897e+00,  1.7751e+00,  1.7500e+00,  5.7154e+00,  3.9101e+00,
          2.3539e+00,  1.4941e+00, -3.1182e+00, -9.4268e-01, -1.4475e+00,
          2.4440e-01,  3.8121e+00, -2.2369e+00,  5.3495e-01, -4.2638e-01,
         -1.4380e+00,  2.4672e+00,  2.2342e+00, -4.3091e-01, -1.3303e+00,
          4.6910e-01, -1.4048e+00,  1.7228e-01, -1.3793e+00,  2.9842e-02,
         -1.0154e+00,  6.3468e+00, -4.0953e+00,  1.0411e+00, -3.2539e+00,
         -4.1446e+00, -3.0250e+00, -3.5614e+00,  1.3593e+00,  5.1085e-01,
          4.5109e+00, -1.4874e+00, -3.1361e+00, -2.4848e+00, -2.8119e+00,
         -6.7911e-02, -2.3294e+00,  1.3028e+00, -3.7070e+00, -4.4216e+00,
         -2.2953e+00, -3.6504e+00,  2.7516e-01, -1.5312e+00,  1.0329e+00,
          1.6986e+00,  1.0926e+00,  3.1453e+00, -2.5195e+00,  7.0777e-01,
         -2.1779e+00,  4.8002e-01,  1.0864e+00,  5.3333e+00,  3.0182e+00,
          2.1296e+00, -2.4911e-01,  3.7452e-01,  1.3576e+00,  2.7960e+00,
          4.7715e+00,  4.6966e+00,  5.8780e+00,  4.9724e+00,  1.4017e+00,
          1.1715e-01,  1.1570e+00,  9.7247e-01,  3.7660e+00, -3.5374e-01,
          5.3870e+00,  1.8733e+00,  1.3488e+00,  1.5701e-01,  1.3194e+00,
          2.4476e+00,  4.7868e+00,  3.8580e-01,  4.1415e+00,  1.7112e+00,
          9.1178e-01,  1.9637e+00,  2.6043e+00,  5.0375e+00,  2.4099e+00,
          3.4662e+00,  2.9921e+00, -1.0670e+00,  2.9855e+00,  2.1285e+00,
         -3.1565e+00, -1.2838e+00, -1.4083e+00, -4.3876e-01, -1.7141e+00,
          6.5106e-01,  4.6627e-02,  3.7456e-02, -4.2657e+00, -3.8156e+00,
         -1.0612e-01,  4.5214e+00,  2.4259e+00,  2.1159e+00,  2.3831e+00,
          4.9002e+00,  2.9549e+00,  3.7447e-01,  4.3923e-01,  4.1411e-01,
          3.1014e+00,  3.9818e+00, -3.0046e+00,  1.2006e-01, -1.5955e+00,
          4.9289e+00,  4.9212e+00, -7.4503e-01, -1.6728e+00,  5.8697e-01,
          7.6339e-01,  1.4665e+00,  3.6282e+00,  4.4763e+00,  1.3277e-01,
          2.2040e+00,  2.2729e+00,  7.9955e-01, -1.3576e+00,  3.0251e-01,
          2.9353e+00,  2.2951e+00,  2.7335e-01, -3.4370e-01, -3.1842e+00,
          2.2757e-01,  2.9045e+00, -6.5820e-01,  2.0302e+00,  6.5038e-01,
         -2.5714e+00,  1.0506e+00, -3.2692e-01,  7.5327e-01, -4.0986e-01,
         -3.7339e-01, -1.9370e+00,  1.0226e+00,  2.9642e+00, -3.5735e+00,
         -3.9063e-01, -9.5224e-01, -9.0553e-01, -3.5998e+00, -8.5623e-01,
         -5.5581e-01, -1.6097e+00,  3.1551e+00,  1.6440e-01,  8.8740e-01,
          1.4377e+00,  1.4128e+00,  2.1339e+00,  2.3060e+00,  4.1313e+00,
         -2.6825e-02,  4.6921e+00,  5.9979e+00,  5.9484e+00,  6.0289e+00,
          3.7618e+00,  2.8032e+00,  5.8450e+00,  6.9106e+00,  7.5400e-01,
          4.8485e+00,  1.4223e+00,  3.5231e+00,  6.1993e-01, -1.0582e+00,
          1.2797e+00,  1.2017e+00,  8.1467e-01,  4.5466e+00, -3.9064e-01,
          4.0466e+00,  5.0219e+00,  3.8093e+00,  5.0940e+00,  4.0430e+00,
          1.3473e+00, -6.6035e-01,  4.2360e+00,  5.4334e+00,  3.5826e+00,
         -7.8393e-01, -1.4106e+00, -3.0390e-01, -1.8406e+00, -1.0724e+00,
         -2.9764e+00, -1.1473e+00, -6.1573e-01, -3.8948e+00, -1.2840e+00,
          5.0359e+00, -8.5460e-01, -2.5694e-01, -1.6913e+00, -1.2486e+00,
         -3.1060e+00, -7.4996e-01, -1.7208e+00,  1.6582e+00, -2.7176e-01,
         -1.9111e+00, -4.4775e+00, -1.9240e-02, -3.4165e+00, -1.6902e+00,
          3.9422e-01, -1.9802e+00, -2.0545e+00, -4.3745e+00, -2.5023e+00,
          2.8182e+00,  2.4417e+00, -3.5029e-01, -1.3586e+00,  1.1151e+00,
          8.9772e+00,  1.3039e+00,  2.6353e+00,  2.6715e+00,  5.0287e+00,
          6.1989e-02,  5.9847e+00,  4.5653e+00,  7.0762e-01,  2.2882e+00,
          3.1422e+00,  8.3222e-01,  2.6250e-02, -2.2207e+00,  1.7435e+00,
          2.3166e+00,  2.9999e+00,  1.5348e+00,  2.9033e+00,  2.7961e+00,
         -3.0774e-02,  5.4938e+00,  2.5429e+00,  5.2993e+00,  3.6272e+00,
          3.5885e+00,  2.1741e+00,  2.4554e+00,  2.1363e+00,  3.9784e+00,
          3.3007e+00, -1.3785e-01,  1.6065e+00,  2.3569e+00,  3.1353e+00,
          4.7172e+00,  6.6717e+00,  4.5755e+00,  6.3232e+00,  4.5126e+00,
          3.4396e+00,  6.5726e+00,  5.9761e+00,  2.4850e+00,  1.1028e+00,
          3.0083e+00,  4.0241e+00,  4.2778e+00,  2.4054e+00,  2.5673e+00,
          2.5481e+00,  2.0542e+00,  1.7572e+00, -1.0559e+00,  2.2605e+00,
          3.3502e+00, -1.8078e+00,  2.8648e+00, -3.4072e+00, -4.4099e-01,
          8.6917e-01, -3.1597e+00,  4.3505e-03, -3.4202e+00,  2.2676e+00,
          1.8040e+00, -2.6961e+00, -1.4948e+00, -2.7205e+00,  1.4263e+00,
         -7.0154e-01, -1.6391e+00, -3.2150e-01,  1.5105e+00,  6.5849e-01,
          2.2449e+00, -1.7156e+00, -1.9151e+00,  2.8912e+00, -5.3091e+00,
         -1.0485e+00,  3.1883e-01, -1.4563e+00, -6.4363e-01,  7.7295e-01,
         -9.9246e-01,  1.0047e+00, -3.7758e+00, -6.0719e-01, -1.3846e+00,
         -1.8416e+00,  2.8273e+00,  3.8228e+00, -2.2721e+00,  1.9359e+00,
         -6.0560e-01, -2.2836e+00,  1.4141e+00, -2.7925e+00, -4.5538e-01,
         -3.4649e+00, -1.7872e+00, -2.5605e+00,  1.2902e+00,  3.8112e+00,
         -1.4774e-01,  6.8579e-01, -2.0493e+00, -2.1508e+00, -2.1789e+00,
         -3.2259e+00,  5.1316e-01,  1.8947e+00, -4.3626e-02,  1.0860e+00,
         -1.3503e+00,  3.4146e+00, -1.8879e+00, -8.1823e-01, -3.6476e+00,
         -9.3557e-01,  8.3229e-01, -1.5754e+00, -2.1934e+00, -2.8384e+00,
         -4.0511e+00,  1.4369e+00, -1.7986e+00, -1.1583e+00, -6.4415e-01,
          3.2715e+00, -1.2289e-02, -3.6119e+00, -1.5522e+00, -6.1746e-01,
         -5.6079e-01, -2.3689e+00, -1.9240e+00,  2.3631e+00, -3.5209e-01,
         -1.7098e+00,  1.1808e-01, -1.7534e+00, -2.7750e+00, -1.6499e+00,
         -9.3644e-01,  2.9077e+00,  3.2064e+00,  2.8315e-01, -2.7487e+00,
         -2.5682e+00, -9.1283e-01, -1.4617e+00,  2.0557e+00, -7.9230e-01,
          2.8463e+00,  3.3286e+00, -9.7562e-01,  3.4900e+00,  1.7402e+00,
         -1.0615e+00, -2.2899e+00, -3.1057e+00, -1.4053e+00,  2.1772e+00,
         -6.0025e-01,  2.0461e+00,  3.2619e+00,  1.2287e+00, -1.8939e+00,
         -2.2300e+00, -2.6463e+00,  1.1879e+00, -2.8672e+00, -2.5418e+00,
         -8.6025e-03, -3.0566e+00, -4.9650e-02,  1.9595e+00,  1.2116e+00,
          9.9951e-01,  1.0263e+00, -2.6785e+00,  1.1045e-01,  1.7066e+00,
         -4.0042e+00,  1.4881e+00,  3.4077e+00, -2.4497e+00,  1.3536e+00,
         -1.9033e-01, -3.7571e+00, -1.3913e+00, -2.2429e+00, -3.4855e+00,
         -1.5903e+00, -3.4654e-01, -1.1485e+00, -3.9036e+00,  1.2231e+00,
         -2.9163e+00, -2.1354e+00,  4.6208e-01,  8.2880e-01, -5.1095e-01,
          1.1536e-01,  6.3784e-01, -4.3888e-01, -2.1631e+00, -4.8468e+00,
         -2.8093e+00, -1.4983e+00,  2.4125e-01,  1.8328e+00, -1.4801e+00,
          2.4542e+00,  1.4957e+00, -2.0550e+00, -9.4710e-01, -2.3425e+00,
         -1.6167e+00,  1.1723e+00, -2.3402e+00,  2.7106e+00, -5.0909e-01,
          1.2496e+00, -3.0539e+00, -1.6845e+00, -2.6310e+00, -1.8425e+00,
         -1.0306e+00,  1.6149e+00,  8.7133e-01, -1.8734e+00, -1.9397e+00,
         -4.3622e-02, -3.9947e+00, -1.8514e+00, -2.2750e+00,  7.1429e-01,
         -2.7171e+00, -3.7736e-01,  5.0865e+00, -3.5437e+00, -2.9734e+00,
         -4.2014e+00, -3.9062e+00, -1.7179e+00,  2.8286e+00,  3.7565e-01,
          2.6294e+00, -1.4285e+00, -1.6840e+00,  8.7809e-01, -9.8851e-01,
          5.2214e-01, -1.3151e+00, -4.9838e+00,  2.3201e+00,  1.0655e+00,
         -2.6878e+00,  2.1620e+00, -1.4497e+00, -8.8794e-01,  2.1993e+00,
         -1.4009e+00,  1.0319e+00,  1.0195e+00, -1.0420e+00,  4.1793e+00,
         -3.6820e+00,  7.7563e-01, -1.8220e+00, -3.7486e+00, -1.7956e+00,
          2.7132e+00, -7.6194e-01,  7.5770e-02, -2.5529e+00, -2.9162e+00,
         -6.7331e-02, -1.5395e+00,  8.1019e-02, -2.2016e+00, -5.3855e-03,
         -1.2226e-01, -3.2332e+00, -3.2035e+00,  1.1766e+00, -1.0256e+00,
         -2.7953e+00, -2.7819e+00, -8.3783e-01, -2.0465e+00, -5.3437e-02,
          5.7348e-01,  2.8144e+00, -1.3536e+00,  3.5754e+00, -1.7817e+00,
          2.5033e+00,  7.0974e-01, -1.8831e+00, -2.6024e+00,  2.3711e-01,
         -4.2370e+00,  2.9335e+00, -1.8004e+00,  6.6646e-03, -1.1804e+00,
         -2.1948e-01,  2.2517e-01, -1.7099e-01, -1.3441e+00, -2.0923e-01,
         -2.0404e+00,  3.5048e-01, -4.0940e-01,  5.7071e+00,  1.4134e-01,
         -3.6617e+00, -6.4405e-01, -4.6943e+00, -3.0810e+00, -1.1468e+00,
          3.5561e+00, -2.7631e+00, -2.6514e-01,  2.5405e+00, -1.9279e+00,
         -4.4700e-01,  6.1255e-01, -9.2394e-01, -1.3497e+00, -2.2631e+00,
         -2.3926e+00, -8.9756e-01,  1.3908e+00, -2.3339e+00,  2.9049e+00,
          4.3401e+00,  5.3769e+00, -1.9619e-01,  4.4967e+00, -3.4348e+00,
          4.2548e+00, -2.6750e+00, -2.2297e+00,  3.7306e+00,  5.8954e+00,
         -5.5631e+00,  4.3577e-01, -1.5471e+00, -1.6416e+00,  1.8229e+00,
          1.4159e+00,  1.2981e+00,  9.9607e-01,  2.3337e-01, -1.5940e+00,
          3.6870e-01, -2.9174e+00, -3.4824e+00,  6.8108e-03,  4.1740e+00,
         -2.2417e+00, -1.5190e+00,  1.3314e+00, -2.0728e+00, -3.2213e+00,
          4.0047e+00, -2.7519e+00, -4.3043e+00,  9.6466e-01, -5.6879e+00,
         -3.8702e+00, -1.0704e+00,  3.7324e+00, -2.2507e+00,  3.2163e+00,
          4.7908e+00, -3.1216e+00, -2.6723e+00, -6.4822e-02, -9.4663e-01,
          7.4396e-01, -4.4728e+00,  2.4080e+00, -1.9523e+00,  1.3251e+00,
         -1.3260e+00,  3.9355e+00,  4.5479e+00,  2.2717e+00,  3.5264e+00,
         -2.5284e+00,  1.6657e+00, -2.3184e+00, -1.4976e-01, -1.5522e+00,
         -1.9631e+00, -1.3398e+00, -1.6649e+00, -3.2594e+00,  1.5380e+00,
          9.0220e-01,  1.4513e+00, -2.1855e+00,  2.8651e+00, -1.3091e+00,
         -3.7919e+00, -7.2512e-01,  2.2642e+00, -3.5373e+00, -1.5909e-01,
         -5.3881e-01, -1.1196e+00, -2.9185e+00, -2.6798e+00, -2.5427e+00,
         -1.5037e+00, -1.0114e+00,  5.3918e-01,  3.1429e+00, -1.0977e+00,
         -3.5961e+00, -3.1428e+00, -1.3114e+00,  4.4193e+00, -1.1253e+00,
         -2.1211e+00,  2.0082e+00, -3.1486e+00, -6.7187e-01, -1.9046e-01,
         -3.9804e-01,  1.3332e+00,  6.5631e-01, -1.6670e+00, -1.4635e+00,
         -1.5953e+00,  1.5277e+00, -1.1803e+00, -2.1749e+00, -8.7321e-02,
         -1.8674e+00,  2.1217e-02,  2.6700e+00, -5.0065e-01, -1.6876e+00,
         -3.8195e+00, -3.7781e+00,  1.8585e-01, -3.0489e+00,  1.8335e+00,
         -3.5381e-01, -3.9283e+00, -3.8722e+00, -2.4858e-01, -3.1298e+00,
         -4.0906e-01,  1.2380e+00, -7.9705e-01,  2.5155e+00, -1.3384e+00,
          6.5294e-01, -1.7719e+00, -2.0720e+00, -2.3575e+00,  9.3613e-01,
          1.5122e+00, -3.6336e+00, -2.2244e-01,  3.4154e+00, -7.5472e-01,
          5.1156e-01, -1.0881e+00, -8.3044e-01,  1.4990e+00, -3.5490e-01,
         -2.4819e+00, -5.1773e+00, -9.0198e-01,  9.0847e-01, -2.9587e+00,
         -3.3460e+00,  3.9728e+00,  1.3789e-01, -5.8135e-01, -1.3604e+00,
         -2.0820e+00, -9.7013e-01, -9.1665e-01, -2.1086e+00, -4.7670e-01,
         -1.3492e-01,  2.3217e+00, -9.1501e-01, -2.3560e+00, -9.8535e-01,
         -8.8625e-01,  1.8024e-02, -3.7257e+00,  4.0296e+00, -1.3412e+00,
         -3.8241e+00,  3.7125e+00, -1.3221e+00, -2.5030e+00,  2.8173e-01,
         -1.1291e+00,  2.9085e+00, -2.5943e+00, -2.2848e+00, -3.1479e+00,
         -2.7947e+00,  7.8793e-02,  1.1727e+00,  2.6063e+00,  3.0375e+00,
         -6.8344e-02,  2.4661e+00,  1.7997e+00, -3.4357e+00, -1.3802e+00,
          5.5325e+00, -6.9763e-01,  4.3667e-01, -2.9679e+00, -1.5235e+00,
         -8.0080e-01, -2.7789e+00,  4.4549e-01, -2.9397e+00,  1.9675e+00,
         -3.0291e+00, -1.6409e+00,  6.3866e-01, -2.0124e+00, -2.8464e+00,
          1.2282e+00, -4.4804e+00, -1.6477e+00, -4.7806e-01, -2.1873e+00,
         -2.5692e+00,  1.2861e+00, -2.6052e+00, -2.0822e-01, -5.5836e-01,
          1.7130e+00, -7.3496e-01,  1.6556e+00, -4.4779e-01, -2.5546e+00,
         -7.3203e-02,  9.4179e-01,  1.3264e+00, -2.0776e+00, -9.7878e-01,
         -3.1767e+00, -1.8428e+00, -6.9790e-01, -1.8298e+00,  8.8667e-01,
         -3.5495e+00,  6.2226e+00,  2.6371e+00,  4.3817e-01, -4.0457e-01,
         -1.3056e+00,  6.0260e+00, -8.4489e-01, -3.1315e+00, -1.6091e+00,
         -1.4279e+00, -8.4846e-01,  3.3562e+00, -1.5579e+00,  4.1209e-01,
         -9.5306e-01,  8.1445e-01,  8.8128e-01, -3.6374e+00, -1.4746e+00,
         -2.9242e+00,  2.1526e+00, -4.3233e+00,  4.6648e-01, -4.5788e-01,
          3.3972e+00,  2.4538e+00, -3.2441e-01, -4.0041e+00, -3.3553e+00,
          2.3856e+00,  2.7954e+00,  7.7657e-01, -1.6941e+00, -1.4962e+00,
         -2.4674e+00,  7.5072e-01, -2.4597e-02, -1.7681e+00, -3.3524e+00,
          2.8682e+00, -1.8553e+00, -1.0839e+00, -2.1442e+00, -1.4042e+00,
         -1.5107e+00,  2.7547e+00,  2.3144e-01,  6.6906e-02,  2.7082e-01,
         -3.6828e+00, -9.7991e-01, -1.5579e-01, -1.6396e+00, -1.3616e+00,
          1.8850e+00,  8.9905e-02, -2.4172e+00, -5.2220e-01,  5.4670e-02,
         -1.9375e+00, -1.7611e+00,  1.0370e+00, -4.2226e+00,  1.2115e+00,
         -4.6829e-02, -9.7030e-01, -8.8701e-01, -2.7261e+00,  4.0279e+00,
          1.7349e-01, -9.6608e-01, -7.8391e-01, -3.1282e+00, -1.4323e+00,
          4.9900e-01, -8.7294e-01, -7.8081e-01, -5.3196e-01, -5.6814e-01,
         -2.6103e+00,  1.0516e+00, -9.9031e-01, -3.2472e+00, -2.4508e+00,
         -1.5461e+00,  4.8962e-01,  9.0596e-01, -5.2621e-01, -2.7542e+00,
          1.4469e+00, -3.0322e+00, -7.2181e-01,  2.5004e+00,  5.9578e-01,
          2.5014e+00,  2.0930e+00,  3.6783e+00,  1.8409e+00, -3.0624e+00]])
Predicted class: 335

 

※ 결과해석

배치 크기: torch.Size([1, 3, 224, 224])는 입력된 이미지가 1개의 이미지로 구성된 배치이며, 각각의 이미지가 3개의 색상 채널(RGB)과 224x224 픽셀 크기를 가진다는 것을 나타냅니다. 이 크기는 ResNet 모델에 맞게 전처리된 형태입니다.


출력된 텐서: 출력된 값들은 로짓(logits)으로, 모델이 각 클래스에 대해 예측한 "점수"입니다. 이 점수는 음수나 양수일 수 있으며, 값이 클수록 해당 클래스에 속할 가능성이 높다는 뜻입니다. 이 경우, 총 1000개의 클래스에 대한 예측 점수가 출력되었습니다. 이 모델은 ImageNet 데이터셋에서 학습된 ResNet18 모델이므로, 1000개의 ImageNet 클 래스 중 하나를 예측하게 됩니다.


Predicted class: 335: torch.max(out, 1)를 사용해 1000개의 클래스 중 가장 높은 점수를 받은 클래스를 예측했으며, 그 결과 335번 클래스가 선택되었습니다. 이 숫자는 ImageNet 데이터셋에서 사전 학습된 모델의 클래스 인덱스를 의미하며, CIFAR-10 클래스와는 다릅니다. 즉, ImageNet의 335번 클래스가 CIFAR-10 이미지를 보고 가장 가능성이 높다고 판단된 클래스입니다.

 

주의 사항:
현재 사용한 모델은 ImageNet 데이터셋에 맞춰져 있어 1000개의 클래스 중 하나로 예측하게 됩니다. 하지만 CIFAR- 10 데이터셋은 10개의 클래스만 있으므로, 이 모델은 CIFAR-10 이미지에 대해 정확한 예측을 하기 어렵습니다. 정확한 예측을 위해서는 모델의 마지막 출력 레이어를 CIFAR-10에 맞게 10개의 클래스로 수정하고, CIFAR-10 데이터셋에 맞춰 미세 조정(fine-tuning)을 진행해야 합니다.


요약:
현재 결과는 ImageNet 클래스 중 하나를 예측한 것이며, CIFAR-10과는 맞지 않을 수 있습니다. CIFAR-10 데이터에 맞는 예측을 하려면 모델의 마지막 레이어를 10개 클래스에 맞춰 수정한 후 다시 학습해야 합니다.

 

3. 실습: 사전 학습 모델을 활용한 기본 이미지 분류

3.1 ResNet18을 활용한 CIFAR-10 이미지 분류

해당 코드는 사전 학습된 ResNet18 모델을 CIFAR-10 데이터셋에 맞춰 수정하고, 데이터를 통해 예측을 수행하는 과정입니다.
ResNet 모델은 224x224 크기의 이미지를 입력으로 받으므로 CIFAR-10 이미지 크기를 조정하고, 마지막 fully connected 레이어를 CIFAR-10의 10개의 클래스에 맞게 수정합니다.

import torch
import torch.nn as nn
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
from torchvision.datasets import ImageFolder

# CIFAR-10 데이터셋 불러오기
transform = transforms.Compose([
    transforms.Resize(224),  # 입력 크기를 ResNet 모델에 맞추기
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = ImageFolder(root='./train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# 사전 학습된 ResNet18 모델을 불러오고 CIFAR-10에 맞게 출력 레이어 수정
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features 
model.fc = nn.Linear(num_ftrs, 10) 

model.eval()  

# 학습 루프에서 예측 수행
for inputs, labels in train_loader:
    with torch.no_grad():  
        outputs = model(inputs)  
    print(outputs)  
    break

 

1) import torch

PyTorch의 핵심 모듈을 가져옵니다. 모델을 불러오거나 데이터를 처리하는 데 사용됩니다.

 

2) import torch.nn as nn

신경망 레이어 정의를 위한 모듈입니다. CIFAR-10 데이터셋에 맞게 마지막 레이어를 수정하는 데 사용됩니다.

 

3) import torchvision.datasets as datasets

torchvision.datasets는 여러 유명한 이미지 데이터셋을 제공하는 모듈로, 여기서는 CIFAR-10 데이터셋을 불러오는 데 사용됩니다.

 

4) from torch.utils.data import DataLoader

DataLoader는 데이터셋을 배치(batch) 단위로 처리하고, 학습 및 추론 중에 데이터를 효율적으로 불러오는 기능을 제공합니다.

 

5) 데이터셋 불러오기 및 전처리

이미지 데이터셋을 불러와 전처리하고, 데이터 크기를 ResNet18 모델에 맞게 조정합니다(224x224).
데이터를 배치 단위로 처리할 수 있도록 DataLoader를 사용합니다.

 

6) 사전 학습된 ResNet18 모델 불러오기 및 출력 레이어 수정

사전학습된 ResNet18 모델을 불러옵니다.

CIFAR-10은 10개의 클래스를 가지므로, ResNet18의 마지막 레이어(fully connected layer)를 CIFAR-10에 맞게 10개의 출력으로 변경합니다.
모델을 평가 모드(eval())로 설정하여 추론을 수행합니다.


7) 예측 수행

CIFAR-10 데이터셋의 첫 번째 배치를 불러와 모델에 입력한 후, 예측 결과를 출력합니다.
torch.no_grad()를 사용하여 추론 시 불필요한 그래디언트 계산을 방지합니다.

 

tensor([[-0.0152, -0.0788,  0.4812,  0.5106, -0.5678,  0.1927,  0.2802, -0.1230,
         -0.5314,  0.0459],
        [-0.9083, -0.1452, -0.1747,  0.1228, -0.3857,  0.5641,  0.3648, -0.2856,
         -0.7403, -1.0223],
        [-0.1251, -0.1639, -0.1007,  1.1067, -1.0362,  0.2851, -0.4113,  0.0754,
         -0.3897,  0.3234],
        [ 0.1661, -0.4323, -0.0657,  0.8099, -0.3302,  0.2585,  0.4583,  0.0758,
         -0.9323, -0.6302],
        [ 0.4421, -0.7657, -0.3361,  0.4012, -0.2645,  0.3685,  0.5388,  0.4013,
         -0.1352, -0.7587],
        [-0.4590, -0.3892, -0.2192,  0.5293, -0.7777,  1.0611,  0.3986, -0.1574,
         -0.2812, -0.2127],
        [-0.6801, -0.4555, -0.5584,  0.2609, -0.4927,  1.4467,  0.1117,  0.2582,
         -0.6613, -0.9880],
        [-0.1635, -0.3159,  0.1936,  0.1466, -0.3465,  0.0702, -0.0961,  0.4681,
         -0.3177, -0.3550],
        [-0.6842, -0.3324,  0.5520, -0.3942,  0.0730,  0.8447,  1.7206,  0.6903,
          0.0174, -0.8974],
        [ 0.0827,  0.8983, -0.0185,  0.5335, -1.0426,  0.8056,  0.3651, -0.0690,
         -0.8736, -1.1491],
        [ 0.1884, -0.7979, -0.3789,  1.1218, -0.1817,  0.4228,  0.5152,  0.1077,
         -0.0980, -0.2291],
        [-0.4438, -0.1639, -0.4012,  0.6672, -0.5961,  0.5247,  0.0709, -0.2124,
         -0.6690, -0.8518],
        [-0.0252, -0.6953, -0.2626,  0.3592, -0.4600,  0.2609,  0.1365, -0.0195,
         -0.4981, -0.1351],
        [-0.0974, -0.3856,  0.0862,  0.1975, -0.4841,  0.1648,  0.2228,  0.2049,
         -0.5927,  0.2846],
        [-0.6462, -0.6984,  0.2349,  0.6264, -0.1520, -0.0270,  0.0744,  0.7470,
         -0.6638, -0.4921],
        [-0.0750, -0.3446,  0.3228,  0.8842, -0.4226,  0.5595, -0.2669,  0.5053,
         -0.8270, -1.0437],
        [ 0.1198, -0.2546, -0.4446,  0.7845, -0.2185,  0.7140,  0.7767,  0.3926,
         -0.4172, -0.7034],
        [ 0.2477, -0.5734, -1.1343,  0.4480, -0.3945,  0.2017,  0.5794, -0.3363,
         -0.5125, -0.0073],
        [-0.0868, -1.0720, -0.0891,  0.8438, -0.4599,  0.1149,  0.6257,  0.2839,
         -0.4985, -0.6350],
        [-0.5463, -0.5971,  0.0878,  1.2156, -0.4918,  0.6928,  0.0571, -0.3310,
         -0.6277, -0.5784],
        [-0.4145, -1.2115,  0.0130,  0.5647, -0.4880,  0.1930,  1.2356, -0.5412,
         -0.2987, -0.3017],
        [-0.8434, -0.1832,  0.1639,  1.2609, -1.1015,  0.0567,  0.3911,  0.7918,
         -0.8879,  0.6071],
        [ 0.2296, -0.5050, -0.5915,  0.8969, -0.5574,  0.4475, -0.1084,  0.2053,
         -0.3700, -0.0624],
        [ 0.2754, -0.6075, -0.3027,  0.2675, -0.0463,  0.4633,  1.1827, -0.0969,
         -0.5176, -0.2548],
        [ 0.1921, -0.8294,  0.0769,  0.1889, -0.0720,  0.4840,  0.3048,  0.5556,
         -0.8595, -0.7720],
        [-0.0942, -0.6982,  0.1471,  0.8397, -0.1127,  0.0428,  0.7941, -0.3751,
         -0.6646, -0.8243],
        [-0.7187,  0.0886, -0.0710,  0.5683, -0.3736,  0.2871,  0.2745,  0.4877,
         -0.6660, -0.5355],
        [ 0.2336, -1.0465, -0.0671,  0.6099,  0.0996,  0.4694,  0.1808,  0.6425,
         -0.6757, -1.0143],
        [ 0.0464, -0.5646, -0.2539,  0.3788, -0.0446,  0.7635,  0.5817,  0.0257,
         -0.6694,  0.0111],
        [-0.4847, -0.6977, -0.0578,  0.1414, -0.5136,  0.3634,  1.4024, -0.7030,
         -0.3273, -0.0170],
        [-0.0192, -0.6165, -0.4976,  0.5407, -0.1591,  0.5681,  0.2570,  0.1294,
         -0.3070, -1.0675],
        [-0.2466, -0.3795, -0.0718,  0.7862, -0.5584,  0.6998,  0.3789, -0.2491,
         -0.5380, -0.5180]])

 

※ 결과 해석

출력된 결과는 모델이 각 배치에 대해 CIFAR-10의 10개 클래스에 대해 예측한 로짓(logits) 값입니다. 로짓은 각 클래스에 대해 모델이 예측한 원시 점수를 나타내며, 이 값이 클수록 해당 클래스에 속할 가능성이 높다고 모델이 판단한 것입니다. 각 로우는 CIFAR-10 데이터셋의 한 이미지에 대한 예측 결과이며, 10개의 숫 자로 이루어져 있습니다.
예를 들어, 출력된 첫 번째 로우 [0.1959, -0.5907, 0.1424, 0.6853, -0.9062, -0.2954, 0.2337, -0.4838, 0.3977, -0.0388] 에서 가장 큰 값은 0.6853로, 이 값에 해당하는 클래스가 모델이 해당 이미지 에 대해 가장 높은 확률로 예측한 클래스입니다.
이를 확인하려면 torch.max()를 사용하여 가장 높은 로짓 값을 가진 인덱스를 확인할 수 있습니다.

 

3.2.최종 예측 클래스 추출

_, predicted_class = torch.max(outputs, 1)
print(f"Predicted class: {predicted_class}")

 

1) torch.max(outputs, 1)

outputs는 모델의 예측 결과로, 각 데이터 샘플에 대해 10개의 클래스에 대한 로짓(logits) 값을 포함합니다.
torch.max(outputs, 1)은 두 가지 값을 반환합니다: 첫 번째 값은 각 샘플에서 가장 큰 로짓 값을 의미합니다.두 번째 값은 그 로짓 값이 위치한 인덱스를 의미합니다.
여기서 두 번째 값(인덱스)가 predicted_class로 저장됩니다. 즉, 각 샘플에서 모델이 가장 높은 점수를 준 클래스 인덱스를 나타냅니다.


2) print(f"Predicted class: {predicted_class}")

predicted_class는 모델이 예측한 각 샘플의 클래스 인덱스를 포함한 텐서입니다. 이 텐서를 출력하여 각 샘플이 어떤 클래스로 예측되었는지 확인할 수 있습니다.

Predicted class: tensor([3, 5, 3, 3, 6, 5, 5, 7, 6, 1, 3, 3, 3, 9, 7, 3, 3, 6, 3, 3, 6, 3, 3, 6,
        7, 3, 3, 7, 5, 6, 5, 3])

 

3.3 예측 결과 시각화 및 정답 비교

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

# CIFAR-10 이미지 전처리
transform = transforms.Compose([
    transforms.Resize(64), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# CIFAR-10 테스트 데이터셋 로드
testset = ImageFolder(root='./train', transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=8,
                                         shuffle=False, num_workers=2)

# CIFAR-10 클래스 정의
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

# 이미지를 시각화하는 함수 정의
def imshow(img):
    img = img / 2 + 0.5  # 정규화 해제
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

# 모델을 평가 모드로 설정
model.eval()

# 8장의 이미지를 가져와 예측
dataiter = iter(testloader)
images, labels = next(dataiter)

# 모델을 사용해 예측
with torch.no_grad():
    outputs = model(images)

_, predicted = torch.max(outputs, 1)

# 이미지와 함께 실제 및 예측된 클래스를 시각화
imshow(torchvision.utils.make_grid(images, nrow=4))  # 2x4 형태로 이미지 배치

# 정답과 예측을 비교하여 함께 출력
for i in range(8):
    print(f"Image {i+1}: GroundTruth = {classes[labels[i]]}, Predicted = {classes[predicted[i]]}")
Image 1: GroundTruth = plane, Predicted = horse
Image 2: GroundTruth = plane, Predicted = dog
Image 3: GroundTruth = plane, Predicted = ship
Image 4: GroundTruth = plane, Predicted = car
Image 5: GroundTruth = plane, Predicted = dog
Image 6: GroundTruth = plane, Predicted = dog
Image 7: GroundTruth = plane, Predicted = truck
Image 8: GroundTruth = plane, Predicted = bird

 

※ 결론
이 코드는 사전학습된 모델을 CIFAR-10 데이터셋에 적용한 예시로, 사전 학습된 모델을 그대로 사용하면 예측 성능이 저조할 수 있음을 보여줍니다. CIFAR-10 데이터셋에 맞게 모델을 미세 조정(fine-tuning) 하면 더 나은 성능을 기대할 수 있습니다.