From b199a9709793f0c30149218b840bd23a00b6448f Mon Sep 17 00:00:00 2001
From: Anudhyan Boral <anudhyan@google.com>
Date: Tue, 23 Apr 2019 16:11:15 -0700
Subject: [PATCH] Move MatMulBCast class to core/util.

Export it under the core/framework target (same as core/util/bcast.h) instead of core/kernels:batch_matmul_op.

As an aside, this allows TFLite use this class without adding a dependency on core/kernels when it just needs the util.

PiperOrigin-RevId: 244945168
---
 tensorflow/contrib/makefile/tf_op_files.txt      |  1 -
 tensorflow/core/BUILD                            |  2 ++
 tensorflow/core/kernels/BUILD                    | 16 ----------------
 tensorflow/core/kernels/batch_matmul_op_impl.h   |  2 +-
 .../matmul_bcast.cc}                             |  2 +-
 .../matmul_bcast.h}                              |  6 +++---
 .../matmul_bcast_test.cc}                        |  2 +-
 7 files changed, 8 insertions(+), 23 deletions(-)
 rename tensorflow/core/{kernels/batch_matmul_op_common.cc => util/matmul_bcast.cc} (98%)
 rename tensorflow/core/{kernels/batch_matmul_op_common.h => util/matmul_bcast.h} (93%)
 rename tensorflow/core/{kernels/batch_matmul_op_common_test.cc => util/matmul_bcast_test.cc} (98%)

diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt
index c472b2764da..ac54c0c3a80 100644
--- a/tensorflow/contrib/makefile/tf_op_files.txt
+++ b/tensorflow/contrib/makefile/tf_op_files.txt
@@ -8,7 +8,6 @@ tensorflow/contrib/boosted_trees/ops/training_ops.cc
 tensorflow/core/kernels/aggregate_ops.cc
 tensorflow/core/kernels/argmax_op.cc
 tensorflow/core/kernels/avgpooling_op.cc
-tensorflow/core/kernels/batch_matmul_op_common.cc
 tensorflow/core/kernels/batch_matmul_op_real.cc
 tensorflow/core/kernels/batch_norm_op.cc
 tensorflow/core/kernels/batchtospace_op.cc
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 745cd148d9a..ae1e4b135de 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -947,6 +947,7 @@ tf_cuda_library(
         "util/activation_mode.h",
         "util/batch_util.h",
         "util/bcast.h",
+        "util/matmul_bcast.h",
         "util/cuda_kernel_helper.h",
         "util/device_name_utils.h",
         "util/dump_graph.h",
@@ -3977,6 +3978,7 @@ tf_cc_tests(
         "util/events_writer_test.cc",
         "util/example_proto_fast_parsing_test.cc",
         "util/example_proto_helper_test.cc",
+        "util/matmul_bcast_test.cc",
         "util/memmapped_file_system_test.cc",
         "util/presized_cuckoo_map_test.cc",
         "util/reffed_status_callback_test.cc",
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index e462e346e86..f2cc4bfc508 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3682,19 +3682,6 @@ tf_cuda_cc_test(
     ],
 )
 
-tf_cc_test(
-    name = "batch_matmul_op_common_test",
-    size = "small",
-    srcs = ["batch_matmul_op_common_test.cc"],
-    deps = [
-        ":batch_matmul_op",
-        "//tensorflow/core:lib",
-        "//tensorflow/core:test",
-        "//tensorflow/core:test_main",
-        "//tensorflow/core:testlib",
-    ],
-)
-
 tf_cuda_cc_test(
     name = "batch_matmul_op_test",
     size = "small",
@@ -5606,8 +5593,6 @@ filegroup(
     name = "mobile_srcs",
     srcs = [
         "avgpooling_op.h",
-        "batch_matmul_op_common.cc",
-        "batch_matmul_op_common.h",
         "batch_util.h",
         "cwise_ops.h",
         "cwise_ops_common.h",
@@ -6108,7 +6093,6 @@ filegroup(
             "*_3d*",
             "*.cu.*",
             # Ops already in android_srcs
-            "batch_matmul_op_common.cc",
             "pooling_ops_common.cc",
             # Ops which we are currently excluding because they are likely
             # not used on Android. Those ops also do not compile if included,
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 617951b5204..1798b272e0c 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -30,12 +30,12 @@ limitations under the License.
 #include "tensorflow/core/framework/tensor_shape.h"
 #include "tensorflow/core/framework/type_traits.h"
 #include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/kernels/batch_matmul_op_common.h"
 #include "tensorflow/core/kernels/fill_functor.h"
 #include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/inlined_vector.h"
 #include "tensorflow/core/platform/logging.h"
 #include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/util/matmul_bcast.h"
 #include "tensorflow/core/util/work_sharder.h"
 
 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
diff --git a/tensorflow/core/kernels/batch_matmul_op_common.cc b/tensorflow/core/util/matmul_bcast.cc
similarity index 98%
rename from tensorflow/core/kernels/batch_matmul_op_common.cc
rename to tensorflow/core/util/matmul_bcast.cc
index 27963f3b264..3e5c5cf1750 100644
--- a/tensorflow/core/kernels/batch_matmul_op_common.cc
+++ b/tensorflow/core/util/matmul_bcast.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/kernels/batch_matmul_op_common.h"
+#include "tensorflow/core/util/matmul_bcast.h"
 
 namespace tensorflow {
 namespace {
diff --git a/tensorflow/core/kernels/batch_matmul_op_common.h b/tensorflow/core/util/matmul_bcast.h
similarity index 93%
rename from tensorflow/core/kernels/batch_matmul_op_common.h
rename to tensorflow/core/util/matmul_bcast.h
index 99e6d937072..611ef237de6 100644
--- a/tensorflow/core/kernels/batch_matmul_op_common.h
+++ b/tensorflow/core/util/matmul_bcast.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#ifndef TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
-#define TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
+#ifndef TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_
+#define TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_
 
 #include <vector>
 
@@ -67,4 +67,4 @@ class MatMulBCast {
 
 }  // namespace tensorflow
 
-#endif  // TENSORFLOW_CORE_KERNELS_BATCH_MATMUL_OP_COMMON_H_
+#endif  // TENSORFLOW_CORE_UTIL_MATMUL_BCAST_H_
diff --git a/tensorflow/core/kernels/batch_matmul_op_common_test.cc b/tensorflow/core/util/matmul_bcast_test.cc
similarity index 98%
rename from tensorflow/core/kernels/batch_matmul_op_common_test.cc
rename to tensorflow/core/util/matmul_bcast_test.cc
index d6334b7d394..1de62297f70 100644
--- a/tensorflow/core/kernels/batch_matmul_op_common_test.cc
+++ b/tensorflow/core/util/matmul_bcast_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
 limitations under the License.
 ==============================================================================*/
 
-#include "tensorflow/core/kernels/batch_matmul_op_common.h"
+#include "tensorflow/core/util/matmul_bcast.h"
 
 #include "tensorflow/core/lib/strings/str_util.h"
 #include "tensorflow/core/lib/strings/strcat.h"