Add broadcast optimization rewrite to Grappler. This generalizes the rewrites for algebraically neutral element simplications (e.g. x + zeros(shape) => x ) to handle the cases where the output shape differs from the input shape of the non-trivial argument. This pattern is sometimes used in tensorflow for broadcasting.

Example of new rewrites enabled by this change:

x * ones_like(y) => broadcast_to(x, shape(y))
zeros_like(x) + y => broadcast_to(y, shape(x))

This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated.

PiperOrigin-RevId: 243319718
This commit is contained in:
A. Unique TensorFlower 2019-04-12 13:13:33 -07:00 committed by TensorFlower Gardener
parent a3d2ee6ada
commit ef1ca5b2e1
6 changed files with 383 additions and 344 deletions

View File

@ -136,14 +136,15 @@ load(
"tf_additional_libdevice_srcs", "tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs", "tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines", "tf_additional_mpi_lib_defines",
"tf_additional_numa_copts",
"tf_additional_numa_deps", "tf_additional_numa_deps",
"tf_additional_numa_lib_defines", "tf_additional_numa_lib_defines",
"tf_additional_numa_copts",
"tf_additional_proto_hdrs", "tf_additional_proto_hdrs",
"tf_additional_proto_srcs", "tf_additional_proto_srcs",
"tf_additional_test_deps", "tf_additional_test_deps",
"tf_additional_test_srcs", "tf_additional_test_srcs",
"tf_additional_verbs_lib_defines", "tf_additional_verbs_lib_defines",
"tf_grpc_service_all",
"tf_jspb_proto_library", "tf_jspb_proto_library",
"tf_kernel_tests_linkstatic", "tf_kernel_tests_linkstatic",
"tf_lib_proto_compiler_deps", "tf_lib_proto_compiler_deps",
@ -157,7 +158,6 @@ load(
"tf_protos_grappler", "tf_protos_grappler",
"tf_protos_grappler_impl", "tf_protos_grappler_impl",
"tf_pyclif_proto_library", "tf_pyclif_proto_library",
"tf_grpc_service_all",
) )
load( load(
":platform/default/build_config_root.bzl", ":platform/default/build_config_root.bzl",
@ -4919,6 +4919,7 @@ tf_cc_test(
"//tensorflow/core/kernels:array", "//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:math",
"//third_party/eigen3", "//third_party/eigen3",
], ],
) )

View File

