[Pyro] Bayesian Regression 해보기
2020. 9. 29. 22:39ㆍML(머신러닝)/Bayesian
pyro.ai/examples/bayesian_regression.html
pytorch에서 bayesian을 쉽게 사용할 수 있는 패키지인 Pyro이다.
예전에 처음 베이지안 코드를 봤을때 보다는 쉬워진 것 같다!
from pyro.nn import PyroModule
from pyro.nn import PyroSample
import pyro
import pyro.distributions as dist
class BayesianRegression(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
self.linear = PyroModule[nn.Linear](in_features, out_features)
self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
self.linear.bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))
def forward(self, x, y=None):
sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
mean = self.linear(x).squeeze(-1)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
return mean
from pyro.infer.autoguide import AutoDiagonalNormal
model = BayesianRegression(input_size , 1)
guide = AutoDiagonalNormal(model)
x_th , y_th = torch.tensor(train_input_x,dtype=torch.float) ,torch.tensor(y.values.reshape(-1,1).astype(np.float32),dtype=torch.float)
num_iterations = 10000
pyro.clear_param_store()
for j in range(num_iterations):
# calculate the loss and take a gradient step
loss = svi.step(x_th, y_th)
if j % 100 == 0:
print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(x_th)),end="\r")
from pyro.infer import Predictive
def summary(samples):
site_stats = {}
for k, v in samples.items():
site_stats[k] = {
"mean": torch.mean(v, 0),
"std": torch.std(v, 0),
"5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
"95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
}
return site_stats
predictive = Predictive(model, guide=guide, num_samples=800,
return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_th)
pred_summary = summary(samples)
from collections import defaultdict
from pyro import poutine
from pyro.poutine.util import prune_subsample_sites
import warnings
class Predict(torch.nn.Module):
def __init__(self, model, guide):
super().__init__()
self.model = model
self.guide = guide
def forward(self, *args, **kwargs):
samples = {}
guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
for site in prune_subsample_sites(model_trace).stochastic_nodes:
samples[site] = model_trace.nodes[site]['value']
return tuple(v for _, v in sorted(samples.items()))
predict_fn = Predict(model, guide)
predict_module = torch.jit.trace_module(predict_fn, {"forward": (x_th,)}, check_trace=False)
torch.jit.save(predict_module, './pyro_reg_predict.pt')
pred_loaded = torch.jit.load('./pyro_reg_predict.pt')
pred = pred_loaded(x_th)[2].detach().numpy().reshape(-1,1) ## ?여기가 이상
728x90
'ML(머신러닝) > Bayesian' 카테고리의 다른 글
BLiTZ — A Bayesian Neural Network LSTM 으로 주가 예측 해보기 (3) | 2020.04.19 |
---|---|
BLiTZ — A Bayesian Neural Network 해보기 (0) | 2020.04.11 |
(개인 공부) Markov Chain 정의 도박사 파산의 예시 및 다른 예시 (0) | 2020.03.29 |
Credible Interval(신용구간) , Confidence Interval(신뢰구간) 차이 (0) | 2019.11.11 |
Bayesian 가우시안 기저 모형을 활용한 Linear 예측(R) (0) | 2019.03.26 |