From ef1ca5b2e184c8bdd78c9eac9cc8b72fa18ad4ad Mon Sep 17 00:00:00 2001
From: "A. Unique TensorFlower" <gardener@tensorflow.org>
Date: Fri, 12 Apr 2019 13:13:33 -0700
Subject: [PATCH] Add broadcast optimization rewrite to Grappler. This
 generalizes the rewrites for algebraically neutral element simplications
 (e.g. x + zeros(shape) => x ) to handle the cases where the output shape
 differs from the input shape of the non-trivial argument. This pattern is
 sometimes used in tensorflow for broadcasting.

Example of new rewrites enabled by this change:

x * ones_like(y) => broadcast_to(x, shape(y))
zeros_like(x) + y => broadcast_to(y, shape(x))

This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated.

PiperOrigin-RevId: 243319718
---
 tensorflow/core/BUILD                         |   5 +-
 .../grappler/optimizers/constant_folding.cc   | 582 ++++++++----------
 .../grappler/optimizers/constant_folding.h    |  28 +-
 .../optimizers/constant_folding_test.cc       |  51 +-
 tensorflow/core/ops/array_grad.cc             |  26 +
 tensorflow/core/ops/array_grad_test.cc        |  35 ++
 6 files changed, 383 insertions(+), 344 deletions(-)

diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 6c9a1b96a5c..c14728e5622 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -136,14 +136,15 @@ load(
     "tf_additional_libdevice_srcs",
     "tf_additional_minimal_lib_srcs",
     "tf_additional_mpi_lib_defines",
+    "tf_additional_numa_copts",
     "tf_additional_numa_deps",
     "tf_additional_numa_lib_defines",
-    "tf_additional_numa_copts",
     "tf_additional_proto_hdrs",
     "tf_additional_proto_srcs",
     "tf_additional_test_deps",
     "tf_additional_test_srcs",
     "tf_additional_verbs_lib_defines",
+    "tf_grpc_service_all",
     "tf_jspb_proto_library",
     "tf_kernel_tests_linkstatic",
     "tf_lib_proto_compiler_deps",
@@ -157,7 +158,6 @@ load(
     "tf_protos_grappler",
     "tf_protos_grappler_impl",
     "tf_pyclif_proto_library",
-    "tf_grpc_service_all",
 )
 load(
     ":platform/default/build_config_root.bzl",
@@ -4919,6 +4919,7 @@ tf_cc_test(
         "//tensorflow/core/kernels:array",
         "//tensorflow/core/kernels:cwise_op",
         "//tensorflow/core/kernels:function_ops",
+        "//tensorflow/core/kernels:math",
         "//third_party/eigen3",
     ],
 )
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index ad82ff704c7..4029e9c314b 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode(
   } else {
     double value =
         (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
-    bool success = false;
     if (value >= 0) {
       TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
-          value, properties, output_shape, node, graph_, &success));
+          value, properties, output_shape, node, graph_));
     }
   }
   return Status::OK();
@@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
   graph_modified_ = true;
 }
 
