본문 바로가기
파이썬

[저장용] train_loader

by 미스터탁 2019. 2. 8.

MNIST/CIFAR10등 torch 라이브러리에서 지원하는 데이터 조작법





import torchvision

import torchvision.transforms as transforms

from torchvision.datasets import CIFAR10

from torch.utils.data import DataLoader


transforms_train = transforms.Compose([

transforms.RandomCrop(32, padding=4),

transforms.RandomHorizontalFlip(),

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

])

transforms_test = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),

])


dataset_train = CIFAR10(root='../data', train=True, 

download=True, transform=transforms_train)

dataset_test = CIFAR10(root='../data', train=False, 

   download=True, transform=transforms_test)



train_loader = DataLoader(dataset_train, batch_size=128,

                      shuffle=True, num_workers=4)

test_loader = DataLoader(dataset_test, batch_size=128,

                     shuffle=False, num_workers=4)


## train 데이터를 조작해서 두배로 불리우는 작업

td=train_loader.dataset.train_data

new_td=td.copy() ## 기존 데이터를 copy시켜야함 

td3=remap((td)) ## remap은 이미지 augmentation function(user defined)

train_loader.dataset.train_data=np.append(new_td,td3,axis=0) ## 합치기

la=train_loader.dataset.train_labels

la2=la+la 
train_loader.dataset.train_labels=la2 ## 라벨도 합치기 


### 확인하는법 

from tqdm import tqdm

progress_bar = tqdm(train_loader)

for i, (images, labels) in enumerate(progress_bar):

    print(i)

## cifar10의 경우 128 batch size를 가지므로 100000/128  = 781까지 나와야 정상


### image plotting

import torchvision

import matplotlib.pyplot as plt

import numpy as np

import PIL

def imshow(img):

    img = img/2+0.5

    npimg = img.numpy()

    plt.imshow(np.transpose(npimg,(1,2,0)))


## 원래 이미지 첫번째

imshow(torchvision.utils.make_grid(transforms_train ((train_loader.dataset.train_data[1,:,:,:]))))


## augmentation 시킨 이미지 첫번째

imshow(torchvision.utils.make_grid(transforms_train ((train_loader.dataset.train_data[50001,:,:,:]))))

print(transforms_img((train_loader.dataset.train_data[1,:,:,:])))

print(transforms_img((train_loader.dataset.train_data[50001,:,:,:])))

반응형

'파이썬' 카테고리의 다른 글

Sort index 찾는 방법 [저장용]  (0) 2020.05.26

댓글