@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode(
} else { } else {
double value = double value =
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0)); (IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
bool success = false;
if (value >= 0) { if (value >= 0) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant( TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
value, properties, output_shape, node, graph_, &success)); value, properties, output_shape, node, graph_));
} }
} }
return Status::OK(); return Status::OK();
@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
graph_modified_ = true; graph_modified_ = true;
} }
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
const PartialTensorShape shape(
properties.GetOutputProperties(node->name())[0].shape());
if (!shape.IsFullyDefined()) return;
// Create constant node with shape.
const string const_name = OptimizedNodeName(
*node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
if (node_map_->GetNode(const_name) != nullptr) {
return;
}
Tensor shape_t;
if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return;
NodeDef tmp;
if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return;
NodeDef* const_node = graph->add_node();
const_node->Swap(&tmp);
const_node->set_device(node->device());
node_map_->AddNode(const_name, const_node);
// Add a control input on the unused input.
string ctrl_dep = AddControlDependency(
NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get());
*const_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
// Rewrite `node` in-place to BroadcastTo.
node->set_op("BroadcastTo");
node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype);
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
// Set the designated input to BroadcastTo.
node->mutable_input()->SwapElements(0, input_to_broadcast);
// Keep all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
// Add the shape argument.
*node->add_input() = const_node->name();
node_map_->AddOutput(const_name, node->name());
node->mutable_input()->SwapElements(1, node->input_size() - 1);
graph_modified_ = true;
}
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node, void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
GraphDef* graph) { GraphDef* graph) {
node->set_op("Reciprocal"); node->set_op("Reciprocal");
@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
Status ConstantFolding::ReplaceOperationWithConstant( Status ConstantFolding::ReplaceOperationWithConstant(
double value, const GraphProperties& properties, double value, const GraphProperties& properties,
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph, const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
bool* success) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties); const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) { if (dtype == DT_INVALID) {
*success = false;
return Status::OK(); return Status::OK();
} }
@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant(
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep); node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep); node->set_input(i, ctrl_dep);
} }
*success = true;
graph_modified_ = true; graph_modified_ = true;
return Status::OK(); return Status::OK();
} }
@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph(
return Status::OK(); return Status::OK();
} }
#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
TF_RETURN_IF_ERROR(EXPR); \
if (graph_modified_) return Status::OK()
#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
graph_modified_ = EXPR; \
if (graph_modified_) return Status::OK()
#define RETURN_IF_MODIFIED(EXPR) \
EXPR; \
if (graph_modified_) return Status::OK()
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
GraphDef* optimized_graph, GraphDef* optimized_graph,
GraphProperties* properties) { GraphProperties* properties) {
if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) { bool graph_modified_cached = graph_modified_;
return Status::OK(); graph_modified_ = false;
}
bool remove_shuffle_transpose_successful = false; RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
Status remove_shuffle_transpose_status = RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph, *properties, use_shape_info, optimized_graph, node));
node, &remove_shuffle_transpose_successful); RETURN_IF_MODIFIED(
if (!remove_shuffle_transpose_status.ok()) { RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
return remove_shuffle_transpose_status; RETURN_IF_ERROR_OR_MODIFIED(
} else if (remove_shuffle_transpose_successful) { RemoveReverse(*properties, use_shape_info, optimized_graph, node));
return Status::OK(); RETURN_IF_ERROR_OR_MODIFIED(
} SimplifySlice(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) { SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
return Status::OK(); RETURN_IF_ERROR_OR_MODIFIED(
} SimplifyTile(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
bool remove_reverse_successful = false; SimplifyPad(*properties, use_shape_info, optimized_graph, node));
Status remove_reverse_status = RETURN_IF_MODIFIED(
RemoveReverse(*properties, use_shape_info, optimized_graph, node, SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
&remove_reverse_successful); SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
if (!remove_reverse_status.ok()) { SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
return remove_reverse_status; SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
} else if (remove_reverse_successful) { SET_AND_RETURN_IF_MODIFIED(
return Status::OK(); SimplifyReduction(optimized_graph, *properties, node));
} SET_AND_RETURN_IF_MODIFIED(
SimplifyReshape(*properties, use_shape_info, node));
bool simplify_slice_successful = false; RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
Status simplify_slice_status = *properties, use_shape_info, optimized_graph, node));
SimplifySlice(*properties, use_shape_info, optimized_graph, node, SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
&simplify_slice_successful); SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
if (!simplify_slice_status.ok()) { SET_AND_RETURN_IF_MODIFIED(
return simplify_slice_status; MulConvPushDown(optimized_graph, node, *properties));
} else if (simplify_slice_successful) { SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
return Status::OK(); SET_AND_RETURN_IF_MODIFIED(
} PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
bool simplify_strided_slice_successful = false; PartialConcatConstFolding(optimized_graph, properties, node));
Status simplify_strided_slice_status = SET_AND_RETURN_IF_MODIFIED(
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node, MergeConcat(*properties, use_shape_info, optimized_graph, node));
&simplify_strided_slice_successful);
if (!simplify_strided_slice_status.ok()) {
return simplify_strided_slice_status;
} else if (simplify_strided_slice_successful) {
return Status::OK();
}
bool simplify_tile_successful = false;
Status simplify_tile_status =
SimplifyTile(*properties, use_shape_info, optimized_graph, node,
&simplify_tile_successful);
if (!simplify_tile_status.ok()) {
return simplify_tile_status;
} else if (simplify_tile_successful) {
return Status::OK();
}
bool simplify_pad_successful = false;
Status simplify_pad_status =
SimplifyPad(*properties, use_shape_info, optimized_graph, node,
&simplify_pad_successful);
if (!simplify_pad_status.ok()) {
return simplify_pad_status;
} else if (simplify_pad_successful) {
return Status::OK();
}
if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
return Status::OK();
}
if (SimplifyPack(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MoveConstantsPastEnter(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifySwitch(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifyReduction(optimized_graph, *properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifyReshape(*properties, use_shape_info, node)) {
graph_modified_ = true;
return Status::OK();
}
bool arithmetic_simplification_succeed = false;
Status simplify_arithmetic_status =
SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
node, &arithmetic_simplification_succeed);
if (!simplify_arithmetic_status.ok()) {
return simplify_arithmetic_status;
} else if (arithmetic_simplification_succeed) {
graph_modified_ = true;
return Status::OK();
}
if (ReduceDivToReciprocalMul(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (ConstantPushDown(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MulConvPushDown(optimized_graph, node, *properties)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConstPropThroughIdentityN(node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
graph_modified_ = graph_modified_cached;
return Status::OK(); return Status::OK();
} }
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (node->attr().count("num_split") == 0) return false; if (node->attr().count("num_split") == 0) return;
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) { if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph); ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
return true;
} }
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) { if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
} }
return false;
} }
Status ConstantFolding::RemoveShuffleOrTranspose( Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info, const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) { GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
properties.GetInputProperties(node->name()).size() >= 2) { properties.GetInputProperties(node->name()).size() >= 2) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape(); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK(); return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
if (!shape.unknown_rank() && if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) { (shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
} }
} }
return false;
} }
Status ConstantFolding::RemoveReverse(const GraphProperties& properties, Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph,
bool* success) { NodeDef* node) {
if (use_shape_info && node->op() == "ReverseV2" && if (use_shape_info && node->op() == "ReverseV2" &&
properties.GetInputProperties(node->name()).size() >= 2) { properties.GetInputProperties(node->name()).size() >= 2) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape(); const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifySlice(const GraphProperties& properties, Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph,
bool* success) { NodeDef* node) {
if (use_shape_info && IsSlice(*node) && if (use_shape_info && IsSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 3) { properties.GetInputProperties(node->name()).size() == 3) {
const auto& input = properties.GetInputProperties(node->name())[0]; const auto& input = properties.GetInputProperties(node->name())[0];
@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK(); return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node, bool* success) { NodeDef* node) {
if (use_shape_info && IsStridedSlice(*node) && if (use_shape_info && IsStridedSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 4) { properties.GetInputProperties(node->name()).size() == 4) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifyTile(const GraphProperties& properties, Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph, NodeDef* node) {
bool* success) {
if (use_shape_info && IsTile(*node) && if (use_shape_info && IsTile(*node) &&
properties.GetInputProperties(node->name()).size() == 2) { properties.GetInputProperties(node->name()).size() == 2) {
const auto& m = properties.GetInputProperties(node->name())[1]; const auto& m = properties.GetInputProperties(node->name())[1];
@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
Status ConstantFolding::SimplifyPad(const GraphProperties& properties, Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph, NodeDef* node) {
bool* success) {
if (use_shape_info && IsPad(*node) && if (use_shape_info && IsPad(*node) &&
properties.GetInputProperties(node->name()).size() >= 2) { properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1]; const auto& p = properties.GetInputProperties(node->name())[1];
@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
} }
} }
} }
*success = false;
return Status::OK(); return Status::OK();
} }
bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties, void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
@ -2263,15 +2202,15 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
} }
if (replaceable) { if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
} }
} }
return false;
} }
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) { bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
if (IsPack(*node) && NumNonControlInputs(*node) == 1 && if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 &&
!OptimizedNodeExists(*node, "_const_axis")) { !OptimizedNodeExists(*node, "_const_axis"))) {
return false;
}
// Create constant axis node. // Create constant axis node.
Tensor axis_t(DT_INT32, TensorShape({})); Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node(); NodeDef* axis_node = optimized_graph->add_node();
@ -2279,8 +2218,7 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
const int axis = const int axis =
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i(); node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() || if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node) !CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) {
.ok()) {
return false; return false;
} }
// Add a control dependency to make sure axis_node is in the right frame. // Add a control dependency to make sure axis_node is in the right frame.
@ -2301,21 +2239,21 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
node->mutable_input()->SwapElements(1, node->input_size() - 1); node->mutable_input()->SwapElements(1, node->input_size() - 1);
} }
return true; return true;
}
return false;
} }
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph, bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
NodeDef* node) { NodeDef* node) {
if (IsEnter(*node) && node->input_size() > 0) { if (!IsEnter(*node) || node->input_size() == 0 ||
if (node->attr().count("is_constant") == 0 || node->attr().count("is_constant") == 0 ||
!node->attr().at("is_constant").b()) { !node->attr().at("is_constant").b()) {
return false; return false;
} }
const string& node_name = node->name(); const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0)); const NodeDef* input = node_map_->GetNode(node->input(0));
if (input != nullptr && IsReallyConstant(*input) && if (input == nullptr || !IsReallyConstant(*input) ||
!OptimizedNodeExists(*input, "_enter")) { OptimizedNodeExists(*input, "_enter")) {
return false;
}
auto fanouts = node_map_->GetOutputs(node_name); auto fanouts = node_map_->GetOutputs(node_name);
// Find non-constant nodes that consume the output of *node. // Find non-constant nodes that consume the output of *node.
std::vector<NodeDef*> consumers; std::vector<NodeDef*> consumers;
@ -2329,7 +2267,10 @@ bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
} }
} }
} }
if (!consumers.empty()) { if (consumers.empty()) {
return false;
}
graph_modified_ = true;
NodeDef* new_node = optimized_graph->add_node(); NodeDef* new_node = optimized_graph->add_node();
*new_node = *input; *new_node = *input;
new_node->set_name(OptimizedNodeName(*input, "_enter")); new_node->set_name(OptimizedNodeName(*input, "_enter"));
@ -2341,17 +2282,12 @@ bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
for (NodeDef* consumer : consumers) { for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) { for (int i = 0; i < consumer->input_size(); ++i) {
if (NodeName(consumer->input(i)) == node_name) { if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name, node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
new_node->name());
consumer->set_input(i, new_node->name()); consumer->set_input(i, new_node->name());
} }
} }
} }
return true; return true;
}
}
}
return false;
} }
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) { bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
return n1->name() < n2->name(); return n1->name() < n2->name();
}); });
// Create constant false & true nodes. // Create constant false & true nodes.
NodeDef* false_node = optimized_graph->add_node(); NodeDef tmp_false_node;
false_node->set_name(OptimizedNodeName(*node, "_const_false")); tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node) if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
&tmp_false_node)
.ok()) { .ok()) {
return false; return false;
} }
false_node->set_device(node->device()); tmp_false_node.set_device(node->device());
NodeDef tmp_true_node;
tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
&tmp_true_node)
.ok()) {
return false;
}
tmp_true_node.set_device(node->device());
// Add const nodes to graph.
NodeDef* false_node = optimized_graph->add_node();
false_node->Swap(&tmp_false_node);
NodeDef* true_node = optimized_graph->add_node(); NodeDef* true_node = optimized_graph->add_node();
true_node->set_name(OptimizedNodeName(*node, "_const_true")); true_node->Swap(&tmp_true_node);
if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
.ok()) {
return false;
}
true_node->set_device(node->device());
// Add controls from the switch ports to the constants, and connect the // Add controls from the switch ports to the constants, and connect the
// constants to the original switch outputs. // constants to the original switch outputs.
@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
Status ConstantFolding::SimplifyArithmeticOperations( Status ConstantFolding::SimplifyArithmeticOperations(
const GraphProperties& properties, bool use_shape_info, const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) { GraphDef* optimized_graph, NodeDef* node) {
*success = false;
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node); const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsAnyMatMul(*node); const bool is_matmul = IsAnyMatMul(*node);
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node); const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node); const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node); const bool is_any_div = IsAnyDiv(*node);
@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations(
// of zeros. // of zeros.
const TensorShapeProto& y_shape = const TensorShapeProto& y_shape =
properties.GetInputProperties(node->name())[1].shape(); properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x); const TensorShapeProto& x_shape =
const bool x_is_one = x_is_zero ? false : IsOnes(*x); properties.GetInputProperties(node->name())[0].shape();
const bool y_matches_output_shape = const bool y_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, y_shape); ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape && const bool x_matches_output_shape =
((is_mul && x_is_one) || (is_add && x_is_zero))) { ShapesSymbolicallyEqual(output_shape, x_shape);
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
// 1 * y = y or 0 + y = y. // 1 * y = y or 0 + y = y.
if (y_matches_output_shape) {
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph); ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
*success = true; } else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK(); return Status::OK();
} }
if (y_matches_output_shape && (is_sub && x_is_zero)) { if (y_matches_output_shape && (is_sub && x_is_zero)) {
// Replace 0 - y with Neg(y). // Replace 0 - y with Neg(y).
ReplaceSubtractionFromZeroByNegation(node, optimized_graph); ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
*success = true;
return Status::OK(); return Status::OK();
} }
@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations(
DataType type = node->attr().at("T").type(); DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) { if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph); ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
*success = true;
return Status::OK(); return Status::OK();
} }
} }
const TensorShapeProto& x_shape =
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y); const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y); const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape = if (((is_mul || is_any_div) && y_is_one) ||
ShapesSymbolicallyEqual(output_shape, x_shape); ((is_add || is_sub) && y_is_zero)) {
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
// x * 1 = x or x / 1 = x or x +/- 0 = x // x * 1 = x or x / 1 = x or x +/- 0 = x
if (x_matches_output_shape) {
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph); ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
*success = true; } else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK(); return Status::OK();
} }
// x OR true = true OR y = true. // x OR true = true OR y = true.
bool updated_graph = false;
const PartialTensorShape shp(output_shape); const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) { if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
bool replace_succeed = false; TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
Status replace_op_status = ReplaceOperationWithConstant( 1, properties, output_shape, node, optimized_graph));
1, properties, output_shape, node, optimized_graph, &replace_succeed); return Status::OK();
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
updated_graph = true;
}
} }
// Simplify multiplication and matmul by zeros. // Simplify multiplication and matmul by zeros.
@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations(
if ((x_is_zero || y_is_zero) && if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) { (is_mul || is_matmul || optimize_zeros_divided_by_y)) {
if (shp.IsFullyDefined()) { if (shp.IsFullyDefined()) {
bool replace_succeed = false; bool is_quantized = IsQuantizedMatMul(*node);
Status replace_op_status = TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
ReplaceOperationWithConstant(0, properties, output_shape, node, 0, properties, output_shape, node, optimized_graph));
optimized_graph, &replace_succeed); if (is_quantized && graph_modified_) {
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
if (is_quantized_matmul) {
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph)); AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
} }
*success = true;
return Status::OK(); return Status::OK();
} }
}
// Even if an input shape is only partially known, we may known that it // Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero // matches the output shape and thus forward or broadcast the
// input. // corresponding zero input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) { if ((is_mul || is_any_div) && x_is_zero) {
if (x_matches_output_shape) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph); ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true; } else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK(); return Status::OK();
} else if (is_mul && y_is_zero && y_matches_output_shape) { } else if (is_mul && y_is_zero) {
if (y_matches_output_shape) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph); ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
*success = true; } else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK(); return Status::OK();
} }
} }
if (updated_graph) {
*success = true;
return Status::OK();
} }
}
*success = false;
return Status::OK(); return Status::OK();
} }
@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
auto add_quantized_out = [this, node, optimized_graph]( auto add_quantized_out = [this, node, optimized_graph](
const string& out_const_name, int index) { const string& out_const_name, int index) {
NodeDef* out_node = optimized_graph->add_node(); NodeDef* out_node = optimized_graph->add_node();
graph_modified_ = true;
Tensor value(DT_FLOAT, TensorShape({})); Tensor value(DT_FLOAT, TensorShape({}));
const bool is_min = index == 1; const bool is_min = index == 1;
const DataType type_attr = node->attr().at("dtype").type(); const DataType type_attr = node->attr().at("dtype").type();
@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
CreateNodeDef(out_const_name, TensorValue(&value), out_node)); CreateNodeDef(out_const_name, TensorValue(&value), out_node));
node_map_->AddNode(out_const_name, out_node); node_map_->AddNode(out_const_name, out_node);
out_node->set_device(node->device()); out_node->set_device(node->device());
// Copy all inputs from node. // Copy all inputs from node.
out_node->mutable_input()->CopyFrom(node->input()); out_node->mutable_input()->CopyFrom(node->input());
for (const string& input : out_node->input()) { for (const string& input : out_node->input()) {

View File

@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer {
void ReplaceOperationWithSnapshot(int input_to_forward, void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties, const GraphProperties& properties,
NodeDef* node, GraphDef* graph); NodeDef* node, GraphDef* graph);
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph); void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
Status ReplaceOperationWithConstant(double value, Status ReplaceOperationWithConstant(double value,
const GraphProperties& properties, const GraphProperties& properties,
const TensorShapeProto& shape, const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph, NodeDef* node, GraphDef* graph);
bool* success);
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph); void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
Status FoldGraph(GraphDef* output, Status FoldGraph(GraphDef* output,
absl::flat_hash_set<string>* nodes_to_not_simplify); absl::flat_hash_set<string>* nodes_to_not_simplify);
@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer {
// was applied. // was applied.
Status SimplifyArithmeticOperations(const GraphProperties& properties, Status SimplifyArithmeticOperations(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph, NodeDef* node);
bool* success);
// Simplifies a Reshape operation to an Identity operation if applicable. // Simplifies a Reshape operation to an Identity operation if applicable.
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info, bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer {
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node); bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Squeeze operation to an Identity operation if applicable. // Simplifies a Squeeze operation to an Identity operation if applicable.
bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info, void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node); GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Pad operation to an Identity operation if applicable. // Simplifies a Pad operation to an Identity operation if applicable.
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info, Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success); GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Tile operation to an Identity operation if applicable. // Simplifies a Tile operation to an Identity operation if applicable.
Status SimplifyTile(const GraphProperties& properties, bool use_shape_info, Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success); GraphDef* optimized_graph, NodeDef* node);
// Simplifies a StridedSlice operation to an Identity operation if applicable. // Simplifies a StridedSlice operation to an Identity operation if applicable.
Status SimplifyStridedSlice(const GraphProperties& properties, Status SimplifyStridedSlice(const GraphProperties& properties,
bool use_shape_info, GraphDef* optimized_graph, bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node, bool* success); NodeDef* node);
// Simplifies a Slice operation to an Identity operation if applicable. // Simplifies a Slice operation to an Identity operation if applicable.
Status SimplifySlice(const GraphProperties& properties, bool use_shape_info, Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success); GraphDef* optimized_graph, NodeDef* node);
// Removes Reverse op over dimensions with size 1. // Removes Reverse op over dimensions with size 1.
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info, Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success); GraphDef* optimized_graph, NodeDef* node);
// Removes RandomShuffle op if it is scalar or first dimension is of size 1. // Removes RandomShuffle op if it is scalar or first dimension is of size 1.
bool RemoveRandomShuffle(const GraphProperties& properties, void RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info, GraphDef* optimized_graph, bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node); NodeDef* node);
// Removes Shuffle or Transpose op over dimensions of size 1. // Removes Shuffle or Transpose op over dimensions of size 1.
Status RemoveShuffleOrTranspose(const GraphProperties& properties, Status RemoveShuffleOrTranspose(const GraphProperties& properties,
bool use_shape_info, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, GraphDef* optimized_graph, NodeDef* node);
bool* success);
// Removes Split or SplitV node if possible. // Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties, void RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node); GraphDef* optimized_graph, NodeDef* node);
bool MergeConcat(const GraphProperties& properties, bool use_shape_info, bool MergeConcat(const GraphProperties& properties, bool use_shape_info,

View File

@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
ops::Placeholder::Shape(TensorShape({2}))); ops::Placeholder::Shape(TensorShape({2})));
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2}); Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2}); Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
Output zeros_const_bcast =
ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x); Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f); Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
Output zeros = const_type == kConst Output zeros = const_type == kConst
? zeros_const ? zeros_const
: (const_type == kLike ? zeros_like : zeros_fill); : (const_type == kLike ? zeros_like : zeros_fill);
Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2}); Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
Output ones_const_bcast =
ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x); Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f); Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
Output ones = const_type == kConst Output ones = const_type == kConst
@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
: (const_type == kLike ? ones_like : ones_fill); : (const_type == kLike ? ones_like : ones_fill);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros); Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y); Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul1_bcast =
ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
Output mul2_bcast =
ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones); Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y); Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d); Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b); Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros); Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y); Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
Output add1_bcast =
ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
Output add2_bcast =
ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d); Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias); Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros); Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2}); matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
GrapplerItem item; GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph)); TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"stack", "matmul3", "matmul4"}; item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
"mul2_bcast", "add1_bcast", "add2_bcast"};
ConstantFolding optimizer(/*cpu_device=*/nullptr); ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef output; GraphDef output;
@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
const string ones_name = strings::StrCat("ones", suffix); const string ones_name = strings::StrCat("ones", suffix);
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix); const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
const string ctrl_ones_name = strings::StrCat("^ones", suffix); const string ctrl_ones_name = strings::StrCat("^ones", suffix);
EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
for (int i = 0; i < output.node_size(); ++i) { for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i); const NodeDef& node = output.node(i);
const string& name = node.name(); const string& name = node.name();
@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("Const", node.op()); EXPECT_EQ("Const", node.op());
EXPECT_EQ(ctrl_zeros_name, node.input(0)); EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1)); EXPECT_EQ("^y", node.input(1));
} else if (name == "mul1_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones_const_bcast", node.input(2));
} else if (name == "mul2_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^ones_const_bcast", node.input(2));
} else if (name == "mul3") { } else if (name == "mul3") {
EXPECT_EQ("Identity", node.op()); EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0)); EXPECT_EQ("x", node.input(0));
@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("Identity", node.op()); EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0)); EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1)); EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add1_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_const_bcast", node.input(2));
} else if (name == "add2_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^zeros_const_bcast", node.input(2));
} else if (name == "bias_add1") { } else if (name == "bias_add1") {
EXPECT_EQ("Identity", node.op()); EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0)); EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_1d", node.input(1)); EXPECT_EQ("^zeros_1d", node.input(1));
} else if (name == "bias_add2") { } else if (name == "bias_add2") {
// We don't eliminate this one, because it requires broadcasting. EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("BiasAdd", node.op()); EXPECT_EQ("bias", node.input(0));
EXPECT_EQ(zeros_name, node.input(0)); EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
EXPECT_EQ("bias", node.input(1)); node.input(1));
EXPECT_EQ(ctrl_zeros_name, node.input(2));
} else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(DT_INT32, t.dtype());
EXPECT_EQ(1, t.tensor_shape().dim_size());
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
} else if (name == "sub1") { } else if (name == "sub1") {
EXPECT_EQ("Identity", node.op()); EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0)); EXPECT_EQ("x", node.input(0));