+void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
+    int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
+    GraphDef* graph) {
+  const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
+  if (dtype == DT_INVALID) return;
+  const PartialTensorShape shape(
+      properties.GetOutputProperties(node->name())[0].shape());
+  if (!shape.IsFullyDefined()) return;
+
+  // Create constant node with shape.
+  const string const_name = OptimizedNodeName(
+      *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
+  if (node_map_->GetNode(const_name) != nullptr) {
+    return;
+  }
+
+  Tensor shape_t;
+  if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return;
+  NodeDef tmp;
+  if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return;
+  NodeDef* const_node = graph->add_node();
+  const_node->Swap(&tmp);
+  const_node->set_device(node->device());
+  node_map_->AddNode(const_name, const_node);
+  // Add a control input on the unused input.
+  string ctrl_dep = AddControlDependency(
+      NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get());
+  *const_node->add_input() = ctrl_dep;
+  node_map_->AddOutput(NodeName(ctrl_dep), const_name);
+
+  // Rewrite `node` in-place to BroadcastTo.
+  node->set_op("BroadcastTo");
+  node->clear_attr();
+  (*node->mutable_attr())["T"].set_type(dtype);
+  (*node->mutable_attr())["Tidx"].set_type(DT_INT32);
+  // Set the designated input to BroadcastTo.
+  node->mutable_input()->SwapElements(0, input_to_broadcast);
+  // Keep all other inputs as control dependencies.
+  for (int i = 1; i < node->input_size(); ++i) {
+    if (IsControlInput(node->input(i))) {
+      break;
+    }
+    const string ctrl_dep =
+        AddControlDependency(node->input(i), graph, node_map_.get());
+    node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
+    node->set_input(i, ctrl_dep);
+  }
+  // Add the shape argument.
+  *node->add_input() = const_node->name();
+  node_map_->AddOutput(const_name, node->name());
+  node->mutable_input()->SwapElements(1, node->input_size() - 1);
+  graph_modified_ = true;
+}
+
 void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
                                                         GraphDef* graph) {
   node->set_op("Reciprocal");
@@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
 
 Status ConstantFolding::ReplaceOperationWithConstant(
     double value, const GraphProperties& properties,
-    const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
-    bool* success) {
+    const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
   const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
   if (dtype == DT_INVALID) {
-    *success = false;
     return Status::OK();
   }
 
@@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant(
     node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
     node->set_input(i, ctrl_dep);
   }
-  *success = true;
   graph_modified_ = true;
   return Status::OK();
 }
@@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph(
   return Status::OK();
 }
 
+#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
+  TF_RETURN_IF_ERROR(EXPR);               \
+  if (graph_modified_) return Status::OK()
+
+#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
+  graph_modified_ = EXPR;                \
+  if (graph_modified_) return Status::OK()
+
+#define RETURN_IF_MODIFIED(EXPR) \
+  EXPR;                          \
+  if (graph_modified_) return Status::OK()
+
 Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
                                      GraphDef* optimized_graph,
                                      GraphProperties* properties) {
-  if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
-    return Status::OK();
-  }
+  bool graph_modified_cached = graph_modified_;
+  graph_modified_ = false;
 
