[ Tensorflow ] Binary, MutliClass Loss

2019. 11. 9. 20:44분석 Python/Tensorflow

728x90

코딩을 하다 보니 Loss를 사용하는 데 있어서 헷갈리는 부분이 있어서 따로 정리를 하고 싶어 짐!

코드는 주로 사용하는 Tensorflow 1.4 위주로... Tensorflow 2.0 은 추후에...

CrossEntroy vs Balanced CrossEntropy vs Focal Loss

Binary

CrossEntropy & Weight Cross Entropy

one_hot = (?,1)
logits  = (?,1) 
weight = tf.constant( [2.0] )
## Wegith CrossEntropy
WCE = tf.nn.weighted_cross_entropy_with_logits( targets = one_hot , logits = logits , pos_weight =  weight)
## CrossEntropy
CE = tf.nn.sigmoid_cross_entropy_with_logits(logits = logits , labels = one_hot )

https://alexisalulema.com/2017/12/15/classification-loss-functions-part-ii/

 

Classification Loss Functions (Part II)

In my previous post, I mentioned 3 loss functions, which are mostly intended to be used in Regression models. This time, I’m going to talk about Classification Loss Functions, which are going to be…

alexisalulema.com

imbalanced 상황에서 쓸 수 있는 weight cross entropy

결국 pos_weight 적은 class에 가중치를 주는 것이기 때문에 

pos_weight > 1 더 작은 클래스를 잘 맞추고자 하는 것이고 recall 증가

반대는 precision 증가

cross entropy
weight cross entropy

 

Multiclass

softmax_cross_entropy_with_logits & sparse_softmax_cross_entropy

## softmax cross entropy
one_hot = (?,3)
logits = (?,3)
tf.nn.softmax_cross_entropy_with_logits(logits = logits  , labels = one_hot))

## sparse softmax cross entropy
class_weights = tf.constant([0.6,0.3,0.1])
one_hot = (?,3)
logits = (?,3)
labels = tf.argmax(one_hot , axis = 0 ) 
weights = tf.gather(class_weights,  labels)
SparseCE = tf.losses.sparse_softmax_cross_entropy(labels, logits , weights)

 

 

https://code-examples.net/ko/q/2a7f0a5

 

python 언밸런스 데이터 및 가중 크로스 엔트로피

나는 불균형 한 데이터로 네트워크를 훈련시키고 자 노력 중이다. 나는 A(198 샘플), B(436 샘플), C(710 샘플), D(272 샘플) 그리고 "weighted_cross_entropy_with_logits"에 대해 읽었지만, 내가 찾은 모든 예제는 바이너리 분류를위한 것이다. 그 무게를 설정하는 방법에 자신감. 총 샘플:1616 A_weight

code-examples.net

 

아래는 Focal Loss도 이야기 나옴

Focal Loss 같은 경우 Imbalanced 데이터에 적용 가능함.

https://gombru.github.io/2018/05/23/cross_entropy_loss/

 

Understanding Categorical Cross-Entropy Loss, Binary Cross-Entropy Loss, Softmax Loss, Logistic Loss, Focal Loss and all those c

People like to use cool names which are often confusing. When I started playing with CNN beyond single label classification, I got confused with the different names and formulations people write in their papers, and even with the loss layer names of the de

gombru.github.io

## 아마 맞을 듯
def focal_loss_sigmoid(labels,logits,alpha=0.25 , gamma=2):
    y_pred=tf.nn.sigmoid(logits)
    labels=tf.to_float(labels)
    L=-labels*(1-alpha)*((1-y_pred)*gamma)*tf.log(  tf.maximum(y_pred , 1e-14 )   )-\
    (1-labels)*alpha*(y_pred**gamma)*tf.log( tf.maximum( 1-y_pred ,  1e-14 ) ) 
    return L
    
def focal_loss_softmax(labels,logits,gamma=2):
    """
    Computer focal loss for multi classification
    Args:
      labels: A int32 tensor of shape [batch_size].
      logits: A float32 tensor of shape [batch_size,num_classes].
      gamma: A scalar for focal loss gamma hyper-parameter.
    Returns:
      A tensor of the same shape as `lables`
    """
    y_pred=tf.nn.softmax(logits,dim=-1) # [batch_size,num_classes]
    labels=tf.one_hot(labels,depth=y_pred.shape[1])
    L=-labels*((1-y_pred)**gamma)*tf.log(y_pred)
    L=tf.reduce_sum(L,axis=1)
    return L
    
## test
if __name__ == '__main__':
    logits=tf.random_uniform(shape=[5],minval=-1,maxval=1,dtype=tf.float32)
    labels=tf.Variable([0,1,0,0,1])
    loss1=focal_loss_sigmoid(labels=labels,logits=logits)

    logits2=tf.random_uniform(shape=[5,4],minval=-1,maxval=1,dtype=tf.float32)
    labels2=tf.Variable([1,0,2,3,1])
    loss2=focal_loss_softmax(labels==labels2,logits=logits2)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        print sess.run(loss1)
        print sess.run(loss2)

https://github.com/fudannlp16/focal-loss/blob/master/focal_loss.py

 

fudannlp16/focal-loss

Tensorflow version implementation of focal loss for binary and multi classification - fudannlp16/focal-loss

github.com

참고

https://m.blog.naver.com/PostView.nhn?blogId=sogangori&logNo=221087066947&proxyReferer=https%3A%2F%2Fwww.google.com%2F

 

Focal Loss for Dense Object Detection

Focal Loss for Dense Object Detection Tsung-Yi Lin Priya Goyal Ross Girshick Kaiming H...

blog.naver.com

728x90