From d7f7cbef61448f6f084757c531070f9306e17f43 Mon Sep 17 00:00:00 2001 From: Zhanghao Chen Date: Thu, 20 Apr 2017 20:01:51 +0800 Subject: [PATCH] add gradient of MXNet operator batch_dot --- minpy/array_variants/mxnet/mxnet_core.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/minpy/array_variants/mxnet/mxnet_core.py b/minpy/array_variants/mxnet/mxnet_core.py index 6716d74..49343d6 100644 --- a/minpy/array_variants/mxnet/mxnet_core.py +++ b/minpy/array_variants/mxnet/mxnet_core.py @@ -112,6 +112,10 @@ def def_grads(prims): prims('dot').def_grad(lambda ans, a, b: lambda g: mx.nd.dot(g, b, transpose_b=True)) prims('dot').def_grad( lambda ans, a, b: lambda g: mx.nd.dot(a, g, transpose_a=True), argnum=1) + # batch_dot + prims('batch_dot').def_grad(lambda ans, a, b: lambda g: mx.nd.batch_dot(g, b, transpose_b=True)) + prims('batch_dot').def_grad( + lambda ans, a, b: lambda g: mx.nd.batch_dot(a, g, transpose_a=True), argnum=1) # non-linear prims('tanh').def_grad(lambda ans, x: lambda g: g * (1 - ans ** 2)) prims('exp').def_grad(lambda ans, x: lambda g: g * ans)