From ef1ca5b2e184c8bdd78c9eac9cc8b72fa18ad4ad Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" <gardener@tensorflow.org> Date: Fri, 12 Apr 2019 13:13:33 -0700 Subject: [PATCH] Add broadcast optimization rewrite to Grappler. This generalizes the rewrites for algebraically neutral element simplications (e.g. x + zeros(shape) => x ) to handle the cases where the output shape differs from the input shape of the non-trivial argument. This pattern is sometimes used in tensorflow for broadcasting. Example of new rewrites enabled by this change: x * ones_like(y) => broadcast_to(x, shape(y)) zeros_like(x) + y => broadcast_to(y, shape(x)) This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated. PiperOrigin-RevId: 243319718 --- tensorflow/core/BUILD | 5 +- .../grappler/optimizers/constant_folding.cc | 582 ++++++++---------- .../grappler/optimizers/constant_folding.h | 28 +- .../optimizers/constant_folding_test.cc | 51 +- tensorflow/core/ops/array_grad.cc | 26 + tensorflow/core/ops/array_grad_test.cc | 35 ++ 6 files changed, 383 insertions(+), 344 deletions(-) diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 6c9a1b96a5c..c14728e5622 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -136,14 +136,15 @@ load( "tf_additional_libdevice_srcs", "tf_additional_minimal_lib_srcs", "tf_additional_mpi_lib_defines", + "tf_additional_numa_copts", "tf_additional_numa_deps", "tf_additional_numa_lib_defines", - "tf_additional_numa_copts", "tf_additional_proto_hdrs", "tf_additional_proto_srcs", "tf_additional_test_deps", "tf_additional_test_srcs", "tf_additional_verbs_lib_defines", + "tf_grpc_service_all", "tf_jspb_proto_library", "tf_kernel_tests_linkstatic", "tf_lib_proto_compiler_deps", @@ -157,7 +158,6 @@ load( "tf_protos_grappler", "tf_protos_grappler_impl", "tf_pyclif_proto_library", - "tf_grpc_service_all", ) load( ":platform/default/build_config_root.bzl", @@ -4919,6 +4919,7 @@ tf_cc_test( "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:math", "//third_party/eigen3", ], ) diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index ad82ff704c7..4029e9c314b 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode( } else { double value = (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0)); - bool success = false; if (value >= 0) { TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( - value, properties, output_shape, node, graph_, &success)); + value, properties, output_shape, node, graph_)); } } return Status::OK(); @@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot( graph_modified_ = true; } +void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo( + int input_to_broadcast, const GraphProperties& properties, NodeDef* node, + GraphDef* graph) { + const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); + if (dtype == DT_INVALID) return; + const PartialTensorShape shape( + properties.GetOutputProperties(node->name())[0].shape()); + if (!shape.IsFullyDefined()) return; + + // Create constant node with shape. + const string const_name = OptimizedNodeName( + *node, strings::StrCat("-broadcastto_shape-", input_to_broadcast)); + if (node_map_->GetNode(const_name) != nullptr) { + return; + } + + Tensor shape_t; + if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return; + NodeDef tmp; + if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return; + NodeDef* const_node = graph->add_node(); + const_node->Swap(&tmp); + const_node->set_device(node->device()); + node_map_->AddNode(const_name, const_node); + // Add a control input on the unused input. + string ctrl_dep = AddControlDependency( + NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get()); + *const_node->add_input() = ctrl_dep; + node_map_->AddOutput(NodeName(ctrl_dep), const_name); + + // Rewrite `node` in-place to BroadcastTo. + node->set_op("BroadcastTo"); + node->clear_attr(); + (*node->mutable_attr())["T"].set_type(dtype); + (*node->mutable_attr())["Tidx"].set_type(DT_INT32); + // Set the designated input to BroadcastTo. + node->mutable_input()->SwapElements(0, input_to_broadcast); + // Keep all other inputs as control dependencies. + for (int i = 1; i < node->input_size(); ++i) { + if (IsControlInput(node->input(i))) { + break; + } + const string ctrl_dep = + AddControlDependency(node->input(i), graph, node_map_.get()); + node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); + node->set_input(i, ctrl_dep); + } + // Add the shape argument. + *node->add_input() = const_node->name(); + node_map_->AddOutput(const_name, node->name()); + node->mutable_input()->SwapElements(1, node->input_size() - 1); + graph_modified_ = true; +} + void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph) { node->set_op("Reciprocal"); @@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node, Status ConstantFolding::ReplaceOperationWithConstant( double value, const GraphProperties& properties, - const TensorShapeProto& shape, NodeDef* node, GraphDef* graph, - bool* success) { + const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) { const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); if (dtype == DT_INVALID) { - *success = false; return Status::OK(); } @@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant( node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); node->set_input(i, ctrl_dep); } - *success = true; graph_modified_ = true; return Status::OK(); } @@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph( return Status::OK(); } +#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \ + TF_RETURN_IF_ERROR(EXPR); \ + if (graph_modified_) return Status::OK() + +#define SET_AND_RETURN_IF_MODIFIED(EXPR) \ + graph_modified_ = EXPR; \ + if (graph_modified_) return Status::OK() + +#define RETURN_IF_MODIFIED(EXPR) \ + EXPR; \ + if (graph_modified_) return Status::OK() + Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, GraphDef* optimized_graph, GraphProperties* properties) { - if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) { - return Status::OK(); - } + bool graph_modified_cached = graph_modified_; + graph_modified_ = false; - bool remove_shuffle_transpose_successful = false; - Status remove_shuffle_transpose_status = - RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph, - node, &remove_shuffle_transpose_successful); - if (!remove_shuffle_transpose_status.ok()) { - return remove_shuffle_transpose_status; - } else if (remove_shuffle_transpose_successful) { - return Status::OK(); - } - - if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) { - return Status::OK(); - } - - bool remove_reverse_successful = false; - Status remove_reverse_status = - RemoveReverse(*properties, use_shape_info, optimized_graph, node, - &remove_reverse_successful); - if (!remove_reverse_status.ok()) { - return remove_reverse_status; - } else if (remove_reverse_successful) { - return Status::OK(); - } - - bool simplify_slice_successful = false; - Status simplify_slice_status = - SimplifySlice(*properties, use_shape_info, optimized_graph, node, - &simplify_slice_successful); - if (!simplify_slice_status.ok()) { - return simplify_slice_status; - } else if (simplify_slice_successful) { - return Status::OK(); - } - - bool simplify_strided_slice_successful = false; - Status simplify_strided_slice_status = - SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node, - &simplify_strided_slice_successful); - if (!simplify_strided_slice_status.ok()) { - return simplify_strided_slice_status; - } else if (simplify_strided_slice_successful) { - return Status::OK(); - } - - bool simplify_tile_successful = false; - Status simplify_tile_status = - SimplifyTile(*properties, use_shape_info, optimized_graph, node, - &simplify_tile_successful); - if (!simplify_tile_status.ok()) { - return simplify_tile_status; - } else if (simplify_tile_successful) { - return Status::OK(); - } - - bool simplify_pad_successful = false; - Status simplify_pad_status = - SimplifyPad(*properties, use_shape_info, optimized_graph, node, - &simplify_pad_successful); - if (!simplify_pad_status.ok()) { - return simplify_pad_status; - } else if (simplify_pad_successful) { - return Status::OK(); - } - - if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) { - return Status::OK(); - } - - if (SimplifyPack(optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (MoveConstantsPastEnter(optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (SimplifySwitch(optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (SimplifyReduction(optimized_graph, *properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (SimplifyReshape(*properties, use_shape_info, node)) { - graph_modified_ = true; - return Status::OK(); - } - - bool arithmetic_simplification_succeed = false; - Status simplify_arithmetic_status = - SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph, - node, &arithmetic_simplification_succeed); - if (!simplify_arithmetic_status.ok()) { - return simplify_arithmetic_status; - } else if (arithmetic_simplification_succeed) { - graph_modified_ = true; - return Status::OK(); - } - - if (ReduceDivToReciprocalMul(optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (ConstantPushDown(optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (MulConvPushDown(optimized_graph, node, *properties)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialConstPropThroughIdentityN(node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialAssocOpConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (PartialConcatConstFolding(optimized_graph, properties, node)) { - graph_modified_ = true; - return Status::OK(); - } - - if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) { - graph_modified_ = true; - return Status::OK(); - } + RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose( + *properties, use_shape_info, optimized_graph, node)); + RETURN_IF_MODIFIED( + RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED( + RemoveReverse(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED( + SimplifySlice(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED( + SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED( + SimplifyTile(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_ERROR_OR_MODIFIED( + SimplifyPad(*properties, use_shape_info, optimized_graph, node)); + RETURN_IF_MODIFIED( + SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED( + SimplifyReduction(optimized_graph, *properties, node)); + SET_AND_RETURN_IF_MODIFIED( + SimplifyReshape(*properties, use_shape_info, node)); + RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations( + *properties, use_shape_info, optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node)); + SET_AND_RETURN_IF_MODIFIED( + MulConvPushDown(optimized_graph, node, *properties)); + SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node)); + SET_AND_RETURN_IF_MODIFIED( + PartialAssocOpConstFolding(optimized_graph, properties, node)); + SET_AND_RETURN_IF_MODIFIED( + PartialConcatConstFolding(optimized_graph, properties, node)); + SET_AND_RETURN_IF_MODIFIED( + MergeConcat(*properties, use_shape_info, optimized_graph, node)); + graph_modified_ = graph_modified_cached; return Status::OK(); } -bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, +void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, GraphDef* optimized_graph, NodeDef* node) { - if (node->attr().count("num_split") == 0) return false; + if (node->attr().count("num_split") == 0) return; if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { ReplaceOperationWithIdentity(1, properties, node, optimized_graph); - return true; } if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - return true; } - return false; } Status ConstantFolding::RemoveShuffleOrTranspose( const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success) { + GraphDef* optimized_graph, NodeDef* node) { if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && properties.GetInputProperties(node->name()).size() >= 2) { const auto& shape = properties.GetInputProperties(node->name())[0].shape(); @@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose( } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; return Status::OK(); } } } - *success = false; return Status::OK(); } -bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, + +void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { @@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties, if (!shape.unknown_rank() && (shape.dim_size() == 0 || shape.dim(0).size() == 1)) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - return true; } } - return false; } Status ConstantFolding::RemoveReverse(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success) { + GraphDef* optimized_graph, + NodeDef* node) { if (use_shape_info && node->op() == "ReverseV2" && properties.GetInputProperties(node->name()).size() >= 2) { const auto& shape = properties.GetInputProperties(node->name())[0].shape(); @@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; - return Status::OK(); } } } - *success = false; return Status::OK(); } Status ConstantFolding::SimplifySlice(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success) { + GraphDef* optimized_graph, + NodeDef* node) { if (use_shape_info && IsSlice(*node) && properties.GetInputProperties(node->name()).size() == 3) { const auto& input = properties.GetInputProperties(node->name())[0]; @@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; return Status::OK(); } } } - *success = false; return Status::OK(); } Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, - NodeDef* node, bool* success) { + NodeDef* node) { if (use_shape_info && IsStridedSlice(*node) && properties.GetInputProperties(node->name()).size() == 4) { TF_RETURN_IF_ERROR( @@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; - return Status::OK(); } } } - *success = false; return Status::OK(); } Status ConstantFolding::SimplifyTile(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success) { + GraphDef* optimized_graph, NodeDef* node) { if (use_shape_info && IsTile(*node) && properties.GetInputProperties(node->name()).size() == 2) { const auto& m = properties.GetInputProperties(node->name())[1]; @@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; - return Status::OK(); } } } - *success = false; return Status::OK(); } Status ConstantFolding::SimplifyPad(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success) { + GraphDef* optimized_graph, NodeDef* node) { if (use_shape_info && IsPad(*node) && properties.GetInputProperties(node->name()).size() >= 2) { const auto& p = properties.GetInputProperties(node->name())[1]; @@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; - return Status::OK(); } } } - *success = false; return Status::OK(); } -bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties, +void ConstantFolding::SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { @@ -2263,95 +2202,92 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties, } if (replaceable) { ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - return true; } } - return false; } bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { - if (IsPack(*node) && NumNonControlInputs(*node) == 1 && - !OptimizedNodeExists(*node, "_const_axis")) { - // Create constant axis node. - Tensor axis_t(DT_INT32, TensorShape({})); - NodeDef* axis_node = optimized_graph->add_node(); - axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); - const int axis = - node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i(); - if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || - !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) - .ok()) { - return false; - } - // Add a control dependency to make sure axis_node is in the right frame. - const string ctrl_dep = ConstantFolding::AddControlDependency( - node->input(0), optimized_graph, node_map_.get()); - axis_node->add_input(ctrl_dep); - axis_node->set_device(node->device()); - node->set_op("ExpandDims"); - if (node->attr().count("axis") != 0) { - node->mutable_attr()->erase("axis"); - } - if (node->attr().count("N") != 0) { - node->mutable_attr()->erase("N"); - } - (*node->mutable_attr())["Tdim"].set_type(DT_INT32); - node->add_input(axis_node->name()); - if (node->input_size() > 2) { - node->mutable_input()->SwapElements(1, node->input_size() - 1); - } - return true; + if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 && + !OptimizedNodeExists(*node, "_const_axis"))) { + return false; } - return false; + // Create constant axis node. + Tensor axis_t(DT_INT32, TensorShape({})); + NodeDef* axis_node = optimized_graph->add_node(); + axis_node->set_name(OptimizedNodeName(*node, "_const_axis")); + const int axis = + node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i(); + if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || + !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) { + return false; + } + // Add a control dependency to make sure axis_node is in the right frame. + const string ctrl_dep = ConstantFolding::AddControlDependency( + node->input(0), optimized_graph, node_map_.get()); + axis_node->add_input(ctrl_dep); + axis_node->set_device(node->device()); + node->set_op("ExpandDims"); + if (node->attr().count("axis") != 0) { + node->mutable_attr()->erase("axis"); + } + if (node->attr().count("N") != 0) { + node->mutable_attr()->erase("N"); + } + (*node->mutable_attr())["Tdim"].set_type(DT_INT32); + node->add_input(axis_node->name()); + if (node->input_size() > 2) { + node->mutable_input()->SwapElements(1, node->input_size() - 1); + } + return true; } bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph, NodeDef* node) { - if (IsEnter(*node) && node->input_size() > 0) { - if (node->attr().count("is_constant") == 0 || - !node->attr().at("is_constant").b()) { - return false; - } - const string& node_name = node->name(); - const NodeDef* input = node_map_->GetNode(node->input(0)); - if (input != nullptr && IsReallyConstant(*input) && - !OptimizedNodeExists(*input, "_enter")) { - auto fanouts = node_map_->GetOutputs(node_name); - // Find non-constant nodes that consume the output of *node. - std::vector<NodeDef*> consumers; - for (NodeDef* fanout : fanouts) { - if (!IsConstant(*fanout)) { - for (int i = 0; i < fanout->input_size(); ++i) { - if (fanout->input(i) == node_name) { - consumers.push_back(fanout); - break; - } - } + if (!IsEnter(*node) || node->input_size() == 0 || + node->attr().count("is_constant") == 0 || + !node->attr().at("is_constant").b()) { + return false; + } + const string& node_name = node->name(); + const NodeDef* input = node_map_->GetNode(node->input(0)); + if (input == nullptr || !IsReallyConstant(*input) || + OptimizedNodeExists(*input, "_enter")) { + return false; + } + auto fanouts = node_map_->GetOutputs(node_name); + // Find non-constant nodes that consume the output of *node. + std::vector<NodeDef*> consumers; + for (NodeDef* fanout : fanouts) { + if (!IsConstant(*fanout)) { + for (int i = 0; i < fanout->input_size(); ++i) { + if (fanout->input(i) == node_name) { + consumers.push_back(fanout); + break; } } - if (!consumers.empty()) { - NodeDef* new_node = optimized_graph->add_node(); - *new_node = *input; - new_node->set_name(OptimizedNodeName(*input, "_enter")); - new_node->set_device(node->device()); - new_node->clear_input(); - new_node->add_input(AsControlDependency(node_name)); - node_map_->AddNode(new_node->name(), new_node); - node_map_->AddOutput(node_name, new_node->name()); - for (NodeDef* consumer : consumers) { - for (int i = 0; i < consumer->input_size(); ++i) { - if (NodeName(consumer->input(i)) == node_name) { - node_map_->UpdateInput(consumer->name(), node_name, - new_node->name()); - consumer->set_input(i, new_node->name()); - } - } - } - return true; - } } } - return false; + if (consumers.empty()) { + return false; + } + graph_modified_ = true; + NodeDef* new_node = optimized_graph->add_node(); + *new_node = *input; + new_node->set_name(OptimizedNodeName(*input, "_enter")); + new_node->set_device(node->device()); + new_node->clear_input(); + new_node->add_input(AsControlDependency(node_name)); + node_map_->AddNode(new_node->name(), new_node); + node_map_->AddOutput(node_name, new_node->name()); + for (NodeDef* consumer : consumers) { + for (int i = 0; i < consumer->input_size(); ++i) { + if (NodeName(consumer->input(i)) == node_name) { + node_map_->UpdateInput(consumer->name(), node_name, new_node->name()); + consumer->set_input(i, new_node->name()); + } + } + } + return true; } bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) { @@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) { return n1->name() < n2->name(); }); // Create constant false & true nodes. - NodeDef* false_node = optimized_graph->add_node(); - false_node->set_name(OptimizedNodeName(*node, "_const_false")); - if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node) + NodeDef tmp_false_node; + tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false")); + if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t), + &tmp_false_node) .ok()) { return false; } - false_node->set_device(node->device()); + tmp_false_node.set_device(node->device()); + NodeDef tmp_true_node; + tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true")); + if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t), + &tmp_true_node) + .ok()) { + return false; + } + tmp_true_node.set_device(node->device()); + // Add const nodes to graph. + NodeDef* false_node = optimized_graph->add_node(); + false_node->Swap(&tmp_false_node); NodeDef* true_node = optimized_graph->add_node(); - true_node->set_name(OptimizedNodeName(*node, "_const_true")); - if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node) - .ok()) { - return false; - } - true_node->set_device(node->device()); + true_node->Swap(&tmp_true_node); // Add controls from the switch ports to the constants, and connect the // constants to the original switch outputs. @@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties, Status ConstantFolding::SimplifyArithmeticOperations( const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success) { - *success = false; + GraphDef* optimized_graph, NodeDef* node) { const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node); const bool is_matmul = IsAnyMatMul(*node); - const bool is_quantized_matmul = IsQuantizedMatMul(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_sub = IsSub(*node); const bool is_any_div = IsAnyDiv(*node); @@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations( // of zeros. const TensorShapeProto& y_shape = properties.GetInputProperties(node->name())[1].shape(); - const bool x_is_zero = IsZeros(*x); - const bool x_is_one = x_is_zero ? false : IsOnes(*x); + const TensorShapeProto& x_shape = + properties.GetInputProperties(node->name())[0].shape(); const bool y_matches_output_shape = ShapesSymbolicallyEqual(output_shape, y_shape); - if (y_matches_output_shape && - ((is_mul && x_is_one) || (is_add && x_is_zero))) { + const bool x_matches_output_shape = + ShapesSymbolicallyEqual(output_shape, x_shape); + + const bool x_is_zero = IsZeros(*x); + const bool x_is_one = x_is_zero ? false : IsOnes(*x); + if ((is_mul && x_is_one) || (is_add && x_is_zero)) { // 1 * y = y or 0 + y = y. - ReplaceOperationWithSnapshot(1, properties, node, optimized_graph); - *success = true; + if (y_matches_output_shape) { + ReplaceOperationWithSnapshot(1, properties, node, optimized_graph); + } else if (x_matches_output_shape) { + ReplaceBinaryOperationWithBroadcastTo(1, properties, node, + optimized_graph); + } return Status::OK(); } if (y_matches_output_shape && (is_sub && x_is_zero)) { // Replace 0 - y with Neg(y). ReplaceSubtractionFromZeroByNegation(node, optimized_graph); - *success = true; return Status::OK(); } @@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations( DataType type = node->attr().at("T").type(); if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) { ReplaceDivisionOfOnesByReciprocal(node, optimized_graph); - *success = true; return Status::OK(); } } - const TensorShapeProto& x_shape = - properties.GetInputProperties(node->name())[0].shape(); const bool y_is_zero = IsZeros(*y); const bool y_is_one = y_is_zero ? false : IsOnes(*y); - const bool x_matches_output_shape = - ShapesSymbolicallyEqual(output_shape, x_shape); - if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) || - ((is_add || is_sub) && y_is_zero))) { + if (((is_mul || is_any_div) && y_is_one) || + ((is_add || is_sub) && y_is_zero)) { // x * 1 = x or x / 1 = x or x +/- 0 = x - ReplaceOperationWithSnapshot(0, properties, node, optimized_graph); - *success = true; + if (x_matches_output_shape) { + ReplaceOperationWithSnapshot(0, properties, node, optimized_graph); + } else if (y_matches_output_shape) { + ReplaceBinaryOperationWithBroadcastTo(0, properties, node, + optimized_graph); + } return Status::OK(); } // x OR true = true OR y = true. - bool updated_graph = false; const PartialTensorShape shp(output_shape); if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) { - bool replace_succeed = false; - Status replace_op_status = ReplaceOperationWithConstant( - 1, properties, output_shape, node, optimized_graph, &replace_succeed); - if (!replace_op_status.ok()) { - return replace_op_status; - } else if (replace_succeed) { - updated_graph = true; - } + TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( + 1, properties, output_shape, node, optimized_graph)); + return Status::OK(); } // Simplify multiplication and matmul by zeros. @@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations( if ((x_is_zero || y_is_zero) && (is_mul || is_matmul || optimize_zeros_divided_by_y)) { if (shp.IsFullyDefined()) { - bool replace_succeed = false; - Status replace_op_status = - ReplaceOperationWithConstant(0, properties, output_shape, node, - optimized_graph, &replace_succeed); - if (!replace_op_status.ok()) { - return replace_op_status; - } else if (replace_succeed) { - if (is_quantized_matmul) { - TF_RETURN_IF_ERROR( - AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph)); - } - *success = true; - return Status::OK(); + bool is_quantized = IsQuantizedMatMul(*node); + TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( + 0, properties, output_shape, node, optimized_graph)); + if (is_quantized && graph_modified_) { + TF_RETURN_IF_ERROR( + AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph)); } + return Status::OK(); } // Even if an input shape is only partially known, we may known that it - // matches the output shape and thus forward the corresponding zero - // input. - if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - *success = true; + // matches the output shape and thus forward or broadcast the + // corresponding zero input. + if ((is_mul || is_any_div) && x_is_zero) { + if (x_matches_output_shape) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + } else if (y_matches_output_shape) { + ReplaceBinaryOperationWithBroadcastTo(0, properties, node, + optimized_graph); + } return Status::OK(); - } else if (is_mul && y_is_zero && y_matches_output_shape) { - ReplaceOperationWithIdentity(1, properties, node, optimized_graph); - *success = true; + } else if (is_mul && y_is_zero) { + if (y_matches_output_shape) { + ReplaceOperationWithIdentity(1, properties, node, optimized_graph); + } else if (x_matches_output_shape) { + ReplaceBinaryOperationWithBroadcastTo(1, properties, node, + optimized_graph); + } return Status::OK(); } } - if (updated_graph) { - *success = true; - return Status::OK(); - } } - *success = false; return Status::OK(); } @@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( auto add_quantized_out = [this, node, optimized_graph]( const string& out_const_name, int index) { NodeDef* out_node = optimized_graph->add_node(); + graph_modified_ = true; Tensor value(DT_FLOAT, TensorShape({})); const bool is_min = index == 1; const DataType type_attr = node->attr().at("dtype").type(); @@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes( CreateNodeDef(out_const_name, TensorValue(&value), out_node)); node_map_->AddNode(out_const_name, out_node); out_node->set_device(node->device()); - // Copy all inputs from node. out_node->mutable_input()->CopyFrom(node->input()); for (const string& input : out_node->input()) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 418176c8932..45b1ca28ceb 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer { void ReplaceOperationWithSnapshot(int input_to_forward, const GraphProperties& properties, NodeDef* node, GraphDef* graph); + void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast, + const GraphProperties& properties, + NodeDef* node, GraphDef* graph); void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); Status ReplaceOperationWithConstant(double value, const GraphProperties& properties, const TensorShapeProto& shape, - NodeDef* node, GraphDef* graph, - bool* success); + NodeDef* node, GraphDef* graph); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); Status FoldGraph(GraphDef* output, absl::flat_hash_set<string>* nodes_to_not_simplify); @@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer { // was applied. Status SimplifyArithmeticOperations(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success); + GraphDef* optimized_graph, NodeDef* node); // Simplifies a Reshape operation to an Identity operation if applicable. bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, @@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer { bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); // Simplifies a Squeeze operation to an Identity operation if applicable. - bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, + void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); // Simplifies a Pad operation to an Identity operation if applicable. Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success); + GraphDef* optimized_graph, NodeDef* node); // Simplifies a Tile operation to an Identity operation if applicable. Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success); + GraphDef* optimized_graph, NodeDef* node); // Simplifies a StridedSlice operation to an Identity operation if applicable. Status SimplifyStridedSlice(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, - NodeDef* node, bool* success); + NodeDef* node); // Simplifies a Slice operation to an Identity operation if applicable. Status SimplifySlice(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success); + GraphDef* optimized_graph, NodeDef* node); // Removes Reverse op over dimensions with size 1. Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, bool* success); + GraphDef* optimized_graph, NodeDef* node); // Removes RandomShuffle op if it is scalar or first dimension is of size 1. - bool RemoveRandomShuffle(const GraphProperties& properties, + void RemoveRandomShuffle(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node); // Removes Shuffle or Transpose op over dimensions of size 1. Status RemoveShuffleOrTranspose(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node, - bool* success); + GraphDef* optimized_graph, NodeDef* node); // Removes Split or SplitV node if possible. - bool RemoveSplitOrSplitV(const GraphProperties& properties, + void RemoveSplitOrSplitV(const GraphProperties& properties, GraphDef* optimized_graph, NodeDef* node); bool MergeConcat(const GraphProperties& properties, bool use_shape_info, diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index b5e94609e66..22d8cccb1ca 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) { ops::Placeholder::Shape(TensorShape({2}))); Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2}); Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2}); + Output zeros_const_bcast = + ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2}); Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x); Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f); Output zeros = const_type == kConst ? zeros_const : (const_type == kLike ? zeros_like : zeros_fill); Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2}); + Output ones_const_bcast = + ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2}); Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x); Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f); Output ones = const_type == kConst @@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) { : (const_type == kLike ? ones_like : ones_fill); Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); + Output mul1_bcast = + ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast); + Output mul2_bcast = + ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y); Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d); @@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) { Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b); Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); + Output add1_bcast = + ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast); + Output add2_bcast = + ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y); Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d); Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias); Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros); @@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) { matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2}); GrapplerItem item; TF_CHECK_OK(s.ToGraphDef(&item.graph)); - item.fetch = {"stack", "matmul3", "matmul4"}; + item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast", + "mul2_bcast", "add1_bcast", "add2_bcast"}; ConstantFolding optimizer(/*cpu_device=*/nullptr); GraphDef output; @@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) { const string ones_name = strings::StrCat("ones", suffix); const string ctrl_zeros_name = strings::StrCat("^zeros", suffix); const string ctrl_ones_name = strings::StrCat("^ones", suffix); - EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size()); + + EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size()); for (int i = 0; i < output.node_size(); ++i) { const NodeDef& node = output.node(i); const string& name = node.name(); @@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("Const", node.op()); EXPECT_EQ(ctrl_zeros_name, node.input(0)); EXPECT_EQ("^y", node.input(1)); + } else if (name == "mul1_bcast") { + EXPECT_EQ("BroadcastTo", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^ones_const_bcast", node.input(2)); + } else if (name == "mul2_bcast") { + EXPECT_EQ("BroadcastTo", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^ones_const_bcast", node.input(2)); } else if (name == "mul3") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("x", node.input(0)); @@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("y", node.input(0)); EXPECT_EQ(ctrl_zeros_name, node.input(1)); + } else if (name == "add1_bcast") { + EXPECT_EQ("BroadcastTo", node.op()); + EXPECT_EQ("x", node.input(0)); + EXPECT_EQ("^zeros_const_bcast", node.input(2)); + } else if (name == "add2_bcast") { + EXPECT_EQ("BroadcastTo", node.op()); + EXPECT_EQ("y", node.input(0)); + EXPECT_EQ("^zeros_const_bcast", node.input(2)); } else if (name == "bias_add1") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("x", node.input(0)); EXPECT_EQ("^zeros_1d", node.input(1)); } else if (name == "bias_add2") { - // We don't eliminate this one, because it requires broadcasting. - EXPECT_EQ("BiasAdd", node.op()); - EXPECT_EQ(zeros_name, node.input(0)); - EXPECT_EQ("bias", node.input(1)); + EXPECT_EQ("BroadcastTo", node.op()); + EXPECT_EQ("bias", node.input(0)); + EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1", + node.input(1)); + EXPECT_EQ(ctrl_zeros_name, node.input(2)); + } else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") { + EXPECT_EQ("Const", node.op()); + EXPECT_EQ(ctrl_zeros_name, node.input(0)); + EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32); + TensorProto t = node.attr().at("value").tensor(); + EXPECT_EQ(DT_INT32, t.dtype()); + EXPECT_EQ(1, t.tensor_shape().dim_size()); + EXPECT_EQ(2, t.tensor_shape().dim(0).size()); } else if (name == "sub1") { EXPECT_EQ("Identity", node.op()); EXPECT_EQ("x", node.input(0)); diff --git a/tensorflow/core/ops/array_grad.cc b/tensorflow/core/ops/array_grad.cc index 3d03bc1d5fd..f64cf801f22 100644 --- a/tensorflow/core/ops/array_grad.cc +++ b/tensorflow/core/ops/array_grad.cc @@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) { } REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad); +Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) { + DataType itype; + TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype)); + if (itype != DT_INT32) { + return errors::Unimplemented( + "BroadcastToGrad for int64 index are not supported."); + } + std::vector<FDH::Node> nodes = { + {{"sx"}, "Shape", {"x"}, {{"T", "$T"}}}, + {{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}}, + {{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}}, + {{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}}, + {{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}}; + *g = FDH::Define( + // Arg defs + {"x: T", "shape: int32", "dy: T"}, + // Ret val defs + {"dx: T", "dshape: Tidx"}, + // Attr defs + {{"T: type"}, {"Tidx: {int32, int64}"}}, + // Nodes + nodes); + return Status::OK(); +} +REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad); + } // end namespace tensorflow diff --git a/tensorflow/core/ops/array_grad_test.cc b/tensorflow/core/ops/array_grad_test.cc index 79d28a83cc4..bcef90c15e3 100644 --- a/tensorflow/core/ops/array_grad_test.cc +++ b/tensorflow/core/ops/array_grad_test.cc @@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) { } } +std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape, + const Tensor& dy) { + auto T = DT_FLOAT; + auto Tidx = DT_INT32; + auto gdef = test::function::GDef( + {f::NDef("x", "Placeholder", {}, {{"dtype", T}}), + f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}), + f::NDef("dy", "Placeholder", {}, {{"dtype", T}}), + f::NDef( + "dx", "SymbolicGradient", {"x", "shape", "dy"}, + {{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})}, + {"Tin", DataTypeSlice{T, Tidx, T}}, + {"Tout", DataTypeSlice{T, Tidx}}})}); + VLOG(1) << DebugStringWhole(gdef); + auto sess = NewSession(); + TF_CHECK_OK(sess->Create(gdef)); + std::vector<Tensor> out; + TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}}, + {"dx:0", "dx:1"}, {}, &out)); + CHECK_EQ(out.size(), 2); + TF_CHECK_OK(sess->Close()); + return out; +} + +TEST(ArrayGradTest, BroadcastToGrad) { + Tensor x(DT_FLOAT, {2, 2}); + x.flat<float>().setZero(); + Tensor shape(DT_INT32, {3}); + test::FillValues<int32>(&shape, {2, 2, 2}); + Tensor dy(DT_FLOAT, {2, 2, 2}); + test::FillIota<float>(&dy, 0); + auto dx = BroadcastToGrad(x, shape, dy); + test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2})); + test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3})); +} } // namespace } // namespace tensorflow