-  bool remove_shuffle_transpose_successful = false;
-  Status remove_shuffle_transpose_status =
-      RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
-                               node, &remove_shuffle_transpose_successful);
-  if (!remove_shuffle_transpose_status.ok()) {
-    return remove_shuffle_transpose_status;
-  } else if (remove_shuffle_transpose_successful) {
-    return Status::OK();
-  }
-
-  if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
-    return Status::OK();
-  }
-
-  bool remove_reverse_successful = false;
-  Status remove_reverse_status =
-      RemoveReverse(*properties, use_shape_info, optimized_graph, node,
-                    &remove_reverse_successful);
-  if (!remove_reverse_status.ok()) {
-    return remove_reverse_status;
-  } else if (remove_reverse_successful) {
-    return Status::OK();
-  }
-
-  bool simplify_slice_successful = false;
-  Status simplify_slice_status =
-      SimplifySlice(*properties, use_shape_info, optimized_graph, node,
-                    &simplify_slice_successful);
-  if (!simplify_slice_status.ok()) {
-    return simplify_slice_status;
-  } else if (simplify_slice_successful) {
-    return Status::OK();
-  }
-
-  bool simplify_strided_slice_successful = false;
-  Status simplify_strided_slice_status =
-      SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
-                           &simplify_strided_slice_successful);
-  if (!simplify_strided_slice_status.ok()) {
-    return simplify_strided_slice_status;
-  } else if (simplify_strided_slice_successful) {
-    return Status::OK();
-  }
-
-  bool simplify_tile_successful = false;
-  Status simplify_tile_status =
-      SimplifyTile(*properties, use_shape_info, optimized_graph, node,
-                   &simplify_tile_successful);
-  if (!simplify_tile_status.ok()) {
-    return simplify_tile_status;
-  } else if (simplify_tile_successful) {
-    return Status::OK();
-  }
-
-  bool simplify_pad_successful = false;
-  Status simplify_pad_status =
-      SimplifyPad(*properties, use_shape_info, optimized_graph, node,
-                  &simplify_pad_successful);
-  if (!simplify_pad_status.ok()) {
-    return simplify_pad_status;
-  } else if (simplify_pad_successful) {
-    return Status::OK();
-  }
-
-  if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
-    return Status::OK();
-  }
-
-  if (SimplifyPack(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (MoveConstantsPastEnter(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (SimplifySwitch(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (SimplifyReduction(optimized_graph, *properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (SimplifyReshape(*properties, use_shape_info, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  bool arithmetic_simplification_succeed = false;
-  Status simplify_arithmetic_status =
-      SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
-                                   node, &arithmetic_simplification_succeed);
-  if (!simplify_arithmetic_status.ok()) {
-    return simplify_arithmetic_status;
-  } else if (arithmetic_simplification_succeed) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (ReduceDivToReciprocalMul(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (ConstantPushDown(optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (MulConvPushDown(optimized_graph, node, *properties)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConstPropThroughIdentityN(node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (PartialConcatConstFolding(optimized_graph, properties, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
-
-  if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
-    graph_modified_ = true;
-    return Status::OK();
-  }
+  RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
+      *properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_MODIFIED(
+      RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(
+      RemoveReverse(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(
+      SimplifySlice(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(
+      SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(
+      SimplifyTile(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_ERROR_OR_MODIFIED(
+      SimplifyPad(*properties, use_shape_info, optimized_graph, node));
+  RETURN_IF_MODIFIED(
+      SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(
+      SimplifyReduction(optimized_graph, *properties, node));
+  SET_AND_RETURN_IF_MODIFIED(
+      SimplifyReshape(*properties, use_shape_info, node));
+  RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
+      *properties, use_shape_info, optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
+  SET_AND_RETURN_IF_MODIFIED(
+      MulConvPushDown(optimized_graph, node, *properties));
+  SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
+  SET_AND_RETURN_IF_MODIFIED(
+      PartialAssocOpConstFolding(optimized_graph, properties, node));
+  SET_AND_RETURN_IF_MODIFIED(
+      PartialConcatConstFolding(optimized_graph, properties, node));
+  SET_AND_RETURN_IF_MODIFIED(
+      MergeConcat(*properties, use_shape_info, optimized_graph, node));
 
+  graph_modified_ = graph_modified_cached;
   return Status::OK();
 }
 
-bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
+void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
                                           GraphDef* optimized_graph,
                                           NodeDef* node) {
-  if (node->attr().count("num_split") == 0) return false;
+  if (node->attr().count("num_split") == 0) return;
   if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
     ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
-    return true;
   }
   if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
     ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-    return true;
   }
-  return false;
 }
 
 Status ConstantFolding::RemoveShuffleOrTranspose(
     const GraphProperties& properties, bool use_shape_info,
-    GraphDef* optimized_graph, NodeDef* node, bool* success) {
+    GraphDef* optimized_graph, NodeDef* node) {
   if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
       properties.GetInputProperties(node->name()).size() >= 2) {
     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
         return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
-bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
+
+void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
                                           bool use_shape_info,
                                           GraphDef* optimized_graph,
                                           NodeDef* node) {
@@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
     if (!shape.unknown_rank() &&
         (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-      return true;
     }
   }
-  return false;
 }
 
 Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
                                       bool use_shape_info,
-                                      GraphDef* optimized_graph, NodeDef* node,
-                                      bool* success) {
+                                      GraphDef* optimized_graph,
+                                      NodeDef* node) {
   if (use_shape_info && node->op() == "ReverseV2" &&
       properties.GetInputProperties(node->name()).size() >= 2) {
     const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
-        return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
 
 Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
                                       bool use_shape_info,
-                                      GraphDef* optimized_graph, NodeDef* node,
-                                      bool* success) {
+                                      GraphDef* optimized_graph,
+                                      NodeDef* node) {
   if (use_shape_info && IsSlice(*node) &&
       properties.GetInputProperties(node->name()).size() == 3) {
     const auto& input = properties.GetInputProperties(node->name())[0];
@@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
         return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
 
 Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
                                              bool use_shape_info,
                                              GraphDef* optimized_graph,
-                                             NodeDef* node, bool* success) {
+                                             NodeDef* node) {
   if (use_shape_info && IsStridedSlice(*node) &&
       properties.GetInputProperties(node->name()).size() == 4) {
     TF_RETURN_IF_ERROR(
@@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
-        return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
 
 Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
                                      bool use_shape_info,
-                                     GraphDef* optimized_graph, NodeDef* node,
-                                     bool* success) {
+                                     GraphDef* optimized_graph, NodeDef* node) {
   if (use_shape_info && IsTile(*node) &&
       properties.GetInputProperties(node->name()).size() == 2) {
     const auto& m = properties.GetInputProperties(node->name())[1];
@@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
-        return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
 
 Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
                                     bool use_shape_info,
-                                    GraphDef* optimized_graph, NodeDef* node,
-                                    bool* success) {
+                                    GraphDef* optimized_graph, NodeDef* node) {
   if (use_shape_info && IsPad(*node) &&
       properties.GetInputProperties(node->name()).size() >= 2) {
     const auto& p = properties.GetInputProperties(node->name())[1];
@@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
       }
       if (replaceable) {
         ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
-        return Status::OK();
       }
     }
   }
-  *success = false;
   return Status::OK();
 }
 
-bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
+void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
                                       bool use_shape_info,
                                       GraphDef* optimized_graph,
                                       NodeDef* node) {
@@ -2263,95 +2202,92 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
     }
     if (replaceable) {
       ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-      return true;
     }
   }
-  return false;
 }
 
 bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
-  if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
-      !OptimizedNodeExists(*node, "_const_axis")) {
-    // Create constant axis node.
-    Tensor axis_t(DT_INT32, TensorShape({}));
-    NodeDef* axis_node = optimized_graph->add_node();
-    axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
-    const int axis =
-        node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
-    if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
-        !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
-             .ok()) {
-      return false;
-    }
-    // Add a control dependency to make sure axis_node is in the right frame.
-    const string ctrl_dep = ConstantFolding::AddControlDependency(
-        node->input(0), optimized_graph, node_map_.get());
-    axis_node->add_input(ctrl_dep);
-    axis_node->set_device(node->device());
-    node->set_op("ExpandDims");
-    if (node->attr().count("axis") != 0) {
-      node->mutable_attr()->erase("axis");
-    }
-    if (node->attr().count("N") != 0) {
-      node->mutable_attr()->erase("N");
-    }
-    (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
-    node->add_input(axis_node->name());
-    if (node->input_size() > 2) {
-      node->mutable_input()->SwapElements(1, node->input_size() - 1);
-    }
-    return true;
+  if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 &&
+        !OptimizedNodeExists(*node, "_const_axis"))) {
+    return false;
   }
-  return false;
+  // Create constant axis node.
+  Tensor axis_t(DT_INT32, TensorShape({}));
+  NodeDef* axis_node = optimized_graph->add_node();
+  axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
+  const int axis =
+      node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
+  if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
+      !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) {
+    return false;
+  }
+  // Add a control dependency to make sure axis_node is in the right frame.
+  const string ctrl_dep = ConstantFolding::AddControlDependency(
+      node->input(0), optimized_graph, node_map_.get());
+  axis_node->add_input(ctrl_dep);
+  axis_node->set_device(node->device());
+  node->set_op("ExpandDims");
+  if (node->attr().count("axis") != 0) {
+    node->mutable_attr()->erase("axis");
+  }
+  if (node->attr().count("N") != 0) {
+    node->mutable_attr()->erase("N");
+  }
+  (*node->mutable_attr())["Tdim"].set_type(DT_INT32);
+  node->add_input(axis_node->name());
+  if (node->input_size() > 2) {
+    node->mutable_input()->SwapElements(1, node->input_size() - 1);
+  }
+  return true;
 }
 
 bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
                                              NodeDef* node) {
-  if (IsEnter(*node) && node->input_size() > 0) {
-    if (node->attr().count("is_constant") == 0 ||
-        !node->attr().at("is_constant").b()) {
-      return false;
-    }
-    const string& node_name = node->name();
-    const NodeDef* input = node_map_->GetNode(node->input(0));
-    if (input != nullptr && IsReallyConstant(*input) &&
-        !OptimizedNodeExists(*input, "_enter")) {
-      auto fanouts = node_map_->GetOutputs(node_name);
-      // Find non-constant nodes that consume the output of *node.
-      std::vector<NodeDef*> consumers;
-      for (NodeDef* fanout : fanouts) {
-        if (!IsConstant(*fanout)) {
-          for (int i = 0; i < fanout->input_size(); ++i) {
-            if (fanout->input(i) == node_name) {
-              consumers.push_back(fanout);
-              break;
-            }
-          }
+  if (!IsEnter(*node) || node->input_size() == 0 ||
+      node->attr().count("is_constant") == 0 ||
+      !node->attr().at("is_constant").b()) {
+    return false;
+  }
+  const string& node_name = node->name();
+  const NodeDef* input = node_map_->GetNode(node->input(0));
+  if (input == nullptr || !IsReallyConstant(*input) ||
+      OptimizedNodeExists(*input, "_enter")) {
+    return false;
+  }
+  auto fanouts = node_map_->GetOutputs(node_name);
+  // Find non-constant nodes that consume the output of *node.
+  std::vector<NodeDef*> consumers;
+  for (NodeDef* fanout : fanouts) {
+    if (!IsConstant(*fanout)) {
+      for (int i = 0; i < fanout->input_size(); ++i) {
+        if (fanout->input(i) == node_name) {
+          consumers.push_back(fanout);
+          break;
         }
       }
-      if (!consumers.empty()) {
-        NodeDef* new_node = optimized_graph->add_node();
-        *new_node = *input;
-        new_node->set_name(OptimizedNodeName(*input, "_enter"));
-        new_node->set_device(node->device());
-        new_node->clear_input();
-        new_node->add_input(AsControlDependency(node_name));
-        node_map_->AddNode(new_node->name(), new_node);
-        node_map_->AddOutput(node_name, new_node->name());
-        for (NodeDef* consumer : consumers) {
-          for (int i = 0; i < consumer->input_size(); ++i) {
-            if (NodeName(consumer->input(i)) == node_name) {
-              node_map_->UpdateInput(consumer->name(), node_name,
-                                     new_node->name());
-              consumer->set_input(i, new_node->name());
-            }
-          }
-        }
-        return true;
-      }
     }
   }
-  return false;
+  if (consumers.empty()) {
+    return false;
+  }
+  graph_modified_ = true;
+  NodeDef* new_node = optimized_graph->add_node();
+  *new_node = *input;
+  new_node->set_name(OptimizedNodeName(*input, "_enter"));
+  new_node->set_device(node->device());
+  new_node->clear_input();
+  new_node->add_input(AsControlDependency(node_name));
+  node_map_->AddNode(new_node->name(), new_node);
+  node_map_->AddOutput(node_name, new_node->name());
+  for (NodeDef* consumer : consumers) {
+    for (int i = 0; i < consumer->input_size(); ++i) {
+      if (NodeName(consumer->input(i)) == node_name) {
+        node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
+        consumer->set_input(i, new_node->name());
+      }
+    }
+  }
+  return true;
 }
 
 bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
@@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
                   return n1->name() < n2->name();
                 });
       // Create constant false & true nodes.
-      NodeDef* false_node = optimized_graph->add_node();
-      false_node->set_name(OptimizedNodeName(*node, "_const_false"));
-      if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
+      NodeDef tmp_false_node;
+      tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
+      if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
+                         &tmp_false_node)
                .ok()) {
         return false;
       }
-      false_node->set_device(node->device());
+      tmp_false_node.set_device(node->device());
+      NodeDef tmp_true_node;
+      tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
+      if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
+                         &tmp_true_node)
+               .ok()) {
+        return false;
+      }
+      tmp_true_node.set_device(node->device());
 
+      // Add const nodes to graph.
+      NodeDef* false_node = optimized_graph->add_node();
+      false_node->Swap(&tmp_false_node);
       NodeDef* true_node = optimized_graph->add_node();
-      true_node->set_name(OptimizedNodeName(*node, "_const_true"));
-      if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
-               .ok()) {
-        return false;
-      }
-      true_node->set_device(node->device());
+      true_node->Swap(&tmp_true_node);
 
       // Add controls from the switch ports to the constants, and connect the
       // constants to the original switch outputs.
@@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
 
 Status ConstantFolding::SimplifyArithmeticOperations(
     const GraphProperties& properties, bool use_shape_info,
-    GraphDef* optimized_graph, NodeDef* node, bool* success) {
-  *success = false;
+    GraphDef* optimized_graph, NodeDef* node) {
   const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
   const bool is_matmul = IsAnyMatMul(*node);
-  const bool is_quantized_matmul = IsQuantizedMatMul(*node);
   const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
   const bool is_sub = IsSub(*node);
   const bool is_any_div = IsAnyDiv(*node);
@@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations(
     // of zeros.
     const TensorShapeProto& y_shape =
         properties.GetInputProperties(node->name())[1].shape();
-    const bool x_is_zero = IsZeros(*x);
-    const bool x_is_one = x_is_zero ? false : IsOnes(*x);
+    const TensorShapeProto& x_shape =
+        properties.GetInputProperties(node->name())[0].shape();
     const bool y_matches_output_shape =
         ShapesSymbolicallyEqual(output_shape, y_shape);
-    if (y_matches_output_shape &&
-        ((is_mul && x_is_one) || (is_add && x_is_zero))) {
+    const bool x_matches_output_shape =
+        ShapesSymbolicallyEqual(output_shape, x_shape);
+
+    const bool x_is_zero = IsZeros(*x);
+    const bool x_is_one = x_is_zero ? false : IsOnes(*x);
+    if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
       // 1 * y = y or 0 + y = y.
-      ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
-      *success = true;
+      if (y_matches_output_shape) {
+        ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
+      } else if (x_matches_output_shape) {
+        ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
+                                              optimized_graph);
+      }
       return Status::OK();
     }
 
     if (y_matches_output_shape && (is_sub && x_is_zero)) {
       // Replace 0 - y with Neg(y).
       ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
-      *success = true;
       return Status::OK();
     }
 
@@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations(
       DataType type = node->attr().at("T").type();
       if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
         ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
-        *success = true;
         return Status::OK();
       }
     }
 
-    const TensorShapeProto& x_shape =
-        properties.GetInputProperties(node->name())[0].shape();
     const bool y_is_zero = IsZeros(*y);
     const bool y_is_one = y_is_zero ? false : IsOnes(*y);
-    const bool x_matches_output_shape =
-        ShapesSymbolicallyEqual(output_shape, x_shape);
-    if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
-                                   ((is_add || is_sub) && y_is_zero))) {
+    if (((is_mul || is_any_div) && y_is_one) ||
+        ((is_add || is_sub) && y_is_zero)) {
       // x * 1 = x or x / 1 = x or x +/- 0 = x
-      ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
-      *success = true;
+      if (x_matches_output_shape) {
+        ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
+      } else if (y_matches_output_shape) {
+        ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
+                                              optimized_graph);
+      }
       return Status::OK();
     }
 
     // x OR true = true OR y = true.
-    bool updated_graph = false;
     const PartialTensorShape shp(output_shape);
     if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
-      bool replace_succeed = false;
-      Status replace_op_status = ReplaceOperationWithConstant(
-          1, properties, output_shape, node, optimized_graph, &replace_succeed);
-      if (!replace_op_status.ok()) {
-        return replace_op_status;
-      } else if (replace_succeed) {
-        updated_graph = true;
-      }
+      TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+          1, properties, output_shape, node, optimized_graph));
+      return Status::OK();
     }
 
     // Simplify multiplication and matmul by zeros.
@@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations(
     if ((x_is_zero || y_is_zero) &&
         (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
       if (shp.IsFullyDefined()) {
-        bool replace_succeed = false;
-        Status replace_op_status =
-            ReplaceOperationWithConstant(0, properties, output_shape, node,
-                                         optimized_graph, &replace_succeed);
-        if (!replace_op_status.ok()) {
-          return replace_op_status;
-        } else if (replace_succeed) {
-          if (is_quantized_matmul) {
-            TF_RETURN_IF_ERROR(
-                AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
-          }
-          *success = true;
-          return Status::OK();
+        bool is_quantized = IsQuantizedMatMul(*node);
+        TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
+            0, properties, output_shape, node, optimized_graph));
+        if (is_quantized && graph_modified_) {
+          TF_RETURN_IF_ERROR(
+              AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
         }
+        return Status::OK();
       }
       // Even if an input shape is only partially known, we may known that it
-      // matches the output shape and thus forward the corresponding zero
-      // input.
-      if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
-        ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
-        *success = true;
+      // matches the output shape and thus forward or broadcast the
+      // corresponding zero input.
+      if ((is_mul || is_any_div) && x_is_zero) {
+        if (x_matches_output_shape) {
+          ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
+        } else if (y_matches_output_shape) {
+          ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
+                                                optimized_graph);
+        }
         return Status::OK();
-      } else if (is_mul && y_is_zero && y_matches_output_shape) {
-        ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
-        *success = true;
+      } else if (is_mul && y_is_zero) {
+        if (y_matches_output_shape) {
+          ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
+        } else if (x_matches_output_shape) {
+          ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
+                                                optimized_graph);
+        }
         return Status::OK();
       }
     }
-    if (updated_graph) {
-      *success = true;
-      return Status::OK();
-    }
   }
-  *success = false;
   return Status::OK();
 }
 
@@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
   auto add_quantized_out = [this, node, optimized_graph](
                                const string& out_const_name, int index) {
     NodeDef* out_node = optimized_graph->add_node();
+    graph_modified_ = true;
     Tensor value(DT_FLOAT, TensorShape({}));
     const bool is_min = index == 1;
     const DataType type_attr = node->attr().at("dtype").type();
@@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
         CreateNodeDef(out_const_name, TensorValue(&value), out_node));
     node_map_->AddNode(out_const_name, out_node);
     out_node->set_device(node->device());
-
     // Copy all inputs from node.
     out_node->mutable_input()->CopyFrom(node->input());
     for (const string& input : out_node->input()) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index 418176c8932..45b1ca28ceb 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer {
   void ReplaceOperationWithSnapshot(int input_to_forward,
                                     const GraphProperties& properties,
                                     NodeDef* node, GraphDef* graph);
+  void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
+                                             const GraphProperties& properties,
+                                             NodeDef* node, GraphDef* graph);
   void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
   Status ReplaceOperationWithConstant(double value,
                                       const GraphProperties& properties,
                                       const TensorShapeProto& shape,
-                                      NodeDef* node, GraphDef* graph,
-                                      bool* success);
+                                      NodeDef* node, GraphDef* graph);
   void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
   Status FoldGraph(GraphDef* output,
                    absl::flat_hash_set<string>* nodes_to_not_simplify);
@@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer {
   // was applied.
   Status SimplifyArithmeticOperations(const GraphProperties& properties,
                                       bool use_shape_info,
-                                      GraphDef* optimized_graph, NodeDef* node,
-                                      bool* success);
+                                      GraphDef* optimized_graph, NodeDef* node);
 
   // Simplifies a Reshape operation to an Identity operation if applicable.
   bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
@@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer {
   bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
 
   // Simplifies a Squeeze operation to an Identity operation if applicable.
-  bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
+  void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
                        GraphDef* optimized_graph, NodeDef* node);
 
   // Simplifies a Pad operation to an Identity operation if applicable.
   Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
-                     GraphDef* optimized_graph, NodeDef* node, bool* success);
+                     GraphDef* optimized_graph, NodeDef* node);
 
   // Simplifies a Tile operation to an Identity operation if applicable.
   Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
-                      GraphDef* optimized_graph, NodeDef* node, bool* success);
+                      GraphDef* optimized_graph, NodeDef* node);
 
   // Simplifies a StridedSlice operation to an Identity operation if applicable.
   Status SimplifyStridedSlice(const GraphProperties& properties,
                               bool use_shape_info, GraphDef* optimized_graph,
-                              NodeDef* node, bool* success);
+                              NodeDef* node);
 
   // Simplifies a Slice operation to an Identity operation if applicable.
   Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
