[Pyro] Bayesian Regression 해보기

2020. 9. 29. 22:39ML(머신러닝)/Bayesian

pyro.ai/examples/bayesian_regression.html

 

Bayesian Regression - Introduction (Part 1) — Pyro Tutorials 1.4.0 documentation

Model In order to make our linear regression Bayesian, we need to put priors on the parameters \(w\) and \(b\). These are distributions that represent our prior belief about reasonable values for \(w\) and \(b\) (before observing any data). Making a Bayesi

pyro.ai

pytorch에서 bayesian을 쉽게 사용할 수 있는 패키지인 Pyro이다.

예전에 처음 베이지안 코드를 봤을때 보다는 쉬워진 것 같다!

from pyro.nn import PyroModule
from pyro.nn import PyroSample
import pyro
import pyro.distributions as dist
class BayesianRegression(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        self.linear = PyroModule[nn.Linear](in_features, out_features)
        self.linear.weight = PyroSample(dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2))
        self.linear.bias = PyroSample(dist.Normal(0., 1.).expand([out_features]).to_event(1))

    def forward(self, x, y=None):
        sigma = pyro.sample("sigma", dist.Uniform(0., 10.))
        mean = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
        return mean
from pyro.infer.autoguide import AutoDiagonalNormal
model  = BayesianRegression(input_size , 1)
guide = AutoDiagonalNormal(model)
x_th , y_th = torch.tensor(train_input_x,dtype=torch.float) ,torch.tensor(y.values.reshape(-1,1).astype(np.float32),dtype=torch.float)
num_iterations = 10000
pyro.clear_param_store()
for j in range(num_iterations):
    # calculate the loss and take a gradient step
    loss = svi.step(x_th, y_th)
    if j % 100 == 0:
        print("[iteration %04d] loss: %.4f" % (j + 1, loss / len(x_th)),end="\r")
        
from pyro.infer import Predictive
def summary(samples):
    site_stats = {}
    for k, v in samples.items():
        site_stats[k] = {
            "mean": torch.mean(v, 0),
            "std": torch.std(v, 0),
            "5%": v.kthvalue(int(len(v) * 0.05), dim=0)[0],
            "95%": v.kthvalue(int(len(v) * 0.95), dim=0)[0],
        }
    return site_stats
predictive = Predictive(model, guide=guide, num_samples=800,
                        return_sites=("linear.weight", "obs", "_RETURN"))
samples = predictive(x_th)
pred_summary = summary(samples)

from collections import defaultdict
from pyro import poutine
from pyro.poutine.util import prune_subsample_sites
import warnings


class Predict(torch.nn.Module):
    def __init__(self, model, guide):
        super().__init__()
        self.model = model
        self.guide = guide

    def forward(self, *args, **kwargs):
        samples = {}
        guide_trace = poutine.trace(self.guide).get_trace(*args, **kwargs)
        model_trace = poutine.trace(poutine.replay(self.model, guide_trace)).get_trace(*args, **kwargs)
        for site in prune_subsample_sites(model_trace).stochastic_nodes:
            samples[site] = model_trace.nodes[site]['value']
        return tuple(v for _, v in sorted(samples.items()))

predict_fn = Predict(model, guide)
predict_module = torch.jit.trace_module(predict_fn, {"forward": (x_th,)}, check_trace=False)
torch.jit.save(predict_module, './pyro_reg_predict.pt')
pred_loaded = torch.jit.load('./pyro_reg_predict.pt')
pred = pred_loaded(x_th)[2].detach().numpy().reshape(-1,1) ## ?여기가 이상
728x90