[ Tensorflow ] Binary, MutliClass Loss
2019. 11. 9. 20:44ㆍ분석 Python/Tensorflow
코딩을 하다 보니 Loss를 사용하는 데 있어서 헷갈리는 부분이 있어서 따로 정리를 하고 싶어 짐!
코드는 주로 사용하는 Tensorflow 1.4 위주로... Tensorflow 2.0 은 추후에...
Binary
CrossEntropy & Weight Cross Entropy
one_hot = (?,1)
logits = (?,1)
weight = tf.constant( [2.0] )
## Wegith CrossEntropy
WCE = tf.nn.weighted_cross_entropy_with_logits( targets = one_hot , logits = logits , pos_weight = weight)
## CrossEntropy
CE = tf.nn.sigmoid_cross_entropy_with_logits(logits = logits , labels = one_hot )
https://alexisalulema.com/2017/12/15/classification-loss-functions-part-ii/
imbalanced 상황에서 쓸 수 있는 weight cross entropy
결국 pos_weight 적은 class에 가중치를 주는 것이기 때문에
pos_weight > 1 더 작은 클래스를 잘 맞추고자 하는 것이고 recall 증가
반대는 precision 증가
Multiclass
softmax_cross_entropy_with_logits & sparse_softmax_cross_entropy
## softmax cross entropy
one_hot = (?,3)
logits = (?,3)
tf.nn.softmax_cross_entropy_with_logits(logits = logits , labels = one_hot))
## sparse softmax cross entropy
class_weights = tf.constant([0.6,0.3,0.1])
one_hot = (?,3)
logits = (?,3)
labels = tf.argmax(one_hot , axis = 0 )
weights = tf.gather(class_weights, labels)
SparseCE = tf.losses.sparse_softmax_cross_entropy(labels, logits , weights)
https://code-examples.net/ko/q/2a7f0a5
아래는 Focal Loss도 이야기 나옴
Focal Loss 같은 경우 Imbalanced 데이터에 적용 가능함.
https://gombru.github.io/2018/05/23/cross_entropy_loss/
## 아마 맞을 듯
def focal_loss_sigmoid(labels,logits,alpha=0.25 , gamma=2):
y_pred=tf.nn.sigmoid(logits)
labels=tf.to_float(labels)
L=-labels*(1-alpha)*((1-y_pred)*gamma)*tf.log( tf.maximum(y_pred , 1e-14 ) )-\
(1-labels)*alpha*(y_pred**gamma)*tf.log( tf.maximum( 1-y_pred , 1e-14 ) )
return L
def focal_loss_softmax(labels,logits,gamma=2):
"""
Computer focal loss for multi classification
Args:
labels: A int32 tensor of shape [batch_size].
logits: A float32 tensor of shape [batch_size,num_classes].
gamma: A scalar for focal loss gamma hyper-parameter.
Returns:
A tensor of the same shape as `lables`
"""
y_pred=tf.nn.softmax(logits,dim=-1) # [batch_size,num_classes]
labels=tf.one_hot(labels,depth=y_pred.shape[1])
L=-labels*((1-y_pred)**gamma)*tf.log(y_pred)
L=tf.reduce_sum(L,axis=1)
return L
## test
if __name__ == '__main__':
logits=tf.random_uniform(shape=[5],minval=-1,maxval=1,dtype=tf.float32)
labels=tf.Variable([0,1,0,0,1])
loss1=focal_loss_sigmoid(labels=labels,logits=logits)
logits2=tf.random_uniform(shape=[5,4],minval=-1,maxval=1,dtype=tf.float32)
labels2=tf.Variable([1,0,2,3,1])
loss2=focal_loss_softmax(labels==labels2,logits=logits2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(loss1)
print sess.run(loss2)
https://github.com/fudannlp16/focal-loss/blob/master/focal_loss.py
참고
728x90
'분석 Python > Tensorflow' 카테고리의 다른 글
Tensorflow Adanet Tabular Data 적용해보기 (0) | 2019.12.29 |
---|---|
remote server 로부터 Tensorboard 사용하는 방법 (6) | 2019.12.28 |
TensorFlow 유용한 구현 모음 (아직 테스트는 안해봄) (0) | 2019.10.27 |
TensorFlow gpu cuda 설치 공식 문서 (Windows / Ubuntu 16.04 ,18.04) (0) | 2019.10.03 |
tensorflow에서 Loss 가 nan 발생한 경우 정리 (개인 생각) (0) | 2019.09.28 |