买专利卖专利找龙图腾,真高效! 查专利查商标用IPTOP,全免费!专利年费监控用IP管家,真方便!
申请/专利权人:国能大渡河大数据服务有限公司
摘要:本发明提供基于阶段训练和注意力融合的多出口架构自蒸馏方法,涉及知识蒸馏领域,包括:根据深度将教师模型划分为多个出口分支,其中,教师模型和学生模型用于图像分类,多个出口分支中,深度最深的分支为教师模型,深度最浅的出口分支为学生模型,其余的出口分支为中间模型;建立总损失函数;基于多个出口分支及注意力融合算法,训练学生模型,基于总损失函数,计算总损失,基于总损失,优化学生模型,直至学生模型满足预设条件,具有提高知识传递的效率,改进知识自蒸馏框架的性能的优点。
主权项:1.基于阶段训练和注意力融合的多出口架构自蒸馏方法,其特征在于,包括:根据深度将教师模型划分为多个出口分支,其中,所述多个出口分支中,深度最深的分支为教师模型,深度最浅的出口分支为学生模型,其余的出口分支为中间模型,所述教师模型和所述学生模型均用于图像分类;建立总损失函数;基于所述多个出口分支及注意力融合算法,训练所述学生模型,基于所述总损失函数,计算总损失,基于所述总损失,优化所述学生模型,直至所述学生模型满足预设条件;其中,基于所述多个出口分支及注意力融合算法,训练所述学生模型,包括:将训练周期划分为多个阶段;在所述多个阶段,对于每个所述出口分支,通过注意力模块计算所述出口分支的注意力信息,将每个所述出口分支的注意力信息整合到学生分支中,对所述学生模型进行级联训练,并对每个所述中间模型进行级联训练,中间模型在训练周期的多个阶段中依次学习更深层次的出口分支;所述总损失函数为: 其中,L为总损失,LCEi为第i个出口分支的交叉熵损失,θi,j,t为取值为0或1的函数,用于根据阶段激活相应的转移路径,LKDi,j为第i个出口分支与第j个出口分支的蒸馏损失,LFi,M为第i个出口分支的特征图损失,M为出口分支的总数,α1、α2及α3均为权重,i和j均为用于求和的索引变量,t为阶段;基于以下公式计算第i个出口分支的交叉熵损失: 其中,N为样本总数,G为图像类别总数,ykg为第k个样本对应第g个图像类别的真实标签,pkg为第i个出口分支预测第k个样本属于第g个图像类别的概率,k为用于求和的索引变量,g为用于求和的索引变量;基于以下公式计算第i个出口分支与第j个出口分支的蒸馏损失: 其中,k为输入至第i个出口分支与第j个出口分支的样本,T为温度参数,pkgk,T为第j个出口分支在给定输入k和温度参数T下,对第g个图像类别的预测概率,qkgk,T为第i个出口分支在给定输入k和温度参数T下,对第g个图像类别的预测概率;基于以下公式计算第i个出口分支的特征图损失: 其中,β为超参数,FT为教师模型输出的特征图,FSi为第i个出口分支输出的特征图,表示计算教师模型输出的特征图与第i个出口分支输出的特征图之间的k阶欧几里得距离。
全文数据:
权利要求:
百度查询: 国能大渡河大数据服务有限公司 基于阶段训练和注意力融合的多出口架构自蒸馏方法
免责声明
1、本报告根据公开、合法渠道获得相关数据和信息,力求客观、公正,但并不保证数据的最终完整性和准确性。
2、报告中的分析和结论仅反映本公司于发布本报告当日的职业理解,仅供参考使用,不能作为本公司承担任何法律责任的依据或者凭证。