Tensorflow에서 tf.data 사용하는 방법
2020. 3. 22. 14:08ㆍ분석 Python/Tensorflow
tf.data에 데이터를 넣는 방법는 방법은 4가지 정도 있는 것 같다.
Data Loading
1. numpy에서 불러오기
# create a random vector of shape (100,2)
x = np.random.sample((100,2))
# make a dataset from a numpy array
dataset = tf.data.Dataset.from_tensor_slices(x)
##
features, labels = (np.random.sample((100,2)), np.random.sample((100,1)))
dataset = tf.data.Dataset.from_tensor_slices((features,labels))
2. tensor에서 불러오기
dataset = tf.data.Dataset.from_tensor_slices(tf.random_uniform([100, 2]))
3. Placeholder에서 불러오기
x = tf.placeholder(tf.float32, shape=[None,2])
dataset = tf.data.Dataset.from_tensor_slices(x)
4. generator에서 불러오기
def our_generator():
for i in range(2):
x = np.random.rand(28,28)
y = np.random.randint(1,10, size=1)
yield x,y
dataset = tf.data.Dataset.from_generator(our_generator, (tf.float32, tf.int16))
iter = dataset.make_initializable_iterator()
el = iter.get_next()
with tf.Session() as sess:
sess.run(iter.initializer)
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
csv 에서 바로 가져오기
# load a csv
total_var = ['ID', 'LIMIT_BAL', 'SEX', 'EDUCATION', 'MARRIAGE', 'AGE', 'PAY_0', 'PAY_2', 'PAY_3', 'PAY_4', 'PAY_5', 'PAY_6', 'BILL_AMT1', 'BILL_AMT2', 'BILL_AMT3', 'BILL_AMT4', 'BILL_AMT5', 'BILL_AMT6', 'PAY_AMT1', 'PAY_AMT2', 'PAY_AMT3', 'PAY_AMT4', 'PAY_AMT5', 'PAY_AMT6', 'default payment next month', 'sep_idx']
in_var = ["LIMIT_BAL", "SEX", "EDUCATION", "MARRIAGE", "AGE", "PAY_0", "PAY_2", "PAY_3", "PAY_4", "PAY_5", "PAY_6",
"BILL_AMT1", "BILL_AMT2", "BILL_AMT3", "BILL_AMT4", "BILL_AMT5", "BILL_AMT6",
"PAY_AMT1", "PAY_AMT2", "PAY_AMT3", "PAY_AMT4", "PAY_AMT5", "PAY_AMT6"]
target_var = ['default payment next month']
in_var + target_var
dataset = tf.contrib.data.make_csv_dataset(path,
column_names= total_var ,
select_columns = in_var + target_var,
label_name= target_var[0] ,
batch_size=32)
def pack_features_vector(features, labels):
"""피처를 하나의 배열로 패킹합니다."""
features = tf.stack(list(features.values()), axis=1)
print(features)
return features, labels
train_dataset = dataset.map(pack_features_vector)
iter = train_dataset.make_one_shot_iterator()
features , labels = iter.get_next()
with tf.Session() as sess:
print(sess.run(features).shape)
## (32, 23)
Iterator
tf.data에서 Iterator는 총 4가지 방법이 있다고 한다. 여기서 필자는 Initializable 로 한 것을 공유하겠다.
여기서 큰 구조는 epoch은 개별적으로 돌게 하고, batch를 tensorflow에서 설정하게 하였다.
왜냐하면 그냥 하게 되면, 결국 batch 단위로만 돌고 epoch 단위로는 돌지 않았기 때문이다.
꼭 이 방법이 정답은 아니니, 참고만 하면 될 것 같다.
X = tf.placeholder(tf.float32 , shape = [None,data.shape[1]])
y = tf.placeholder(tf.float32 , shape = [None,labels.shape[1]])
def _function(feature, label) :
"""전처리 함수"""
print(feature)
return feature , label
data_tuple = (X,y)
dataset = tf.data.Dataset.from_tensor_slices(data_tuple)
#dataset.map(_function)
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(batch_size=10, drop_remainder=True)
EPOCHS = 10
아직 완벽히 해결되지 않은 것 중에 하나는 중간에 같은 값을 계속 뽑는 것은 안되는 것 같다.
에를 들어, 2번째 배치를 2번을 중복해서 사용하고 싶은 경우
암튼 또 같은 데이터를 2개로 같이 쓰고 싶은 경우에는 dataset 초기화 하는 것을 하나 더 만들어주면 된다.
iter = dataset.make_initializable_iterator()
feature_batch , label_batch = iter.get_next()
iter2 = dataset.make_initializable_iterator()
feature_batch_2 , label_batch_2 = iter2.get_next()
with tf.Session() as sess:
for i in range(EPOCHS):
print(i)
sess.run(iter.initializer , feed_dict= {X : data.values,
y : labels.values
}) # switch to train dataset
sess.run(iter2.initializer , feed_dict= {X : data.values,
y : labels.values
}) # switch to train dataset
while True :
try :
sess.run(feature_batch)
# print(sess.run(feature_batch).shape)
# print(sess.run(feature_batch_2).shape)
except tf.errors.OutOfRangeError :
break
#참고
https://sknadig.me/TensorFlow2.0-dataset/
https://cyc1am3n.github.io/2018/09/13/how-to-use-dataset-in-tensorflow.html
728x90
'분석 Python > Tensorflow' 카테고리의 다른 글
tf.data 삽질해보기 (two iterator, feed_dict, GAN) (0) | 2020.04.08 |
---|---|
tf.data로 데이터 파이프라인 만들고 추론하는 것 까지 해보기 (0) | 2020.03.22 |
tf.stop_gradient 사용해서 학습시킬 가중치 조절해보기 (3) | 2020.03.21 |
Shared Weight AutoEncoder 구현해보기 (0) | 2020.02.24 |
[ Tensorflow ] AUC 구하기 (0) | 2020.02.21 |