이번에 포스팅할 논문은 RadialGAN (Leveraging multiple datasets to improve target-specific predictive models using Generative Adversarial Networks) 입니다. 제목에서 알수 있듯이 predictive model (classifier)의 성능을 높이기 위한 GAN으로 2018ICML에 나온 논문입니다.
사실 GAN을 이용해서 데이터를 더 생성 해서 classifer의 성능을 높일 수 있을까라는 의문은 많은 사람들이 가지고 있던 의문이였습니다. GAN의 성능을 높일 수 있는 알고리즘들은 계속 발전해왔지만 이런 task에 대해서 연구가 활발히 진행되어 오지는 않았습니다. 본 논문에서는 여러 도메인(비슷한)에서 특정 도메인으로 transfer시켜서 데이터를 부풀리면 classifer의 성능을 높일수 있다라고 주장합니다.
본 논문에서는 많은 실험데이터를 사용한 것은 아닙니다. 여러 병원의 환자데이터를 활용해 classification 문제를 풀고자 하였으며, 여기에는 두가지의 main challenges가 있습니다. 첫번째로 Feature mismatch 문제 (병원마다 가지고 있는 환자의 feature가 다름) 가 있고 두번째로 Distribution mismatch 문제 ( 병원의 특성에 따라 환자의 분포가 다름, 종합병원, 개인병원등) 가 있습니다. Feature와 distribution이 다르기 때문에 단순히 데이터를 합치기에는 문제가 발생하게 되고 이를 GAN을 통해 해결하고자 합니다.
In this paper, we propose a novel approach for utilizing datasets from multiple sources to effectively enlarge a target dataset. The proposed model, which we call RadialGAN, provides a natural solution to the two challenges outlined above and moreover is able to jointly perform the task for each dataset. We use multiple GAN architectures to translate the patient information from one hospital to another, leveraging the adversarial framework to ensure that the learned translation respects the distribution of the target hospital.
함수 f1을 X1을 X2로 만드는 것으로 정의 한다면 함수 f1은 다음과 같이 쓸 수 있고, 매우 쉬운 문제라는 것을 알 수 있습니다.
함수 f2을 Z를 X2로 만드는 것으로 정의 한다면 함수 f2는 다음과 같이 쓸 수 있고, 쉬운 문제가 아니라는 것 또한 직관적으로 알 수 있습니다. 물론 GAN을 통해 mapping시키는 것은 가능하지만, 위 함수 f1에 대한 문제보다는 상대적으로 어렵다는 것을 알 수 있습니다.
함수 fi를 linear function으로 approximation 한다면 f1이 f2에 비해 좀 더 잘 approximation될 것입니다(함수 f1는 linear space에 있기 때문이죠). 함수 fi를 neural net으로 approximation 한다면 f2를 학습시키기 위해서는 large capacity가 요구됩니다. 그러나 neural net의 large capacity를 충족하기 위해서는 large training samples이 필요하죠. 아래 그림을 보시면 위의 예제를 그림으로 확인할 수가 있습니다. normal 분포를 target 분포로 하여 위는 같은 normal (평균이 다른) 분포, 아래는 uniform분포로 초기화를 시켜서 GAN을 학습시키면 적은 capacity를 가질때 부터 normal 분포로 초기화시킨것이 target 분포를 따라가지만 uniform분포로 초기화를 시킨 것은 그렇지 않은 것을 확인 할 수 있죠.
즉, Source 도메인과 target 도메인이 비슷한 분포를 가졌다면, 랜덤 분포보다 비슷한 분포에서 Generate시키는 것이 더 잘 될 것이라는 이야기입니다. (이는 CycleGAN, DiscoGAN과 같은 domain transfer GAN을 통해서 이미 직관적으로 알고 있는 사실이기는 합니다.) 아래 그림에서도 위 그림과 비슷한 현상을 확인할 수 있습니다.
Cycle GAN이나 DiscoGAN과 같은 경우는 도메인 A에서 도메인 B로 바로 transfer시키는데 RadialGAN은 transfer 시키기전에 중간에 Latent space로 일차적으로 Mapping 시키고 그걸 다시 transfer 시킵니다. (논문에서는 뜬금없이 latent space가 나오는 느낌이 없잖아 있는데, 두세번 정도 읽어보니 사실 이부분이 이 논문의 핵심이라는 것 같습니다. Latent space에 대해서 구체적으로 정의를 해주면 좋았을테데...) 그리고
를 아래와 같은 식으로 정의 하는데, Source 도메인과 Target 도메인의 데이터 수가 다르다보니, GAN을 학습시킬때 그때그때 다른 가중치를 주기 위한 것이라고 이해하면 될것 같습니다.
i 도메인 데이터를 latent space Z로 mapping하는 encoder을 Fi로 정의하고 latent space에서 다시 도메인 i로 mapping하는 decoder를 Gi로 정의합니다. Wj는 latent space Z에서 j 도메인을 latent space로 encoding하는 함수 Fi를 취하는 확률 변수로 정의를 합니다. 그리고 확률 변수 Zi 가 Wj일 확률은 위에서 정의한
와 같습니다. 즉, source domain i에 대하여 확률 변수 zi로 부터 샘플링하는 것은 i를 제외한 나머지 모든 데이터에 대해서 uniform하게 샘플링하고 그 j에 대하여 Fj를 적용 시키는 것과 같습니다. 확률 변수 zi는 자신 i를 제외한 나머지 index값을 가질수 있는 확률 변수입니다. 총 네개의 데이터가 있다고 가정을 한다면 확률 변수 Z1이 가질수 있는 값은 W2, W3, W4 세가지인 것이고, 각각의 W들은 각 데이터를 latent space로 mapping하는 함수입니다. 확률 변수 Z1이 이 W2, W3, W4를 가질수 있는 확률은 위에서 언급한
에 의해서 결정이 되는 것이죠(세개의 데이터의 비중, 위에서 든 예제로 보면 각각 2/9, 3/9, 4/9의 값을 가짐). 결국에 Z1으로 부터 sampling한다는 것은 Encoder F2, F3, F4를 샘플링한다는 의미라고 볼수 있습니다.
RadialGAN의 구조는 아래 그림과 같습니다. 그림만 보면 어느정도 직관적으로 네트워크 구조를 이해 할 수 있습니다.
Encoder와 Decoder에 대한 구체적인 구조와 loss는 다음과 같습니다. Adversarial loss와 Cycle loss를 추가해서 최종적인 loss를 구성합니다.
Discriminator입장에서 i번째 도메인 데이터가 real로 구분하고 Encoder와 Decoder를 거쳐 transfer 시킨 데이터가 real이 아닌지 구별하고 Generator입장에서 transfer시킨 데이터가 Discriminator가 real로 판별하도록 만든 loss와 (adv loss)
i의 Source 도메인이 encoder와 decoder로 다시 복원된 i의 source 도메인이 같고 i를 제외한 나머지 도메인들에 대하여 각각 encoding 시킨 것과 그것을 다시 decoding, encoding시킨 것이 같도록(i의 latent space와 latent space를 decoding하고 encoding한것이 같도록) loss를 구성합니다. i의 latent space는 i를 제외한 나머지 j에 대하여 encoding시킨 것과 같기 때문에 아래와같은 식이 성립하게 됩니다. 이 Cycle loss의 두번째 term이 사실 이해하기 가장 어려운 부분이기도 하지만 가장 핵심이 되는 부분이라고 생각됩니다. 모든 데이터 셋이 다 공유할 수 있는 latent space를 학습하기 위한 것이라 생각됩니다.
D :
총 M개의 데이터셋에 대하여 :
i번째 데이터 샘플 추출 (소스도메인)
나머지 데이터 샘플 추출 (타겟도메인)
나머지 데이터에 대하여 각각의 decoder에 넣은 후, i번째 encoder로 복원
i번째 D 학습 : 복원된 데이터들이 i번째 데이터인지 아닌지 구분
G :
총 M개의 데이터셋에 대하여 :
i번째 데이터 샘플 추출 (소스도메인)
나머지 데이터 샘플 추출 (타겟도메인)
모든 Generator G,F update
i로 복원된 j데이터가 i도메인으로 D가 속고
i를 다시 i로 잘복원시키고 j를 encoder에 넣은 것과 그것을 다시 i의 decoder와 encoder로 복원하도록
아래표는 14개의 dataset중 random하게 3,5,7개를 선택하고 classifier의 성능을 측정합니다. 1번부터 14번 도메인의 test data를 예측하는데 i번째 도메인을 제외한 나머지 데이터중 랜덤하게 선택한 도메인 데이터(transfer된)를 합쳐서 분류 한 것 같습니다. (논문에서는 random하게 m개의 dataset를 선택해서 사용했다라고 만 되어있는데, 저 해석이 맞지 않나 싶습니다. 결과를 보시면 RadialGAN이 가장 성능이 좋고 transfer된 데이터를 더 늘릴 때마다 성능이 더 좋아진다라고 강조하고 있습니다.
전체 데이터셋을 다 쓸때도 다른 알고리즘 대비 더 좋은 모습을 보이고 있습니다. 다만, Simple-combined도 좋은 성능을 보인다라는 것은 도메인 마다 매우 다른 특성을 가지는 것은 아닌 것 같다라는 생각은 듭니다.
추가적으로 feature와 distribution이 안맞는 상황을 가정해서 실험을 해보았는데
Setting A : 모든 Feature가 일치하는 데이터셋 다섯개만 학습
Setting B : 같은 데이터셋에서 feature를 random하게 33% 제거해서 학습
여기서도 RadialGAN이 가장 좋은 성능을 보였습니다.
논문을 읽으면서 궁금했던 점은 네트워크 구조가 어떻게 되느냐였는데 그 부분에 대한 설명은 전혀 없습니다. 각 Generator들의 input과 output의 dimension이 같아야 할텐데.. 그럼 feature의 maximum값으로 설정했을런지... 그래야 될 것 같긴 합니다만. 더 생각을 해보면 transfer된 데이터의 feature의 수는 test data의 feature의 수와 같아야 할텐데... test data의 feature수도 maximum값으로 설정하고 나머진 0으로 채웠는지... 좀더 친절하게 설명해주면 좋았을법한 논문 이였습니다만, Data Science 분야에서도 GAN으로 classifier의 성능을 높일 수 있다라는 사실을 증명한 논문이라 나름 임팩트가 있는것 같습니다.
'딥러닝' 카테고리의 다른 글
[딥러닝 논문 리뷰] Understanding Deep Learning Requires Rethinking Generalization2 (0) | 2022.12.10 |
---|---|
[딥러닝 논문 리뷰] Understanding Deep Learning Requires Rethinking Generalization (0) | 2022.12.10 |
GAN (Generative Adversarial Networks) 기본 개념과 학습 원리 (1) | 2022.12.10 |
자연어처리(NLP)분야의 다양한 Task와 데이터 (0) | 2020.12.10 |
[딥러닝 논문 리뷰] DOMAIN GENERALIZATION WITH MIXSTYLE (1) | 2020.10.28 |
댓글