View File

@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
} }
REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad); REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
DataType itype;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
if (itype != DT_INT32) {
return errors::Unimplemented(
"BroadcastToGrad for int64 index are not supported.");
}
std::vector<FDH::Node> nodes = {
{{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
{{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
{{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
{{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
*g = FDH::Define(
// Arg defs
{"x: T", "shape: int32", "dy: T"},
// Ret val defs
{"dx: T", "dshape: Tidx"},
// Attr defs
{{"T: type"}, {"Tidx: {int32, int64}"}},
// Nodes
nodes);
return Status::OK();
}
REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
} // end namespace tensorflow } // end namespace tensorflow

View File

@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) {
} }
} }
std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape,
const Tensor& dy) {
auto T = DT_FLOAT;
auto Tidx = DT_INT32;
auto gdef = test::function::GDef(
{f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}),
f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
f::NDef(
"dx", "SymbolicGradient", {"x", "shape", "dy"},
{{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})},
{"Tin", DataTypeSlice{T, Tidx, T}},
{"Tout", DataTypeSlice{T, Tidx}}})});
VLOG(1) << DebugStringWhole(gdef);
auto sess = NewSession();
TF_CHECK_OK(sess->Create(gdef));
std::vector<Tensor> out;
TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}},
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
return out;
}
TEST(ArrayGradTest, BroadcastToGrad) {
Tensor x(DT_FLOAT, {2, 2});
x.flat<float>().setZero();
Tensor shape(DT_INT32, {3});
test::FillValues<int32>(&shape, {2, 2, 2});
Tensor dy(DT_FLOAT, {2, 2, 2});
test::FillIota<float>(&dy, 0);
auto dx = BroadcastToGrad(x, shape, dy);
test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2}));
test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3}));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow