tensorflow 分类损失函数问题(有点坑)

el/2024/4/19 23:30:31

tf.nn.softmax_cross_entropy_with_logits(记为f1) 和
tf.nn.sparse_softmax_cross_entropy_with_logits(记为f3),以及
tf.nn.softmax_cross_entropy_with_logits_v2(记为f2)
之间的区别。

f1和f3对于参数logits的要求都是一样的,即未经处理的,直接由神经网络输出的数值, 比如 [3.5,2.1,7.89,4.4]。两个函数不一样的地方在于labels格式的要求,f1的要求labels的格式和logits类似,比如[0,0,1,0]。而f3的要求labels是一个数值,这个数值记录着ground truth所在的索引。以[0,0,1,0]为例,这里真值1的索引为2。所以f3要求labels的输入为数字2(tensor)。一般可以用tf.argmax()来从[0,0,1,0]中取得真值的索引。

f1和f2之间很像,实际上官方文档已经标记出f1已经是deprecated 状态,推荐使用f2。两者唯一的区别在于f1在进行反向传播的时候,只对logits进行反向传播,labels保持不变。而f2在进行反向传播的时候,同时对logits和labels都进行反向传播,如果将labels传入的tensor设置为stop_gradients,就和f1一样了。
那么问题来了,一般我们在进行监督学习的时候,labels都是标记好的真值,什么时候会需要改变label?f2存在的意义是什么?实际上在应用中labels并不一定都是人工手动标注的,有的时候还可能是神经网络生成的,一个实际的例子就是对抗生成网络(GAN)。

测试用代码:

import tensorflow as tf
import numpy as np

Truth = np.array([0,0,1,0])
Pred_logits = np.array([3.5,2.1,7.89,4.4])

loss = tf.nn.softmax_cross_entropy_with_logits(labels=Truth,logits=Pred_logits)
loss2 = tf.nn.softmax_cross_entropy_with_logits_v2(labels=Truth,logits=Pred_logits)
loss3 = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(Truth),logits=Pred_logits)

with tf.Session() as sess:
    print(sess.run(loss))
    print(sess.run(loss2))
    print(sess.run(loss3))

 

参考:

    https://www.tensorflow.org/api_docs/
    https://stats.stackexchange.com/questions/327348/how-is-softmax-cross-entropy-with-logits-different-from-softmax-cross-entropy-wi
---------------------  
作者:史丹利复合田  
来源:CSDN  
原文:https://blog.csdn.net/tsyccnh/article/details/81069308  
版权声明:本文为博主原创文章,转载请附上博文链接!


http://www.ngui.cc/el/5181836.html

相关文章

OpenMP并行程序设计——for循环并行化详解

转载请声明出处http://blog.csdn.net/zhongkejingwang/article/details/40018735 在C/C中使用OpenMP优化代码方便又简单,代码中需要并行处理的往往是一些比较耗时的for循环,所以重点介绍一下OpenMP中for循环的应用。个人感觉只要掌握了文中讲的这些就足…

音视频技术网站

http://blog.yundiantech.com/

c++ 函数式编程(简单示例)

C中一个函数作为作为另一个函数的参数 2016年12月17日 15:59:36 initiallysunny 阅读数:13266 版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/Initiallysunny/article/details/53708466 C中一个函数作为作为…

._bootstrap' has no attribute 'SourceFileLoader' 和 'socketio' has no attribute 'Server' 分析解决

之前运行别人的代码,报错缺少各种包,于是直接pip install安装,后来发现,报下面两个错误,很是纠结,网上查阅资料都不能正确的解决问题。 File "/usr/local/lib/python3.6/dist-packages/pkg_resources…

x264 参数详解

https://blog.csdn.net/zhubosa/article/details/51321783

opencv reshape 深拷贝 浅拷贝之坑

今天学习reshape遇见了一个坎,浪费了不少时间,希望后学者不要未该问题浪费过多时间。通常情况下,Opencv 的reshape函数跟Matlab是一致的。A.reshape(0,N),代表通道不变,行数变为N的变形。但是即便上两个参数没有问题&a…

x264_stack_align 对齐函数

看到x264中对于字节对齐的函数x264_stack_align( x264_slice_write, h ),为什么要字节对齐呢?因为x264中用到的指令集优化SSE2,而SSE2寄存器是128位寄存器,SSE2的指令是对16字节(128/8)同时处理&#xff0c…

x264 理论与代码系列

https://www.cnblogs.com/TaigaCon/category/1189649.html

统计数据(包括金融、气候、交通等,全面的医疗数据,如EEG、ECG、血压等)

金融、气候、交通等统计数据: https://datamarket.com/data/list/?qprovider:tsdl 权威的医疗数据(EEG、ECG、血压,各种医学影像资料等): https://physionet.org/physiobank/database/ (格式为.dat&am…

反卷积中( conv2d_transpose)的stride参数

https://blog.csdn.net/u012938704/article/details/52838902 https://blog.csdn.net/qq_38906523/article/details/80520950