파이토치에서 클래스로 Layer나 Model을 구현해주면 항상 생성자에서 super(class이름, self).__init__()을 입력해줍니다. 왜 이것을 입력해야 하는지 궁금하여 알아보았습니다.
super().__init__()이 없다면?
import torch
class Test(torch.nn.Module):
def __init__(self):
self.linear = torch.nn.Linear(3,2)
def forward(self,x):
return self.linear(x)
Test를 위해 굉장히 간단한 torch.nn.Module을 상속받는 클래스를 만들어보았습니다.
그리고 이 클래스를 활용해 모델을 만들어보겠습니다.
model = Test()
생성을 하게 된다면 AttributeError: cannot assign module before Module.__init__() call의 에러메세지를 받을 수 있습니다.
에러 내용을 해석해보면 module을 사용하기 위해서는 Module.__init__()을 먼저 수행하라 즉 super().__init__()을 수행하라는 것을 알 수 있습니다.(super가 nn.Module이므로)
그렇다면? torch.nn모듈을 사용하지 않는다면? 어떻게 될까요?
import torch
class Test(torch.nn.Module):
def __init__(self):
self.out = 5
#self.linear = torch.nn.Linear(3,2)
def forward(self,x):
return self.out
#return self.linear(x)
Test클래스에서 torch.nn 모듈을 사용하지 않았습니다.
model = Test()
그 결과는 정상적으로 작동되는 것을 알 수 있습니다.
즉 이 결과로 우리는 torch.nn의 모듈을 사용하기 위해서는 반드시 super.__init__()을 해줘야함을 유추해볼 수 있습니다.
Super().__init__()이란?
super()라는 것은 상속받은 부모클래스를 의미합니다. 부모 클래스를 불러와서 __init__()을 수행한다는 것은 부모클래스의 생성자를 불러주는 것을 의미합니다.
그렇다면 nn.Module의 __init__()을 직접 들어가서 확인해보겠습니다.
__init__을 보면 다양한 변수가 선언되어있는 것을 볼 수있습니다.
super().__init__()은 위 변수들을 상속받는 역할을 해줍니다.
위 변수는 클래스 내 함수로 접근하는 변수들입니다.
정상적으로 작동되기 위해서는 이러한 변수들을 상속받고 사용해야하므로 반드시 super().__init__()을 사용해주어야 하는 것 입니다.
Super().__init__() vs Super().__init__(class_name, self)
import torch
class Test(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(3,2)
def forward(self,x):
return self.linear(x)
위 처럼 super.__init__()을 사용할 수도 있고
import torch
class Test(torch.nn.Module):
def __init__(self):
super(Test,self).__init__()
self.linear = torch.nn.Linear(3,2)
def forward(self,x):
return self.linear(x)
super(파생클래스이름, self).__init__()로 만들수도 있는데 사용하는데 다른점은 없다고 합니다.
단순히 현재 사용하는 클래스가 어떤 클래스인지 알리는 용도로 작성된다고 하네요.
Reference
'Artificial intelligence > Pytorch' 카테고리의 다른 글
Pytorch Dataset과 DataLoader (0) | 2022.09.03 |
---|---|
torchvision.transforms (ToTensor, Normalize, Resize, RandomCrop,Compose) (0) | 2022.09.02 |