pytorch) dataloader sampler

2021. 4. 19. 23:53분석 Python/Pytorch

728x90

dataloader https://blog.csdn.net/loveliuzz/article/details/108756253

OverSampler / StratifiedSampler 구현물

 

OverSampler는 다른 코드를 참고해서 수정해봤습니다.

 

OverSampler

from torch.utils.data import Sampler
class OverSampler(Sampler):
    """Over Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, class_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(class_vector.size(0) / batch_size)
        self.class_vector = class_vector
        self.indices = list(range(len(self.class_vector)))
        self.batch_size =batch_size
        uni_label = torch.unique(class_vector) 
        uni_label_sorted, _  = torch.sort(uni_label)
        print(uni_label_sorted)
        uni_label_sorted = uni_label_sorted.detach().numpy()
        label_bin = torch.bincount(class_vector.int()).detach().numpy()
        label_to_count = dict(zip(uni_label_sorted , label_bin))
        weights = [ len(class_vector) / label_to_count[float(label)] for label in           class_vector]
        self.weights = torch.DoubleTensor(weights)

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.batch_size, replacement=True))

    def __len__(self):
        return len(self.class_vector)

StratifiedSampler


class StratifiedSampler(Sampler):
    """Stratified Sampling
    Provides equal representation of target classes in each batch
    """

    def __init__(self, class_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(class_vector.size(0) / batch_size)
        self.class_vector = class_vector

    def gen_sample_array(self):
        try:
            from sklearn.model_selection import StratifiedShuffleSplit
        except:
            print("Need scikit-learn for this functionality")
        import numpy as np

        s = StratifiedShuffleSplit(n_splits=self.n_splits, test_size=0.7)
        X = torch.randn(self.class_vector.size(0), 2).numpy()
        y = self.class_vector.numpy()
        s.get_n_splits(X, y)

        train_index, test_index = next(s.split(X, y))
        return np.hstack([train_index, test_index])

    def __iter__(self):
        return iter(self.gen_sample_array())

    def __len__(self):
        return len(self.class_vector)
from sklearn.utils.class_weight import compute_class_weight
class StratifiedSampler(Sampler):
    """Over Sampling
    Provides equal representation of target classes in each batch
    """
    def __init__(self, class_vector, batch_size):
        """
        Arguments
        ---------
        class_vector : torch tensor
            a vector of class labels
        batch_size : integer
            batch_size
        """
        self.n_splits = int(class_vector.size(0) / batch_size)
        self.class_vector = class_vector
        self.indices = list(range(len(self.class_vector)))
        self.batch_size =batch_size
        target_class = np.unique(class_vector.detach().cpu().numpy())
        target_weight = compute_class_weight(
            class_weight='balanced',
            classes=target_class , y=class_vector.detach().cpu().numpy())
        self.mapping_sampling_weights = dict(zip(target_class , target_weight))
        self.mapping_sampling_weights[1] = self.mapping_sampling_weights[1] / 2
        weights = [ self.mapping_sampling_weights[int(label)] for label in class_vector]
        self.weights = torch.DoubleTensor(weights)

    def __iter__(self):
        return (self.indices[i] for i in torch.multinomial(
            self.weights, self.batch_size, replacement=True))

    def __len__(self):
        return len(self.class_vector)

 

 

 

How to use

traindataset = ## DataSet Custom Class
trainsampler = OverSampler(class_vector=torch.from_numpy(label.values.squeeze()),
                                 batch_size=batch_size)
trainloader = DataLoader(traindataset, batch_size=batch_size, shuffle=False,
                         sampler=trainsampler,num_workers =5,drop_last= True)
728x90