파이토치에서 데이터들을 학습할 때 굉장히 유용한 기능으로 DataLoader가 있습니다.
DataLoader는 파이토치에서 데이터들을 원하는 batch size로 잘라줍니다. DataLoader를 사용하면 batch size에 맞추어 학습을 굉장히 쉽게 학습을 수행할 수 있습니다. 이 때 DataLoader에 넣어주어야 하는 값이 Dataset이 됩니다.
How To Use Dataset
from torchvision import datasets, transforms
train_dataset = datasets.MNIST(
root = "data",
download = True,
train = True,
transform = transforms.Compose([
transforms.ToTensor()
])
)
test_dataset = datasets.MNIST(
root = "data",
download = True,
train = False,
transform = transforms.Compose([
transforms.ToTensor()
])
)
TorchVision의 MNIST 데이터를 불러오는 코드입니다.
- root : 가져올 데이터가 저장되어 있는 위치를 의미합니다. 현재의 위치 './data'에 있는 데이터를 가져오게 됩니다.
- download : 만약 root 위치에 데이터가 없다면 데이터를 다운로드 받습니다.
- train : train 데이터를 가져올 것인지 test 데이터를 가져올 것인지 결정합니다.
- transform : 데이터의 해당 transform을 수행합니다.
Make Custom Dataset
TorchVision에서 제공하는 Dataset 이 외에 데이터를 사용하고 싶다면 직접 DataSet을 정의해서 사용할 수 있습니다.
Custom Dataset을 만들기 위해서는 '__init__', '__len__', '__getitem__'을 가지고 있는 Class를 정의해주면 됩니다.
이 때 해당 클래스는 torch.utils.data.Dataset을 상속 받아야합니다.
__init__은 생성자로 데이터셋을 구성하기 위해 필요한 정보들을 받고 저장합니다.
__len__은 데이터의 길이정보를 retrun합니다.
__getitem__은 idx에 대한 데이터 하나를 정제과정을 거쳐 넘겨줍니다.
class CustomDataset(Dataset):
def __init__(self, label, data, transform=None, target_transform=None):
self.label =label
self.data = data
self.transform = transform
self.target_transform = target_transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return_data = self.data[idx]
return_label = self.labels[idx]
if self.transform:
return_data = self.transform(return_data)
if self.target_transform:
return_label = self.target_transform(return_label)
return return_data, return_label
__init__에서 label과 data 정보를 list로 받는다고 가정하였습니다. 이는 다르게 받을 수도 있습니다.(이미지라면 파일의 디렉토리) 필요한 정보들을 __getitem__과 __len__에서 사용하기 위해 저장해둡니다.
__getitem__은 idx의 데이터와 라벨을 가져와 return해주게 됩니다.
Make DataLoader
from torch.utils.data import DataLoader
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
DataLoader를 정의하는 방법은 간단합니다.
Dataset을 넣어주면 됩니다.
batch_size는 데이터를 어느 batch size로 가져올지를 결정합니다. 만약 batch_size가 64라면 64개의 batch사이즈 씩 가져오게 됩니다.
shuffle을 True로 지정하면 데이터의 배치를 랜덤하게 가져오게 됩니다.(순서대로 학습하면 그 순서를 모델이 학습해서 예측해버릴 수도 있기 때문)
References
https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
https://tutorials.pytorch.kr/beginner/basics/data_tutorial.html
'Artificial intelligence > Pytorch' 카테고리의 다른 글
torchvision.transforms (ToTensor, Normalize, Resize, RandomCrop,Compose) (0) | 2022.09.02 |
---|---|
[Pytorch] nn.module을 상속받을 때 super().__init__()을 하는 이유 (0) | 2022.08.26 |