[Pytorch] MixtureSameFamily 을 사용해서 bimodal distribution 만들기

2020. 5. 5. 23:33분석 Python/Pytorch

728x90

torch 1.5.0 버전부터 생김

 

bimodal gaussian distribution 만들기

n = 1000
import torch
import seaborn as sns
from torch.distributions import Normal
from torch import distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily


d = torch.cat((Normal(loc=-3, scale =1.0).sample((n,1)) , 
               Normal(loc=0.5, scale =1.0).sample((n,1))),dim=1)
mix = D.Categorical(probs = torch.nn.Softmax()(torch.rand(n,2)))
comp = D.Independent(D.Normal(d.unsqueeze(dim=2) , 
                              Normal(loc=0, scale =0.5).sample((n,2,1))), 1)
gmm = MixtureSameFamily(mix, comp)
sns.distplot(gmm.sample().detach().numpy().reshape(-1,1))

확장하면 쉽게 n-modal이 가능하다.

trimodal 만들기

n = 1000
import torch
import seaborn as sns
from torch.distributions import Normal
from torch import distributions as D
from torch.distributions.mixture_same_family import MixtureSameFamily


d = torch.cat((Normal(loc=-3, scale =1.0).sample((n,1)) , 
               Normal(loc=0.5, scale =1.0).sample((n,1)),
               Normal(loc=5.5, scale =1.0).sample((n,1))
              ),dim=1)
mix = D.Categorical(probs = torch.nn.Softmax()(torch.rand(n,3)))
comp = D.Independent(D.Normal(d.unsqueeze(dim=2) , 
                              Normal(loc=0, scale =0.5).sample((n,3,1))), 1)
gmm = MixtureSameFamily(mix, comp)
sns.distplot(gmm.sample().detach().numpy().reshape(-1,1))

728x90