MNIST/CIFAR10등 torch 라이브러리에서 지원하는 데이터 조작법
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader
transforms_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transforms_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
dataset_train = CIFAR10(root='../data', train=True,
download=True, transform=transforms_train)
dataset_test = CIFAR10(root='../data', train=False,
download=True, transform=transforms_test)
train_loader = DataLoader(dataset_train, batch_size=128,
shuffle=True, num_workers=4)
test_loader = DataLoader(dataset_test, batch_size=128,
shuffle=False, num_workers=4)
## train 데이터를 조작해서 두배로 불리우는 작업
td=train_loader.dataset.train_data
new_td=td.copy() ## 기존 데이터를 copy시켜야함
td3=remap((td)) ## remap은 이미지 augmentation function(user defined)
train_loader.dataset.train_data=np.append(new_td,td3,axis=0) ## 합치기
la=train_loader.dataset.train_labels
### 확인하는법
from tqdm import tqdm
progress_bar = tqdm(train_loader)
for i, (images, labels) in enumerate(progress_bar):
print(i)
### image plotting
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import PIL
def imshow(img):
img = img/2+0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg,(1,2,0)))
## 원래 이미지 첫번째
imshow(torchvision.utils.make_grid(transforms_train ((train_loader.dataset.train_data[1,:,:,:]))))
## augmentation 시킨 이미지 첫번째
print(transforms_img((train_loader.dataset.train_data[1,:,:,:])))
print(transforms_img((train_loader.dataset.train_data[50001,:,:,:])))
'파이썬' 카테고리의 다른 글
Sort index 찾는 방법 [저장용] (0) | 2020.05.26 |
---|
댓글