From 47ab68d265a96b6e7be06afd1b4b47e0114c0ee9 Mon Sep 17 00:00:00 2001 From: Anudhyan Boral Date: Fri, 29 Mar 2019 13:30:57 -0700 Subject: [PATCH] Add broadcasting support to tf.matmul. Add Numpy-style broadcasting in the batch dimensions for tf.matmul op. The last two dimensions of both operands constitute the matrix dimensions. The dimensions beyond these are broadcasted to form a common output shape with the standard NumPy broadcasting rules. (https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html) Note: This implementation differs from Numpy's behavior in that vectors (rank-1 Tensors) are not promoted to matrices (rank-2 Tensors) by appending/prepending dimensions. PiperOrigin-RevId: 241040476 --- tensorflow/contrib/makefile/tf_op_files.txt | 1 + .../base_api/api_def_BatchMatMulV2.pbtxt | 59 ++++ .../python_api/api_def_BatchMatMulV2.pbtxt | 4 + tensorflow/core/framework/common_shape_fns.cc | 37 ++ tensorflow/core/framework/common_shape_fns.h | 4 + .../core/framework/common_shape_fns_test.cc | 68 ++++ tensorflow/core/kernels/BUILD | 17 + .../core/kernels/batch_matmul_op_common.cc | 76 ++++ .../core/kernels/batch_matmul_op_common.h | 70 ++++ .../kernels/batch_matmul_op_common_test.cc | 138 ++++++++ .../core/kernels/batch_matmul_op_impl.h | 331 +++++++++++++----- .../core/kernels/batch_matmul_op_test.cc | 124 ++++++- tensorflow/core/ops/math_ops.cc | 11 + tensorflow/python/BUILD | 1 + .../kernel_tests/batch_matmul_op_test.py | 269 +++++++++----- tensorflow/python/ops/math_grad.py | 38 ++ tensorflow/python/ops/math_ops.py | 19 +- .../api/golden/v1/tensorflow.raw_ops.pbtxt | 4 + .../api/golden/v2/tensorflow.raw_ops.pbtxt | 4 + 19 files changed, 1093 insertions(+), 182 deletions(-) create mode 100644 tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt create mode 100644 tensorflow/core/api_def/python_api/api_def_BatchMatMulV2.pbtxt create mode 100644 tensorflow/core/kernels/batch_matmul_op_common.cc create mode 100644 tensorflow/core/kernels/batch_matmul_op_common.h create mode 100644 tensorflow/core/kernels/batch_matmul_op_common_test.cc diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index a9a6053a272..07275436dd7 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -8,6 +8,7 @@ tensorflow/contrib/boosted_trees/ops/training_ops.cc tensorflow/core/kernels/aggregate_ops.cc tensorflow/core/kernels/argmax_op.cc tensorflow/core/kernels/avgpooling_op.cc +tensorflow/core/kernels/batch_matmul_op_common.cc tensorflow/core/kernels/batch_matmul_op_real.cc tensorflow/core/kernels/batch_norm_op.cc tensorflow/core/kernels/batchtospace_op.cc diff --git a/tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt new file mode 100644 index 00000000000..b1ccd38d748 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_BatchMatMulV2.pbtxt @@ -0,0 +1,59 @@ +op { + graph_op_name: "BatchMatMulV2" + in_arg { + name: "x" + description: <