-                       GraphDef* optimized_graph, NodeDef* node, bool* success);
+                       GraphDef* optimized_graph, NodeDef* node);
 
   // Removes Reverse op over dimensions with size 1.
   Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
-                       GraphDef* optimized_graph, NodeDef* node, bool* success);
+                       GraphDef* optimized_graph, NodeDef* node);
 
   // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
-  bool RemoveRandomShuffle(const GraphProperties& properties,
+  void RemoveRandomShuffle(const GraphProperties& properties,
                            bool use_shape_info, GraphDef* optimized_graph,
                            NodeDef* node);
 
   // Removes Shuffle or Transpose op over dimensions of size 1.
   Status RemoveShuffleOrTranspose(const GraphProperties& properties,
                                   bool use_shape_info,
-                                  GraphDef* optimized_graph, NodeDef* node,
-                                  bool* success);
+                                  GraphDef* optimized_graph, NodeDef* node);
 
   // Removes Split or SplitV node if possible.
-  bool RemoveSplitOrSplitV(const GraphProperties& properties,
+  void RemoveSplitOrSplitV(const GraphProperties& properties,
                            GraphDef* optimized_graph, NodeDef* node);
 
   bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b5e94609e66..22d8cccb1ca 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
                                    ops::Placeholder::Shape(TensorShape({2})));
     Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
     Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
