SJ_Koding

Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset 본문

PyTorch Code/Pytorch

Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset

성지코딩 2023. 11. 7. 22:48

2023.11.07 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드 자세히 이해하기 (1편) - 데이터 확인

 

Pytorch, 이미지 분류 코드 자세히 이해하기 (1편) - 데이터 확인

1편 내용: import 문, SEED 고정, DataFrame화, rsplit, natsort, countplot 1편은 pytorch문법이 나오지는 않으나 반드시 필수적으로 처리해야하는 부분들입니다. 교내에서 진행한 AI경진대회에서 30개의 클래스

sjkoding.tistory.com

* 이번 포스트는 위의 글과 이어지는 포스트입니다. 코드의 흐름이 이어지므로 꼭 확인바랍니다.

Custom Dataset 설정

*아래 코드를 잘 기억해 두었다가 추후 생성자의 인자가 어느형태로 어떻게 전달되는지 흐름을 이해하시기 바랍니다. 

class CustomDataset(Dataset):
    def __init__(self, img_paths, target=None, transforms=None, is_test=False):
        self.img_paths = img_paths
        self.target = target
        self.transforms = transforms
        self.is_test = is_test

    def __getitem__(self, idx):
        x = Image.open(self.img_paths[idx]).convert('RGB')

        if not self.is_test:
            y = self.target[idx]

        if self.transforms is not None:
            x = self.transforms(x)

        if self.is_test:
            return x

        return x, y

    def __len__(self):
        return len(self.img_paths)

해당 코드는 torch.utils.data.Dataset을 상속받아 새로운 커스텀 서브클래스를 만듭니다. 해당 클래스는 PyTorch 모델을 훈련할 때, 데이터 로딩 과정을 커스터마이즈하는데 사용됩니다.

 

Dataset은 아래 세가지 함수를 반드시 구현해야합니다.

def __ init__(self, <초기화할 인자들>) : 클래스의 생성자입니다.
def __getitem__(self, idx): 특정 인덱스의 데이터를 가져옵니다.
def __len__(self): 데이터셋의 총 데이터수를 반환하는 메서드입니다.

 

Dataset을 상속받는 이유는 추후 DataLoader를 사용하기 위함인데 지금 간단히 말씀드리면 배치처리(Batching), 데이터 섞기(Shuffling), 병렬처리(Parallelism), 자동화된 반복(Iteration)등을 지원하기 때문에 PyTorch를 사용함에 있어 반드시 필요합니다.

DataLoader를 사용할 시점에서 다시 자세히 말씀드리겠습니다.

 

코드 자세히 살펴보기

미리 스포일러하자면 위의 Dataset은 다음과 같이 사용됩니다. (실제 코드에서는 학습을 진행하는 시점에서 아래 코드를 실행합니다.)

all_dataset = CustomDataset(all_df['img_path'], all_df['class'], train_aug)

all_df 는 이전 게시글(https://sjkoding.tistory.com/27) 에서 만든 train.csv를 pandas의 read_csv로 불러온 DataFrame입니다.

현재 인자로 'img_path'와 'class'를 담은 Series형태의 데이터를 보내고 있으며 train_aug라는 커스텀 augmentation함수를 전달합니다. (augmentation은 다음 게시글에서 언급합니다.)

 

def __init__(self, img_paths, target=None, transforms=None, is_test=False):
    self.img_paths = img_paths
    self.target = target
    self.transforms = transforms
    self.is_test = is_test

따라서 클래스의 인스턴스들을 직관적으로 이해할 수 있습니다. (헷갈리면 댓글 부탁드립니다.) 

 

*is_test: 해당 코드는 AI경진대회를 진행한 코드입니다. 거의 모든 AI경진대회는 다음과 같은 내용이 적용됩니다.
주어진 test셋의 라벨은 공개되지 않으며, 모델이 추정한 값을 제출하면 내부 시스템에 저장된 답과 비교하여 성능을 추정하여 겨루는 시스템입니다. 따라서 test셋의 라벨값은 없으므로 __getitem__() 부분에서 라벨이 아닌 이미지만 return하기 위함입니다.
def __getitem__(self, idx):
    x = Image.open(self.img_paths[idx]).convert('RGB')
	# 이미지를 열어 numpy형태로 저장하며 이를 RGB형태의 이미지로 변환합니다. 
    # Image라이브러리는 from PIL import Image로 불러옵니다.
    
    if not self.is_test: # test셋이 아니라면
        y = self.target[idx] # idx번째 데이터의 class 저장

    if self.transforms is not None: 
        x = self.transforms(x) # 이미지에 증강 적용

    if self.is_test: # test셋이면
        return x # 이미지만 반환

    return x, y # 나머지는 이미지와 클래스 반환

__getitem__(self, idx)는 idx번째 데이터를 반환하기 위한 메소드입니다. 주석을 통해 쉽게 이해할 수 있습니다.

 

def __len__(self):
    return len(self.img_paths)

__len__(self)는 데이터의 개수를 반환하는데, 이는 DataLoader에서 내부적으로 사용됩니다. 데이터의 개수이므로  len(self.img_paths)도 좋고 len(self.target)도 좋습니다. 어차피 같은 개수의 데이터가 인자로 주어졌기 때문입니다.

 

좋아요, 이제 DataLoader를 사용할 준비와 이미지를 불러올 준비가 되었습니다. 다음 게시글에서는 모델의 일반화를 위한 augmentation을 중점으로 다루겠습니다. 헷갈리는 부분, 어려운 부분은 댓글로 남겨주시면 답글 남겨드리겠습니다.

 

-2편 끝-

썸네일용