Tensorflow에서 tf.data 사용하는 방법

2020. 3. 22. 14:08분석 Python/Tensorflow

728x90

tf.data에 데이터를 넣는 방법는 방법은 4가지 정도 있는 것 같다.

Data Loading

1. numpy에서 불러오기

# create a random vector of shape (100,2)
x = np.random.sample((100,2))
# make a dataset from a numpy array
dataset = tf.data.Dataset.from_tensor_slices(x)
##
features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))
dataset = tf.data.Dataset.from_tensor_slices((features,labels))

2. tensor에서 불러오기

dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([100, 2]))

3. Placeholder에서 불러오기

x = tf.placeholder(tf.float32, shape=[None,2])
dataset = tf.data.Dataset.from_tensor_slices(x)

4. generator에서 불러오기

def our_generator():
    for i in range(2):
        x = np.random.rand(28,28)
        y = np.random.randint(1,10, size=1)
        yield x,y
dataset = tf.data.Dataset.from_generator(our_generator, (tf.float32, tf.int16))
iter = dataset.make_initializable_iterator()
el = iter.get_next()
with tf.Session() as sess:
    sess.run(iter.initializer)
    print(sess.run(el))
    print(sess.run(el))
    print(sess.run(el))

csv 에서 바로 가져오기

# load a csv
total_var = ['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'default payment next month', 'sep_idx']
in_var = ["LIMIT_BAL", "SEX", "EDUCATION", "MARRIAGE", "AGE", "PAY_0", "PAY_2", "PAY_3", "PAY_4", "PAY_5", "PAY_6",
          "BILL_AMT1", "BILL_AMT2", "BILL_AMT3", "BILL_AMT4", "BILL_AMT5", "BILL_AMT6",
          "PAY_AMT1", "PAY_AMT2", "PAY_AMT3", "PAY_AMT4", "PAY_AMT5", "PAY_AMT6"]
target_var = ['default payment next month']
in_var + target_var

dataset = tf.contrib.data.make_csv_dataset(path, 
                                           column_names= total_var , 
                                           select_columns = in_var + target_var,
                                           label_name= target_var[0] , 
                                           batch_size=32)
def pack_features_vector(features, labels):
    """피처를 하나의 배열로 패킹합니다."""
    features = tf.stack(list(features.values()), axis=1)
    print(features)
    return features, labels
train_dataset = dataset.map(pack_features_vector)
iter = train_dataset.make_one_shot_iterator()
features , labels = iter.get_next()
with  tf.Session() as sess:
    print(sess.run(features).shape)
    
## (32, 23)

 

Iterator

tf.data에서 Iterator는 총 4가지 방법이 있다고 한다. 여기서 필자는 Initializable 로 한 것을 공유하겠다.

https://cyc1am3n.github.io/2018/09/13/how-to-use-dataset-in-tensorflow.html

여기서 큰 구조는 epoch은 개별적으로 돌게 하고, batch를 tensorflow에서 설정하게 하였다.
왜냐하면 그냥 하게 되면, 결국 batch 단위로만 돌고 epoch 단위로는 돌지 않았기 때문이다.
꼭 이 방법이 정답은 아니니, 참고만 하면 될 것 같다.

X = tf.placeholder(tf.float32 , shape = [None,data.shape[1]])
y = tf.placeholder(tf.float32 , shape = [None,labels.shape[1]])

def _function(feature, label) :
    """전처리 함수"""
    print(feature)
    return feature , label
    
data_tuple = (X,y)
dataset = tf.data.Dataset.from_tensor_slices(data_tuple)
#dataset.map(_function)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(batch_size=10, drop_remainder=True)
EPOCHS = 10


아직 완벽히 해결되지 않은 것 중에 하나는 중간에 같은 값을 계속 뽑는 것은 안되는 것 같다. 
에를 들어, 2번째 배치를 2번을 중복해서 사용하고 싶은 경우

암튼 또 같은 데이터를 2개로 같이 쓰고 싶은 경우에는 dataset 초기화 하는 것을 하나 더 만들어주면 된다.

iter = dataset.make_initializable_iterator()
feature_batch , label_batch = iter.get_next()

iter2 = dataset.make_initializable_iterator()
feature_batch_2 , label_batch_2 = iter2.get_next()

with tf.Session() as sess:
    for i in range(EPOCHS):
        print(i)
        sess.run(iter.initializer , feed_dict= {X : data.values,
                                            y : labels.values
                                           }) # switch to train dataset
        sess.run(iter2.initializer , feed_dict= {X : data.values,
                                                y : labels.values
                                               }) # switch to train dataset
        while True :
            try :
                sess.run(feature_batch)
#                 print(sess.run(feature_batch).shape)
#                 print(sess.run(feature_batch_2).shape)
            except tf.errors.OutOfRangeError :
                break

 


#참고

https://sknadig.me/TensorFlow2.0-dataset/

 

Tensorflow 2.0 tf.data.Dataset.from_generator

My findings about the new TensorFlow 2.0 Dataset API

sknadig.me

https://cyc1am3n.github.io/2018/09/13/how-to-use-dataset-in-tensorflow.html

 

TensorFlow에서 Dataset을 사용하는 방법

The built-in Input Pipeline. Never use ‘feed-dict’ anymore

cyc1am3n.github.io

 

728x90