250x250
250x250
JinSeopKim
Hello World!
JinSeopKim
전체 방문자
오늘
어제
  • 분류 전체보기 (168)
    • Artificial intelligence (14)
      • DeepDiveToAI (3)
      • Pytorch (3)
      • Etc (8)
    • Back-end (19)
      • Spring (10)
      • JPA (9)
    • Language (24)
      • Python (3)
      • Java (11)
      • Swift (10)
    • Math (4)
      • Linear Algebra (4)
    • CodingTest (79)
      • Algolithm (12)
      • Backjoon (25)
      • Programmers (42)
    • Etc (27)
      • Book Review (3)
      • Adsp (6)
      • Life (2)
      • Docker (1)
      • odds and ends (15)

블로그 메뉴

  • 홈
  • 태그
  • 방명록
  • GitHub

인기 글

태그

  • ADsP
  • 파이썬
  • java
  • BFS
  • data
  • swift
  • 자바
  • 개발자
  • AI
  • JAVA8
  • BOJ
  • 브루트포스
  • 개발
  • ios
  • 문제풀이
  • 프로그래머스
  • 코딩테스트
  • 머신러닝
  • SpringMVC
  • 백준
  • uArm
  • 선형대수
  • 구현
  • certificate
  • 카카오
  • Front-end
  • Python
  • 알고리즘
  • JPA
  • DP

최근 댓글

최근 글

티스토리

hELLO · Designed By 정상우.
JinSeopKim

Hello World!

[Pytorch] nn.module을 상속받을 때 super().__init__()을 하는 이유
Artificial intelligence/Pytorch

[Pytorch] nn.module을 상속받을 때 super().__init__()을 하는 이유

2022. 8. 26. 01:38
728x90
728x90

파이토치에서 클래스로 Layer나 Model을 구현해주면 항상 생성자에서 super(class이름, self).__init__()을 입력해줍니다. 왜 이것을 입력해야 하는지 궁금하여 알아보았습니다.

 

super().__init__()이 없다면?

import torch

class Test(torch.nn.Module):
    def __init__(self):
        self.linear = torch.nn.Linear(3,2)

    def forward(self,x):
        return self.linear(x)

Test를 위해 굉장히 간단한 torch.nn.Module을 상속받는 클래스를 만들어보았습니다.

그리고 이 클래스를 활용해 모델을 만들어보겠습니다.

model = Test()

생성을 하게 된다면 AttributeError: cannot assign module before Module.__init__() call의 에러메세지를 받을 수 있습니다.

에러 내용을 해석해보면 module을 사용하기 위해서는 Module.__init__()을 먼저 수행하라 즉 super().__init__()을 수행하라는 것을 알 수 있습니다.(super가 nn.Module이므로)

 

그렇다면? torch.nn모듈을 사용하지 않는다면? 어떻게 될까요?

import torch

class Test(torch.nn.Module):
    def __init__(self):
        self.out = 5
        #self.linear = torch.nn.Linear(3,2)

    def forward(self,x):
        return self.out
        #return self.linear(x)

Test클래스에서 torch.nn 모듈을 사용하지 않았습니다.

model = Test()

그 결과는 정상적으로 작동되는 것을 알 수 있습니다. 

즉 이 결과로 우리는 torch.nn의 모듈을 사용하기 위해서는 반드시 super.__init__()을 해줘야함을 유추해볼 수 있습니다.

 

Super().__init__()이란?

super()라는 것은 상속받은 부모클래스를 의미합니다. 부모 클래스를 불러와서 __init__()을 수행한다는 것은 부모클래스의 생성자를 불러주는 것을 의미합니다. 

 

그렇다면 nn.Module의 __init__()을 직접 들어가서 확인해보겠습니다.

nn.Module의 __init__()

__init__을 보면 다양한 변수가 선언되어있는 것을 볼 수있습니다. 

 

super().__init__()은 위 변수들을 상속받는 역할을 해줍니다.

위 변수는 클래스 내 함수로 접근하는 변수들입니다.

정상적으로 작동되기 위해서는 이러한 변수들을 상속받고 사용해야하므로 반드시 super().__init__()을 사용해주어야 하는 것 입니다.

 

Super().__init__() vs Super().__init__(class_name, self)

import torch

class Test(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(3,2)

    def forward(self,x):
        return self.linear(x)

위 처럼 super.__init__()을 사용할 수도 있고

import torch

class Test(torch.nn.Module):
    def __init__(self):
        super(Test,self).__init__()
        self.linear = torch.nn.Linear(3,2)

    def forward(self,x):
        return self.linear(x)

super(파생클래스이름, self).__init__()로 만들수도 있는데 사용하는데 다른점은 없다고 합니다.

단순히 현재 사용하는 클래스가 어떤 클래스인지 알리는 용도로 작성된다고 하네요.

Reference

  • https://stackoverflow.com/questions/63058355/why-is-the-super-constructor-necessary-in-pytorch-custom-modules
  • https://daebaq27.tistory.com/60
  • https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
728x90
728x90
저작자표시 비영리 (새창열림)

'Artificial intelligence > Pytorch' 카테고리의 다른 글

Pytorch Dataset과 DataLoader  (0) 2022.09.03
torchvision.transforms (ToTensor, Normalize, Resize, RandomCrop,Compose)  (0) 2022.09.02
    'Artificial intelligence/Pytorch' 카테고리의 다른 글
    • Pytorch Dataset과 DataLoader
    • torchvision.transforms (ToTensor, Normalize, Resize, RandomCrop,Compose)
    JinSeopKim
    JinSeopKim
    기록📚

    티스토리툴바