博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
损失函数Center Loss 代码解析
阅读量:6050 次
发布时间:2019-06-20

本文共 1227 字,大约阅读时间需要 4 分钟。

center loss来自ECCV2016的一篇论文:A Discriminative Feature Learning Approach for Deep Face Recognition。 

论文链接: 
代码链接:

理论解析请参看 https://blog.csdn.net/u014380165/article/details/76946339

下面给出centerloss的计算公式以及更新公式

 

 

下面的代码是facenet作者利用tensorflow实现的centerloss代码

def center_loss(features, label, alfa, nrof_classes):    """Center loss based on the paper "A Discriminative Feature Learning Approach for Deep Face Recognition"       (http://ydwen.github.io/papers/WenECCV16.pdf)       https://blog.csdn.net/u014380165/article/details/76946339    """    nrof_features = features.get_shape()[1]   #训练过程中,需要保存当前所有类中心的全连接预测特征centers, 每个batch的计算都要先读取已经保存的centers    centers = tf.get_variable('centers', [nrof_classes, nrof_features], dtype=tf.float32,        initializer=tf.constant_initializer(0), trainable=False)     label = tf.reshape(label, [-1])    centers_batch = tf.gather(centers, label)#获取当前batch对应的类中心特征    diff = (1 - alfa) * (centers_batch - features)#计算当前的类中心与特征的差异,用于Cj的的梯度更新,这里facenet的作者做了一个 1-alfa操作,比较奇怪,和原论文不同    centers = tf.scatter_sub(centers, label, diff)#更新梯度Cj,对于上图中步骤6,tensorflow会将该变量centers保留下来,用于计算下一个batch的centerloss    loss = tf.reduce_mean(tf.square(features - centers_batch))#计算当前的centerloss 对应于Lc    return loss, centers

 

你可能感兴趣的文章
使用 openSSL 实现CA
查看>>
【SCCM排错篇】手动注册SPN提示权限不足
查看>>
TypeScript基础入门 - 泛型 - 泛型类
查看>>
python设计模式(二)--策略模式(中)
查看>>
CrontrolTier 项目
查看>>
CSS基础学习
查看>>
配置iptables只允许访问服务器的固定端口
查看>>
C++著名类库
查看>>
企业单点登录解决方案(CAS)之三安全指南
查看>>
Java Robot对象实现服务器屏幕远程监视
查看>>
用报表软件自定义地图
查看>>
CentOS7.4安装Gitlab10.5.1及,汉化,修改端口,url,安装runner
查看>>
开源流媒体系统:OBS ( Open Broadcaster Software ) 介绍
查看>>
如何绕过安卓SSL证书的强校验
查看>>
haproxy根据用户客户端做ACL的文件例子
查看>>
Linux系统运行级管理
查看>>
在选择数据库的路上,我们遇到过哪些坑?(1)
查看>>
微服务扩展新途径:Messaging
查看>>
Windows 7样式地址栏(Address Bar)控件实现
查看>>
[ffmpeg]通过Qt调用ffmpeg命令
查看>>