Notice
Recent Posts
Recent Comments
Link
«   2024/10   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
Archives
Today
Total
10-05 13:28
관리 메뉴

SJ_Koding

Pytorch, 이미지 분류 코드 자세히 이해하기 (3편) - AutoAugment 본문

PyTorch Code/Pytorch

Pytorch, 이미지 분류 코드 자세히 이해하기 (3편) - AutoAugment

성지코딩 2023. 11. 8. 09:59

2023.11.07 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset

 

Pytorch, 이미지 분류 코드 자세히 이해하기 (2편) - Dataset

2023.11.07 - [Deep Learning/Pytorch] - Pytorch, 이미지 분류 코드 자세히 이해하기 (1편) - 데이터 확인 Pytorch, 이미지 분류 코드 자세히 이해하기 (1편) - 데이터 확인 1편 내용: import 문, SEED 고정, DataFrame화, r

sjkoding.tistory.com

이번 글 역시 이전과 이어지는 글입니다.

 

Albumentation Demo 사이트

torchvisionaugmentation함수와 albumentation의 라이브러리는 타 블로그에 너무 잘 정리되어있습니다. 해당 대회에서 초기 albumentation 라이브러리를 사용했다가, AutoAugment를 사용하기로 했습니다.

 

그 전에 albumentation라이브러리의 좋은 사이트를 공유합니다.

https://demo.albumentations.ai/

 

Streamlit

 

demo.albumentations.ai

해당 사이트는 albumentation이 지원하는 증강 기법을 코드 없이 미리 적용해볼 수 있는 사이트입니다. 

demo.albumentation.ai

위 사진처럼 구성되어있으며 좌상단 Professional을 선택하시면 원하는 사진을 업로드 할 수 있고, 여러 증강을 겹쳐서 진행시킬 수 있습니다. 행여 이미지가 너무 훼손되지는 않는지, 데이터셋에 적합하게 증강이 되는지, 회전하면 안되는 사진에 대해 회전을 시켜버렸는지 미리 확인해볼 수 있는 좋은 사이트입니다.

 

AutoAugment 

AutoAugment는 모델의 일반화 성능을 향상시키기 위해 데이터셋에 자동으로 증강을 적용시키는 알고리즘입니다. 다양한 변형을 진행하면서 어떤 변형을 얼마나 적용할지를 강화학습(Rainforcement Learning)을 통해 최적화됩니다. AutoAugment 알고리즘의 핵심은 정책(Policy)라고 불리는 규칙의 집합을 학습하는 것입니다. 각 정책은 하나 이상의 데이터 증강 변환(transformations)과 이러한 변환을 적용할 확률, 변환의 강도를 결정하는 매개변수들을 포함합니다. 이 정책은 탐색 과정(Search Process)을 통해 발견되며, 강화학습에서 에이전트가 환경과 상호작용하면서 보상을 극대화하는 전략을 학습하는 방식과 유사합니다.

 

AutoAugment 5단계

1. 탐색 공간 정의: 가능한 모든 데이터 증강 변환과 그 매개변수의 조합으로 탐색 공간을 정의합니다.

2. 정책 탐색: 강화학습을 사용하여 가장 성능이 좋은 데이터 증강 정책을 탐색합니다. 이때, 보통은 RNN(Recurrent Neural Network)과 같은 모델을 사용하여 각각의 데이터 증강 변환과 매개변수를 선택합니다.

3. 보상 함수: 정책의 성능을 평가하기 위한 보상 함수를 정의합니다. 이는 대개 검증 데이터셋(validation set)에서의 모델 성능으로 측정됩니다.

4. 최적의 정책 선택: 탐색 과정에서 발견된 여러 정책들 중, 보상을 최대화하는 최적의 정책을 선택합니다.

5. 데이터 증강 적용: 최적화된 정책을 사용하여 훈련 데이터셋에 데이터 증강을 적용하고, 이를 통해 훈련된 모델을 최종 평가합니다.

 

 

AutoAugment 사용법

https://github.com/DeepVoltaire/AutoAugment

 

GitHub - DeepVoltaire/AutoAugment: Unofficial implementation of the ImageNet, CIFAR 10 and SVHN Augmentation Policies learned by

Unofficial implementation of the ImageNet, CIFAR 10 and SVHN Augmentation Policies learned by AutoAugment using pillow - GitHub - DeepVoltaire/AutoAugment: Unofficial implementation of the ImageNet...

github.com

!git clone https://github.com/uoguelph-mlrg/Cutout.git
!git clone https://github.com/DeepVoltaire/AutoAugment.git

노트북환경에서는 위의 명령어를 통해 AutoAugment를 clone해옵니다.

 

