博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
mxnet系列 全连接层代码阅读
阅读量:5144 次
发布时间:2019-06-13

本文共 3204 字,大约阅读时间需要 10 分钟。

全连接操作(全连接层)也具有前向和反向。代码 解析如下

virtual void Forward(const OpContext &ctx,                       const std::vector
&in_data, const std::vector
&req, const std::vector
&out_data, const std::vector
&aux_args) { using namespace mshadow; using namespace mshadow::expr; if (req[fullc::kOut] == kNullOp) return; CHECK_EQ(req[fullc::kOut], kWriteTo); size_t expected = param_.no_bias ? 2 : 3; CHECK_EQ(in_data.size(), expected); CHECK_EQ(out_data.size(), 1); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context // TODO(bing): judge shape to remove flatten op Stream
*s = ctx.get_stream
();#if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream
::OwnHandle) << "Must init CuBLAS handle in stream";#endif // __CUDACC__ const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_data[fullc::kOut].shape_; Tensor
data = in_data[fullc::kData].get_with_shape
( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor
wmat = in_data[fullc::kWeight].get
(s); //权重 Tensor
out = out_data[fullc::kOut].get_with_shape
( //输出 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); out = dot(data, wmat.T()); //点乘 if (!param_.no_bias) { Tensor
bias = in_data[fullc::kBias].get
(s); out += repmat(bias, data.size(0)); } } virtual void Backward(const OpContext &ctx, const std::vector
&out_grad, const std::vector
&in_data, const std::vector
&out_data, const std::vector
&req, const std::vector
&in_grad, const std::vector
&aux_args) { using namespace mshadow; using namespace mshadow::expr; CHECK_EQ(out_grad.size(), 1); size_t expected = param_.no_bias ? 2 : 3; CHECK(in_data.size() == expected && in_grad.size() == expected); CHECK_EQ(req.size(), expected); // TODO(bing): check the BLAS Handle, be careful // maybe need blas handle from context Stream
*s = ctx.get_stream
(); const TShape& ishape = in_data[fullc::kData].shape_; const TShape& oshape = out_grad[fullc::kOut].shape_; Tensor
data = in_data[fullc::kData].get_with_shape
( //输入 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Tensor
wmat = in_data[fullc::kWeight].get
(s); //权重 Tensor
grad = out_grad[fullc::kOut].get_with_shape
( //梯度 Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s);#if defined(__CUDACC__) CHECK_EQ(s->blas_handle_ownership_, Stream
::OwnHandle) << "Must init CuBLAS handle in stream";#endif // backprop CHECK_NE(req[fullc::kWeight], kWriteInplace) << "cannot write weight inplace"; // gradient of weight Tensor
gwmat = in_grad[fullc::kWeight].get
(s); //权重梯度 Assign(gwmat, req[fullc::kWeight], dot(grad.T(), data)); //求权重梯度 // gradient of bias if (!param_.no_bias) { Tensor
gbias = in_grad[fullc::kBias].get
(s);//偏置梯度 Assign(gbias, req[fullc::kBias], sum_rows(grad)); } // gradient of data Tensor
gdata = in_grad[fullc::kData].get_with_shape
( //输入梯度 Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); Assign(gdata, req[fullc::kData], dot(grad, wmat)); //求权重梯度 }

转载于:https://www.cnblogs.com/hellokittyblog/p/9128437.html

你可能感兴趣的文章
[Serializable]的应用--注册码的生成,加密和验证
查看>>
Android 官方新手指导教程
查看>>
幸运转盘v1.0 【附视频】我的Android原创处女作,请支持!
查看>>
安装 Express
查看>>
Weka中数据挖掘与机器学习系列之基本概念(三)
查看>>
leetcode-Sort List
查看>>
中文词频统计
查看>>
Postman-----如何导入和导出
查看>>
【Linux】ping命令详解
查看>>
8、RDD持久化
查看>>
第二次团队冲刺--2
查看>>
使用Xshell密钥认证机制远程登录Linux
查看>>
pair的例子
查看>>
uva 387 A Puzzling Problem (回溯)
查看>>
Oracle中包的创建
查看>>
django高级应用(分页功能)
查看>>
【转】Linux之printf命令
查看>>
关于PHP会话:session和cookie
查看>>
C#double转化成字符串 保留小数位数, 不以科学计数法的形式出现。
查看>>
利用IP地址查询接口来查询IP归属地
查看>>