최근에는 Latent diffusion model with Transformer architecture가 고퀄리티 이미지를 잘 생성해낸다.
하지만 최근의 연구들은 two-stage 디자인의 최적화 딜레마에 빠진다.
- visual tokenizer에서 토큰당 dim을 높이면 recon은 쉬워지지만, generation performance를 높이려면 생성단계에서 더 큰 모델과 더 많은 학습이 필요하다.
결국 기존의 연구들은 종종 sub-optimal solution을 채택한다.
- dim을 작게하면 artifact가 생기고 generation quality의 상한이 정해지며
- 크게하면 이런 문제가 없는대신 충분히 수렴하지 못한다. 혹은 컴퓨팅 연산이 더 든다.
우리는 이런 딜레다가 제약없는 high-dimensional latent space를 학습해야하는 내제된 어려움 때문이라고 생각한다.
이걸 해결하기 위해서 VAE를 학습할 때 latent space를 pre-trained vision foundation model에 align 시킨다.
→
VA-VAE
(Vision foundation model Aligned Variational AutoEncoder)→ dim이 클 때도 DiT가 빠르게 수렴하도록 해준다!
그리고 VA-VAE에 더해 DiT 단에도 몇가지 학습 전략과 아키텍처 수정을 해서 최적을 만듦 →
LightningDiT
ImageNet 256*256에서 64 epoch로 량 2.11 찍었다(단순 DiT 보다 21배 빠름)

dimension을 늘려서 reconFID가 좋아질수록, generation FID는 안좋아진다.
직관적이지만 효율적인 접근법 - AR에서 코드북 크기를 늘리면 전체 코드북이 활용되지 않는 문제가 있다.
VAE latent를 시각화해보면(위에서처럼), 큰 dim으로 학습할수록 latent들이 몰려있다. → 표현력에 한계
이걸 아무 제약없이 쌩으로 학습해야하니 발생하는 학습 최적화 문제로 봤다.
따라서 latent space에서 vision foundation model의 rep과 align 되도록 loss를 준다(VF Loss)
이전 연구들을 보면 단순히 inital module로 사용하면 어짜피 recond을 하느라 의미가 없었다 → carefully designed joint recon and align loss가 중요하다.
- element-wise, pair-wise 유사도를 둘다 적용한다. feature space에서 global 구조, local 구조를 각각 잘 반영
- flexibility를 위해 약간의 margin을 둔다(너무 강하게 따르지는 않아도 되도록)
결론부터
- optimization dilemma 문제를 해결하고 high-dim vae로도 DiT 학습을 2.5배 빠르게 만들었다.
- 64 에포크 학습으로 FID 2.11을 찍었고, 단순 DiT보다 21배 빠르다.

VAE의 모델 아키텍처를 수정하는일 없이 추가 loss 계산 하나만 더하기 때문에 복잡하지 않다.
VF Loss는 두개로 이루어진다.
- Marginal cosine similarity loss
given image
I
는 인코더에서 들어가고, vision foundation model(DINO)로도 들어간다. → Z, F그리고 Z를 linear 태워서 F와 같은 차원으로 맞춘다.

similarity는 local하게 각 위치(h, w)에서 전부 개별적으로 계산되며(latent 단위로), m1을 빼고 ReLU를 태워서 similarity가 m1보다 낮은 경우에만 반영된다.
- Marginal distance matrix similarity loss
각 개별 단위가 아니라 전체적으로 분포가 비슷하도록하는 Loss도 추가했다.

→ Z안에서 각 lante끼리의 유사도와 F안에서 각 latent 끼리의 유사도가 비슷하도록
import torch
import torch.nn.functional as F
def mdms_loss(z, f, margin=0.1):
# z, f: (N, D)
z_norm = F.normalize(z, dim=1) # (N, D)
f_norm = F.normalize(f, dim=1)
# Cosine similarity matrix: (N, N)
sim_z = z_norm @ z_norm.T
sim_f = f_norm @ f_norm.T
# Absolute difference and margin
diff = torch.abs(sim_z - sim_f)
loss = F.relu(diff - margin)
return loss.mean()
근데 두 방법 다 computation cost가 좀 든다.
- Adaptive weighting
두 loss가 갖는 영향이 비슷하도록 크기로 나눠서 자동으로 가중치가 조정되게 했다.
Empirically, we set m1 = 0.5, m2 = 0.25, and whyper = 0.1.
- Improved DiT
직접 뭔가 찾진않고 기존에 나온 모든 연구들을 실험해서 최적을 찾았다.
80 epoch, dopri5 integrator, NFE는 크진않은걸로 실험한 결과

오히려 이렇게 정리해주니까 더 좋다!
no cfg, bfloat16, torch.compile 사용. MaskDiT는 사용안함
This optimized architecture has been of great help in our following rapid experiment validation.
Experiment

이걸 어떻게 받아들여야하지.. 일단 dim을 늘리고 그대로 하는 것 대비 VA-VAE 형태로 했을 때 recond 퀄리티는 유지하면서 generation 퀄리티가 좋아졌다. 값이 저정도 차이나면 학습 시간은 꽤 차이난다.
근데 모델이 커져야하는건 여전히 있어서.. d16말고 d64를 쓰면서 VF Loss를 써야하는 이유가 있는건가? recon 성능이 훨씬 좋긴하지만 이게 가지는 의의를 크게 모르겠음.
일단 연구자체는 성공!

Ablation Study

결론
- 높은 dimension을 쓰면 recon 성능은 좋지만 generation 성능은 안좋고, 더 큰 모델과 많은 학습이 필요하다.
- 여기서는 high-dim의 경우에도 Vision foundation model alignment를 통해 generation이 좋아지도록 했다.
- → 그런데 여전히 low-dim(d=16)에서 generation 성능이 더 좋긴하다..
그럼 실제 상황에서 d=16을 쓰면 되는걸 굳이 d=64를 만들고 VF Loss를 써야하는 이유는? 잘 모르겠다.
d를 높이고 down factor를 높이고 싶을 때나 더 복잡하고 큰 데이터셋으로 학습할 때 이점이 있을 수도. 어쨌든 high-dim을 잘 다루는 법을 알려준건 맞으니까
Vision Foundation Model을 활용하는 좋은 예시라는건 알겠다.
Share article
Subscribe to our newsletter