해당 코드는 ImageNetPolicy, CIFAR10Policy, SVHNPolicy를 지원하며, 강화학습으로 찾아낸 최적의 증강 종류 및 강도를 사용해볼 수 있습니다. 이번 데이터셋은 CIFAR10과 매우 유사하므로 아래의 증강을 적용시켰습니다. (아래 코드는 clone으로 가져온 소스코드의 일부이므로, 구현할 필요가 없습니다.)

class CIFAR10Policy(object):
    """ Randomly choose one of the best 25 Sub-policies on CIFAR10.

        Example:
        >>> policy = CIFAR10Policy()
        >>> transformed = policy(image)

        Example as a PyTorch Transform:
        >>> transform=transforms.Compose([
        >>>     transforms.Resize(256),
        >>>     CIFAR10Policy(),
        >>>     transforms.ToTensor()])
    """
    def __init__(self, fillcolor=(128, 128, 128)):
        self.policies = [
            SubPolicy(0.1, "invert", 7, 0.2, "contrast", 6, fillcolor),
            SubPolicy(0.7, "rotate", 2, 0.3, "translateX", 9, fillcolor),
            SubPolicy(0.8, "sharpness", 1, 0.9, "sharpness", 3, fillcolor),
            SubPolicy(0.5, "shearY", 8, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.5, "autocontrast", 8, 0.9, "equalize", 2, fillcolor),

            SubPolicy(0.2, "shearY", 7, 0.3, "posterize", 7, fillcolor),
            SubPolicy(0.4, "color", 3, 0.6, "brightness", 7, fillcolor),
            SubPolicy(0.3, "sharpness", 9, 0.7, "brightness", 9, fillcolor),
            SubPolicy(0.6, "equalize", 5, 0.5, "equalize", 1, fillcolor),
            SubPolicy(0.6, "contrast", 7, 0.6, "sharpness", 5, fillcolor),

            SubPolicy(0.7, "color", 7, 0.5, "translateX", 8, fillcolor),
            SubPolicy(0.3, "equalize", 7, 0.4, "autocontrast", 8, fillcolor),
            SubPolicy(0.4, "translateY", 3, 0.2, "sharpness", 6, fillcolor),
            SubPolicy(0.9, "brightness", 6, 0.2, "color", 8, fillcolor),
            SubPolicy(0.5, "solarize", 2, 0.0, "invert", 3, fillcolor),

            SubPolicy(0.2, "equalize", 0, 0.6, "autocontrast", 0, fillcolor),
            SubPolicy(0.2, "equalize", 8, 0.6, "equalize", 4, fillcolor),
            SubPolicy(0.9, "color", 9, 0.6, "equalize", 6, fillcolor),
            SubPolicy(0.8, "autocontrast", 4, 0.2, "solarize", 8, fillcolor),
            SubPolicy(0.1, "brightness", 3, 0.7, "color", 0, fillcolor),

            SubPolicy(0.4, "solarize", 5, 0.9, "autocontrast", 3, fillcolor),
            SubPolicy(0.9, "translateY", 9, 0.7, "translateY", 9, fillcolor),
            SubPolicy(0.9, "autocontrast", 2, 0.8, "solarize", 3, fillcolor),
            SubPolicy(0.8, "equalize", 8, 0.1, "invert", 3, fillcolor),
            SubPolicy(0.7, "translateY", 9, 0.9, "autocontrast", 1, fillcolor)
        ]

    def __call__(self, img):
        policy_idx = random.randint(0, len(self.policies) - 1)
        return self.policies[policy_idx](img)

    def __repr__(self):
        return "AutoAugment CIFAR10 Policy"

Flips + CIFAR10Policy

torchvision의 transforms를 사용했으며, Compose 메소드로 여러 증강을 하나로 묶어 변수로 저장할 수 있습니다. 이전 게시글에서 선언한 CustomDataset의 transforms 인자로 들어가게 됩니다.

from torchvision import transforms

from AutoAugment.autoaugment import CIFAR10Policy
policy = CIFAR10Policy()
train_aug=transforms.Compose([
                         transforms.RandomHorizontalFlip(),
                         transforms.RandomVerticalFlip(),
                         CIFAR10Policy(),
			             transforms.ToTensor(),
                         transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))]
                        )


valid_aug=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize(mean=(0.5071, 0.4867, 0.4408), std=(0.2675, 0.2565, 0.2761))
                        ])

이렇게 train_aug, valid_aug를 선언했습니다.

*참고
Normalize값에 대해 사전에 알려진 평균 및 표준편차값을 적용합니다.
mean = {
'cifar10': (0.4914, 0.4822, 0.4465),
'cifar100': (0.5071, 0.4867, 0.4408),
}

std = {
'cifar10': (0.2023, 0.1994, 0.2010),
'cifar100': (0.2675, 0.2565, 0.2761),
}

 

다음은 모델 선언 및 모델 학습에 대해 다룹니다.

 

- 3편 끝 -