+    Output zeros_const_bcast =
+        ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
     Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
     Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
     Output zeros = const_type == kConst
                        ? zeros_const
                        : (const_type == kLike ? zeros_like : zeros_fill);
     Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
+    Output ones_const_bcast =
+        ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
     Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
     Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
     Output ones = const_type == kConst
@@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
                       : (const_type == kLike ? ones_like : ones_fill);
     Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
     Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
+    Output mul1_bcast =
+        ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
+    Output mul2_bcast =
+        ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
     Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
     Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
     Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
@@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
     Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
     Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
+    Output add1_bcast =
+        ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
+    Output add2_bcast =
+        ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
     Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
     Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
     Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
@@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
                     matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
     GrapplerItem item;
     TF_CHECK_OK(s.ToGraphDef(&item.graph));
-    item.fetch = {"stack", "matmul3", "matmul4"};
+    item.fetch = {"stack",      "matmul3",    "matmul4",   "mul1_bcast",
+                  "mul2_bcast", "add1_bcast", "add2_bcast"};
 
     ConstantFolding optimizer(/*cpu_device=*/nullptr);
     GraphDef output;
@@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
     const string ones_name = strings::StrCat("ones", suffix);
     const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
     const string ctrl_ones_name = strings::StrCat("^ones", suffix);
-    EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
+
+    EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
     for (int i = 0; i < output.node_size(); ++i) {
       const NodeDef& node = output.node(i);
       const string& name = node.name();
@@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ("Const", node.op());
         EXPECT_EQ(ctrl_zeros_name, node.input(0));
         EXPECT_EQ("^y", node.input(1));
+      } else if (name == "mul1_bcast") {
+        EXPECT_EQ("BroadcastTo", node.op());
+        EXPECT_EQ("x", node.input(0));
+        EXPECT_EQ("^ones_const_bcast", node.input(2));
+      } else if (name == "mul2_bcast") {
+        EXPECT_EQ("BroadcastTo", node.op());
+        EXPECT_EQ("y", node.input(0));
+        EXPECT_EQ("^ones_const_bcast", node.input(2));
       } else if (name == "mul3") {
         EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
@@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
         EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("y", node.input(0));
         EXPECT_EQ(ctrl_zeros_name, node.input(1));
+      } else if (name == "add1_bcast") {
+        EXPECT_EQ("BroadcastTo", node.op());
+        EXPECT_EQ("x", node.input(0));
+        EXPECT_EQ("^zeros_const_bcast", node.input(2));
+      } else if (name == "add2_bcast") {
+        EXPECT_EQ("BroadcastTo", node.op());
+        EXPECT_EQ("y", node.input(0));
+        EXPECT_EQ("^zeros_const_bcast", node.input(2));
       } else if (name == "bias_add1") {
         EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
         EXPECT_EQ("^zeros_1d", node.input(1));
       } else if (name == "bias_add2") {
-        // We don't eliminate this one, because it requires broadcasting.
-        EXPECT_EQ("BiasAdd", node.op());
-        EXPECT_EQ(zeros_name, node.input(0));
-        EXPECT_EQ("bias", node.input(1));
+        EXPECT_EQ("BroadcastTo", node.op());
+        EXPECT_EQ("bias", node.input(0));
+        EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
+                  node.input(1));
+        EXPECT_EQ(ctrl_zeros_name, node.input(2));
+      } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
+        EXPECT_EQ("Const", node.op());
+        EXPECT_EQ(ctrl_zeros_name, node.input(0));
+        EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
+        TensorProto t = node.attr().at("value").tensor();
+        EXPECT_EQ(DT_INT32, t.dtype());
+        EXPECT_EQ(1, t.tensor_shape().dim_size());
+        EXPECT_EQ(2, t.tensor_shape().dim(0).size());
       } else if (name == "sub1") {
         EXPECT_EQ("Identity", node.op());
         EXPECT_EQ("x", node.input(0));
diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc
index 3d03bc1d5fd..f64cf801f22 100644
--- a/tensorflow/core/ops/array_grad.cc
+++ b/tensorflow/core/ops/array_grad.cc
@@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
 }
 REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
 
+Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
+  DataType itype;
+  TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
+  if (itype != DT_INT32) {
+    return errors::Unimplemented(
+        "BroadcastToGrad for int64 index are not supported.");
+  }
+  std::vector<FDH::Node> nodes = {
+      {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
+      {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
+      {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
+      {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
+      {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
+  *g = FDH::Define(
+      // Arg defs
+      {"x: T", "shape: int32", "dy: T"},
+      // Ret val defs
+      {"dx: T", "dshape: Tidx"},
+      // Attr defs
+      {{"T: type"}, {"Tidx: {int32, int64}"}},
+      // Nodes
+      nodes);
+  return Status::OK();
+}
+REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
+
 }  // end namespace tensorflow
diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc
index 79d28a83cc4..bcef90c15e3 100644
--- a/tensorflow/core/ops/array_grad_test.cc
+++ b/tensorflow/core/ops/array_grad_test.cc
@@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) {
   }
 }
 
+std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape,
+                                    const Tensor& dy) {
+  auto T = DT_FLOAT;
+  auto Tidx = DT_INT32;
+  auto gdef = test::function::GDef(
+      {f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
+       f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}),
+       f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
+       f::NDef(
+           "dx", "SymbolicGradient", {"x", "shape", "dy"},
+           {{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})},
+            {"Tin", DataTypeSlice{T, Tidx, T}},
+            {"Tout", DataTypeSlice{T, Tidx}}})});
+  VLOG(1) << DebugStringWhole(gdef);
+  auto sess = NewSession();
+  TF_CHECK_OK(sess->Create(gdef));
+  std::vector<Tensor> out;
+  TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}},
+                        {"dx:0", "dx:1"}, {}, &out));
+  CHECK_EQ(out.size(), 2);
+  TF_CHECK_OK(sess->Close());
+  return out;
+}
+
+TEST(ArrayGradTest, BroadcastToGrad) {
+  Tensor x(DT_FLOAT, {2, 2});
+  x.flat<float>().setZero();
+  Tensor shape(DT_INT32, {3});
+  test::FillValues<int32>(&shape, {2, 2, 2});
+  Tensor dy(DT_FLOAT, {2, 2, 2});
+  test::FillIota<float>(&dy, 0);
+  auto dx = BroadcastToGrad(x, shape, dy);
+  test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2}));
+  test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3}));
+}
 }  // namespace
 }  // namespace tensorflow