Add broadcast optimization rewrite to Grappler. This generalizes the rewrites for algebraically neutral element simplications (e.g. x + zeros(shape) => x ) to handle the cases where the output shape differs from the input shape of the non-trivial argument. This pattern is sometimes used in tensorflow for broadcasting.

Example of new rewrites enabled by this change:

x * ones_like(y) => broadcast_to(x, shape(y))
zeros_like(x) + y => broadcast_to(y, shape(x))

This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated.

PiperOrigin-RevId: 243319718
This commit is contained in:
A. Unique TensorFlower 2019-04-12 13:13:33 -07:00 committed by TensorFlower Gardener
parent a3d2ee6ada
commit ef1ca5b2e1
6 changed files with 383 additions and 344 deletions

View File

@ -136,14 +136,15 @@ load(
"tf_additional_libdevice_srcs",
"tf_additional_minimal_lib_srcs",
"tf_additional_mpi_lib_defines",
"tf_additional_numa_copts",
"tf_additional_numa_deps",
"tf_additional_numa_lib_defines",
"tf_additional_numa_copts",
"tf_additional_proto_hdrs",
"tf_additional_proto_srcs",
"tf_additional_test_deps",
"tf_additional_test_srcs",
"tf_additional_verbs_lib_defines",
"tf_grpc_service_all",
"tf_jspb_proto_library",
"tf_kernel_tests_linkstatic",
"tf_lib_proto_compiler_deps",
@ -157,7 +158,6 @@ load(
"tf_protos_grappler",
"tf_protos_grappler_impl",
"tf_pyclif_proto_library",
"tf_grpc_service_all",
)
load(
":platform/default/build_config_root.bzl",
@ -4919,6 +4919,7 @@ tf_cc_test(
"//tensorflow/core/kernels:array",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:function_ops",
"//tensorflow/core/kernels:math",
"//third_party/eigen3",
],
)

View File

@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode(
} else {
double value =
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
bool success = false;
if (value >= 0) {
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
value, properties, output_shape, node, graph_, &success));
value, properties, output_shape, node, graph_));
}
}
return Status::OK();
@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
graph_modified_ = true;
}
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) return;
const PartialTensorShape shape(
properties.GetOutputProperties(node->name())[0].shape());
if (!shape.IsFullyDefined()) return;
// Create constant node with shape.
const string const_name = OptimizedNodeName(
*node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
if (node_map_->GetNode(const_name) != nullptr) {
return;
}
Tensor shape_t;
if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return;
NodeDef tmp;
if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return;
NodeDef* const_node = graph->add_node();
const_node->Swap(&tmp);
const_node->set_device(node->device());
node_map_->AddNode(const_name, const_node);
// Add a control input on the unused input.
string ctrl_dep = AddControlDependency(
NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get());
*const_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
// Rewrite `node` in-place to BroadcastTo.
node->set_op("BroadcastTo");
node->clear_attr();
(*node->mutable_attr())["T"].set_type(dtype);
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
// Set the designated input to BroadcastTo.
node->mutable_input()->SwapElements(0, input_to_broadcast);
// Keep all other inputs as control dependencies.
for (int i = 1; i < node->input_size(); ++i) {
if (IsControlInput(node->input(i))) {
break;
}
const string ctrl_dep =
AddControlDependency(node->input(i), graph, node_map_.get());
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
// Add the shape argument.
*node->add_input() = const_node->name();
node_map_->AddOutput(const_name, node->name());
node->mutable_input()->SwapElements(1, node->input_size() - 1);
graph_modified_ = true;
}
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
GraphDef* graph) {
node->set_op("Reciprocal");
@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
Status ConstantFolding::ReplaceOperationWithConstant(
double value, const GraphProperties& properties,
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
bool* success) {
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
if (dtype == DT_INVALID) {
*success = false;
return Status::OK();
}
@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant(
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
node->set_input(i, ctrl_dep);
}
*success = true;
graph_modified_ = true;
return Status::OK();
}
@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph(
return Status::OK();
}
#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
TF_RETURN_IF_ERROR(EXPR); \
if (graph_modified_) return Status::OK()
#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
graph_modified_ = EXPR; \
if (graph_modified_) return Status::OK()
#define RETURN_IF_MODIFIED(EXPR) \
EXPR; \
if (graph_modified_) return Status::OK()
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
GraphDef* optimized_graph,
GraphProperties* properties) {
if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
return Status::OK();
}
bool graph_modified_cached = graph_modified_;
graph_modified_ = false;
bool remove_shuffle_transpose_successful = false;
Status remove_shuffle_transpose_status =
RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
node, &remove_shuffle_transpose_successful);
if (!remove_shuffle_transpose_status.ok()) {
return remove_shuffle_transpose_status;
} else if (remove_shuffle_transpose_successful) {
return Status::OK();
}
if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
return Status::OK();
}
bool remove_reverse_successful = false;
Status remove_reverse_status =
RemoveReverse(*properties, use_shape_info, optimized_graph, node,
&remove_reverse_successful);
if (!remove_reverse_status.ok()) {
return remove_reverse_status;
} else if (remove_reverse_successful) {
return Status::OK();
}
bool simplify_slice_successful = false;
Status simplify_slice_status =
SimplifySlice(*properties, use_shape_info, optimized_graph, node,
&simplify_slice_successful);
if (!simplify_slice_status.ok()) {
return simplify_slice_status;
} else if (simplify_slice_successful) {
return Status::OK();
}
bool simplify_strided_slice_successful = false;
Status simplify_strided_slice_status =
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
&simplify_strided_slice_successful);
if (!simplify_strided_slice_status.ok()) {
return simplify_strided_slice_status;
} else if (simplify_strided_slice_successful) {
return Status::OK();
}
bool simplify_tile_successful = false;
Status simplify_tile_status =
SimplifyTile(*properties, use_shape_info, optimized_graph, node,
&simplify_tile_successful);
if (!simplify_tile_status.ok()) {
return simplify_tile_status;
} else if (simplify_tile_successful) {
return Status::OK();
}
bool simplify_pad_successful = false;
Status simplify_pad_status =
SimplifyPad(*properties, use_shape_info, optimized_graph, node,
&simplify_pad_successful);
if (!simplify_pad_status.ok()) {
return simplify_pad_status;
} else if (simplify_pad_successful) {
return Status::OK();
}
if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
return Status::OK();
}
if (SimplifyPack(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MoveConstantsPastEnter(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifySwitch(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifyReduction(optimized_graph, *properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (SimplifyReshape(*properties, use_shape_info, node)) {
graph_modified_ = true;
return Status::OK();
}
bool arithmetic_simplification_succeed = false;
Status simplify_arithmetic_status =
SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
node, &arithmetic_simplification_succeed);
if (!simplify_arithmetic_status.ok()) {
return simplify_arithmetic_status;
} else if (arithmetic_simplification_succeed) {
graph_modified_ = true;
return Status::OK();
}
if (ReduceDivToReciprocalMul(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (ConstantPushDown(optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MulConvPushDown(optimized_graph, node, *properties)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConstPropThroughIdentityN(node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
graph_modified_ = true;
return Status::OK();
}
if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
graph_modified_ = true;
return Status::OK();
}
RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
*properties, use_shape_info, optimized_graph, node));
RETURN_IF_MODIFIED(
RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
RemoveReverse(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifySlice(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyTile(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_ERROR_OR_MODIFIED(
SimplifyPad(*properties, use_shape_info, optimized_graph, node));
RETURN_IF_MODIFIED(
SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifyReduction(optimized_graph, *properties, node));
SET_AND_RETURN_IF_MODIFIED(
SimplifyReshape(*properties, use_shape_info, node));
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
*properties, use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
MulConvPushDown(optimized_graph, node, *properties));
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
MergeConcat(*properties, use_shape_info, optimized_graph, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
}
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph,
NodeDef* node) {
if (node->attr().count("num_split") == 0) return false;
if (node->attr().count("num_split") == 0) return;
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
return true;
}
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
}
return false;
}
Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) {
GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
if (!shape.unknown_rank() &&
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
}
}
return false;
}
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success) {
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && node->op() == "ReverseV2" &&
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success) {
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && IsSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 3) {
const auto& input = properties.GetInputProperties(node->name())[0];
@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node, bool* success) {
NodeDef* node) {
if (use_shape_info && IsStridedSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 4) {
TF_RETURN_IF_ERROR(
@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success) {
GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && IsTile(*node) &&
properties.GetInputProperties(node->name()).size() == 2) {
const auto& m = properties.GetInputProperties(node->name())[1];
@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success) {
GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && IsPad(*node) &&
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1];
@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
return Status::OK();
}
}
}
*success = false;
return Status::OK();
}
bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
@ -2263,95 +2202,92 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return true;
}
}
return false;
}
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
!OptimizedNodeExists(*node, "_const_axis")) {
// Create constant axis node.
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
const int axis =
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
.ok()) {
return false;
}
// Add a control dependency to make sure axis_node is in the right frame.
const string ctrl_dep = ConstantFolding::AddControlDependency(
node->input(0), optimized_graph, node_map_.get());
axis_node->add_input(ctrl_dep);
axis_node->set_device(node->device());
node->set_op("ExpandDims");
if (node->attr().count("axis") != 0) {
node->mutable_attr()->erase("axis");
}
if (node->attr().count("N") != 0) {
node->mutable_attr()->erase("N");
}
(*node->mutable_attr())["Tdim"].set_type(DT_INT32);
node->add_input(axis_node->name());
if (node->input_size() > 2) {
node->mutable_input()->SwapElements(1, node->input_size() - 1);
}
return true;
if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 &&
!OptimizedNodeExists(*node, "_const_axis"))) {
return false;
}
return false;
// Create constant axis node.
Tensor axis_t(DT_INT32, TensorShape({}));
NodeDef* axis_node = optimized_graph->add_node();
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
const int axis =
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) {
return false;
}
// Add a control dependency to make sure axis_node is in the right frame.
const string ctrl_dep = ConstantFolding::AddControlDependency(
node->input(0), optimized_graph, node_map_.get());
axis_node->add_input(ctrl_dep);
axis_node->set_device(node->device());
node->set_op("ExpandDims");
if (node->attr().count("axis") != 0) {
node->mutable_attr()->erase("axis");
}
if (node->attr().count("N") != 0) {
node->mutable_attr()->erase("N");
}
(*node->mutable_attr())["Tdim"].set_type(DT_INT32);
node->add_input(axis_node->name());
if (node->input_size() > 2) {
node->mutable_input()->SwapElements(1, node->input_size() - 1);
}
return true;
}
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
NodeDef* node) {
if (IsEnter(*node) && node->input_size() > 0) {
if (node->attr().count("is_constant") == 0 ||
!node->attr().at("is_constant").b()) {
return false;
}
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input != nullptr && IsReallyConstant(*input) &&
!OptimizedNodeExists(*input, "_enter")) {
auto fanouts = node_map_->GetOutputs(node_name);
// Find non-constant nodes that consume the output of *node.
std::vector<NodeDef*> consumers;
for (NodeDef* fanout : fanouts) {
if (!IsConstant(*fanout)) {
for (int i = 0; i < fanout->input_size(); ++i) {
if (fanout->input(i) == node_name) {
consumers.push_back(fanout);
break;
}
}
if (!IsEnter(*node) || node->input_size() == 0 ||
node->attr().count("is_constant") == 0 ||
!node->attr().at("is_constant").b()) {
return false;
}
const string& node_name = node->name();
const NodeDef* input = node_map_->GetNode(node->input(0));
if (input == nullptr || !IsReallyConstant(*input) ||
OptimizedNodeExists(*input, "_enter")) {
return false;
}
auto fanouts = node_map_->GetOutputs(node_name);
// Find non-constant nodes that consume the output of *node.
std::vector<NodeDef*> consumers;
for (NodeDef* fanout : fanouts) {
if (!IsConstant(*fanout)) {
for (int i = 0; i < fanout->input_size(); ++i) {
if (fanout->input(i) == node_name) {
consumers.push_back(fanout);
break;
}
}
if (!consumers.empty()) {
NodeDef* new_node = optimized_graph->add_node();
*new_node = *input;
new_node->set_name(OptimizedNodeName(*input, "_enter"));
new_node->set_device(node->device());
new_node->clear_input();
new_node->add_input(AsControlDependency(node_name));
node_map_->AddNode(new_node->name(), new_node);
node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name,
new_node->name());
consumer->set_input(i, new_node->name());
}
}
}
return true;
}
}
}
return false;
if (consumers.empty()) {
return false;
}
graph_modified_ = true;
NodeDef* new_node = optimized_graph->add_node();
*new_node = *input;
new_node->set_name(OptimizedNodeName(*input, "_enter"));
new_node->set_device(node->device());
new_node->clear_input();
new_node->add_input(AsControlDependency(node_name));
node_map_->AddNode(new_node->name(), new_node);
node_map_->AddOutput(node_name, new_node->name());
for (NodeDef* consumer : consumers) {
for (int i = 0; i < consumer->input_size(); ++i) {
if (NodeName(consumer->input(i)) == node_name) {
node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
consumer->set_input(i, new_node->name());
}
}
}
return true;
}
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
return n1->name() < n2->name();
});
// Create constant false & true nodes.
NodeDef* false_node = optimized_graph->add_node();
false_node->set_name(OptimizedNodeName(*node, "_const_false"));
if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
NodeDef tmp_false_node;
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
&tmp_false_node)
.ok()) {
return false;
}
false_node->set_device(node->device());
tmp_false_node.set_device(node->device());
NodeDef tmp_true_node;
tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
&tmp_true_node)
.ok()) {
return false;
}
tmp_true_node.set_device(node->device());
// Add const nodes to graph.
NodeDef* false_node = optimized_graph->add_node();
false_node->Swap(&tmp_false_node);
NodeDef* true_node = optimized_graph->add_node();
true_node->set_name(OptimizedNodeName(*node, "_const_true"));
if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
.ok()) {
return false;
}
true_node->set_device(node->device());
true_node->Swap(&tmp_true_node);
// Add controls from the switch ports to the constants, and connect the
// constants to the original switch outputs.
@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
Status ConstantFolding::SimplifyArithmeticOperations(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success) {
*success = false;
GraphDef* optimized_graph, NodeDef* node) {
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
const bool is_matmul = IsAnyMatMul(*node);
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
const bool is_sub = IsSub(*node);
const bool is_any_div = IsAnyDiv(*node);
@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations(
// of zeros.
const TensorShapeProto& y_shape =
properties.GetInputProperties(node->name())[1].shape();
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
const TensorShapeProto& x_shape =
properties.GetInputProperties(node->name())[0].shape();
const bool y_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, y_shape);
if (y_matches_output_shape &&
((is_mul && x_is_one) || (is_add && x_is_zero))) {
const bool x_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, x_shape);
const bool x_is_zero = IsZeros(*x);
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
// 1 * y = y or 0 + y = y.
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
*success = true;
if (y_matches_output_shape) {
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
} else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK();
}
if (y_matches_output_shape && (is_sub && x_is_zero)) {
// Replace 0 - y with Neg(y).
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
*success = true;
return Status::OK();
}
@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations(
DataType type = node->attr().at("T").type();
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
*success = true;
return Status::OK();
}
}
const TensorShapeProto& x_shape =
properties.GetInputProperties(node->name())[0].shape();
const bool y_is_zero = IsZeros(*y);
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
const bool x_matches_output_shape =
ShapesSymbolicallyEqual(output_shape, x_shape);
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero))) {
if (((is_mul || is_any_div) && y_is_one) ||
((is_add || is_sub) && y_is_zero)) {
// x * 1 = x or x / 1 = x or x +/- 0 = x
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
*success = true;
if (x_matches_output_shape) {
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
} else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK();
}
// x OR true = true OR y = true.
bool updated_graph = false;
const PartialTensorShape shp(output_shape);
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
bool replace_succeed = false;
Status replace_op_status = ReplaceOperationWithConstant(
1, properties, output_shape, node, optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
updated_graph = true;
}
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
1, properties, output_shape, node, optimized_graph));
return Status::OK();
}
// Simplify multiplication and matmul by zeros.
@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations(
if ((x_is_zero || y_is_zero) &&
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
if (shp.IsFullyDefined()) {
bool replace_succeed = false;
Status replace_op_status =
ReplaceOperationWithConstant(0, properties, output_shape, node,
optimized_graph, &replace_succeed);
if (!replace_op_status.ok()) {
return replace_op_status;
} else if (replace_succeed) {
if (is_quantized_matmul) {
TF_RETURN_IF_ERROR(
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
}
*success = true;
return Status::OK();
bool is_quantized = IsQuantizedMatMul(*node);
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
0, properties, output_shape, node, optimized_graph));
if (is_quantized && graph_modified_) {
TF_RETURN_IF_ERROR(
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
}
return Status::OK();
}
// Even if an input shape is only partially known, we may known that it
// matches the output shape and thus forward the corresponding zero
// input.
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
*success = true;
// matches the output shape and thus forward or broadcast the
// corresponding zero input.
if ((is_mul || is_any_div) && x_is_zero) {
if (x_matches_output_shape) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} else if (y_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
optimized_graph);
}
return Status::OK();
} else if (is_mul && y_is_zero && y_matches_output_shape) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
*success = true;
} else if (is_mul && y_is_zero) {
if (y_matches_output_shape) {
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
} else if (x_matches_output_shape) {
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
optimized_graph);
}
return Status::OK();
}
}
if (updated_graph) {
*success = true;
return Status::OK();
}
}
*success = false;
return Status::OK();
}
@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
auto add_quantized_out = [this, node, optimized_graph](
const string& out_const_name, int index) {
NodeDef* out_node = optimized_graph->add_node();
graph_modified_ = true;
Tensor value(DT_FLOAT, TensorShape({}));
const bool is_min = index == 1;
const DataType type_attr = node->attr().at("dtype").type();
@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
node_map_->AddNode(out_const_name, out_node);
out_node->set_device(node->device());
// Copy all inputs from node.
out_node->mutable_input()->CopyFrom(node->input());
for (const string& input : out_node->input()) {

View File

@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer {
void ReplaceOperationWithSnapshot(int input_to_forward,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
const GraphProperties& properties,
NodeDef* node, GraphDef* graph);
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
Status ReplaceOperationWithConstant(double value,
const GraphProperties& properties,
const TensorShapeProto& shape,
NodeDef* node, GraphDef* graph,
bool* success);
NodeDef* node, GraphDef* graph);
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
Status FoldGraph(GraphDef* output,
absl::flat_hash_set<string>* nodes_to_not_simplify);
@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer {
// was applied.
Status SimplifyArithmeticOperations(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Reshape operation to an Identity operation if applicable.
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer {
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Squeeze operation to an Identity operation if applicable.
bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Pad operation to an Identity operation if applicable.
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Simplifies a Tile operation to an Identity operation if applicable.
Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Simplifies a StridedSlice operation to an Identity operation if applicable.
Status SimplifyStridedSlice(const GraphProperties& properties,
bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node, bool* success);
NodeDef* node);
// Simplifies a Slice operation to an Identity operation if applicable.
Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Removes Reverse op over dimensions with size 1.
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node, bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Removes RandomShuffle op if it is scalar or first dimension is of size 1.
bool RemoveRandomShuffle(const GraphProperties& properties,
void RemoveRandomShuffle(const GraphProperties& properties,
bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node);
// Removes Shuffle or Transpose op over dimensions of size 1.
Status RemoveShuffleOrTranspose(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node,
bool* success);
GraphDef* optimized_graph, NodeDef* node);
// Removes Split or SplitV node if possible.
bool RemoveSplitOrSplitV(const GraphProperties& properties,
void RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,

View File

@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
ops::Placeholder::Shape(TensorShape({2})));
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
Output zeros_const_bcast =
ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
Output zeros = const_type == kConst
? zeros_const
: (const_type == kLike ? zeros_like : zeros_fill);
Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
Output ones_const_bcast =
ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
Output ones = const_type == kConst
@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
: (const_type == kLike ? ones_like : ones_fill);
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
Output mul1_bcast =
ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
Output mul2_bcast =
ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
Output add1_bcast =
ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
Output add2_bcast =
ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
item.fetch = {"stack", "matmul3", "matmul4"};
item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
"mul2_bcast", "add1_bcast", "add2_bcast"};
ConstantFolding optimizer(/*cpu_device=*/nullptr);
GraphDef output;
@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
const string ones_name = strings::StrCat("ones", suffix);
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
for (int i = 0; i < output.node_size(); ++i) {
const NodeDef& node = output.node(i);
const string& name = node.name();
@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ("^y", node.input(1));
} else if (name == "mul1_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^ones_const_bcast", node.input(2));
} else if (name == "mul2_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^ones_const_bcast", node.input(2));
} else if (name == "mul3") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ(ctrl_zeros_name, node.input(1));
} else if (name == "add1_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_const_bcast", node.input(2));
} else if (name == "add2_bcast") {
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("y", node.input(0));
EXPECT_EQ("^zeros_const_bcast", node.input(2));
} else if (name == "bias_add1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));
EXPECT_EQ("^zeros_1d", node.input(1));
} else if (name == "bias_add2") {
// We don't eliminate this one, because it requires broadcasting.
EXPECT_EQ("BiasAdd", node.op());
EXPECT_EQ(zeros_name, node.input(0));
EXPECT_EQ("bias", node.input(1));
EXPECT_EQ("BroadcastTo", node.op());
EXPECT_EQ("bias", node.input(0));
EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
node.input(1));
EXPECT_EQ(ctrl_zeros_name, node.input(2));
} else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
EXPECT_EQ("Const", node.op());
EXPECT_EQ(ctrl_zeros_name, node.input(0));
EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
TensorProto t = node.attr().at("value").tensor();
EXPECT_EQ(DT_INT32, t.dtype());
EXPECT_EQ(1, t.tensor_shape().dim_size());
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
} else if (name == "sub1") {
EXPECT_EQ("Identity", node.op());
EXPECT_EQ("x", node.input(0));

View File

@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
}
REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
DataType itype;
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
if (itype != DT_INT32) {
return errors::Unimplemented(
"BroadcastToGrad for int64 index are not supported.");
}
std::vector<FDH::Node> nodes = {
{{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
{{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
{{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
{{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
*g = FDH::Define(
// Arg defs
{"x: T", "shape: int32", "dy: T"},
// Ret val defs
{"dx: T", "dshape: Tidx"},
// Attr defs
{{"T: type"}, {"Tidx: {int32, int64}"}},
// Nodes
nodes);
return Status::OK();
}
REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
} // end namespace tensorflow

View File

@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) {
}
}
std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape,
const Tensor& dy) {
auto T = DT_FLOAT;
auto Tidx = DT_INT32;
auto gdef = test::function::GDef(
{f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}),
f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
f::NDef(
"dx", "SymbolicGradient", {"x", "shape", "dy"},
{{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})},
{"Tin", DataTypeSlice{T, Tidx, T}},
{"Tout", DataTypeSlice{T, Tidx}}})});
VLOG(1) << DebugStringWhole(gdef);
auto sess = NewSession();
TF_CHECK_OK(sess->Create(gdef));
std::vector<Tensor> out;
TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}},
{"dx:0", "dx:1"}, {}, &out));
CHECK_EQ(out.size(), 2);
TF_CHECK_OK(sess->Close());
return out;
}
TEST(ArrayGradTest, BroadcastToGrad) {
Tensor x(DT_FLOAT, {2, 2});
x.flat<float>().setZero();
Tensor shape(DT_INT32, {3});
test::FillValues<int32>(&shape, {2, 2, 2});
Tensor dy(DT_FLOAT, {2, 2, 2});
test::FillIota<float>(&dy, 0);
auto dx = BroadcastToGrad(x, shape, dy);
test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2}));
test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3}));
}
} // namespace
} // namespace tensorflow