Add broadcast optimization rewrite to Grappler. This generalizes the rewrites for algebraically neutral element simplications (e.g. x + zeros(shape) => x ) to handle the cases where the output shape differs from the input shape of the non-trivial argument. This pattern is sometimes used in tensorflow for broadcasting.
Example of new rewrites enabled by this change: x * ones_like(y) => broadcast_to(x, shape(y)) zeros_like(x) + y => broadcast_to(y, shape(x)) This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated. PiperOrigin-RevId: 243319718
This commit is contained in:
parent
a3d2ee6ada
commit
ef1ca5b2e1
@ -136,14 +136,15 @@ load(
|
||||
"tf_additional_libdevice_srcs",
|
||||
"tf_additional_minimal_lib_srcs",
|
||||
"tf_additional_mpi_lib_defines",
|
||||
"tf_additional_numa_copts",
|
||||
"tf_additional_numa_deps",
|
||||
"tf_additional_numa_lib_defines",
|
||||
"tf_additional_numa_copts",
|
||||
"tf_additional_proto_hdrs",
|
||||
"tf_additional_proto_srcs",
|
||||
"tf_additional_test_deps",
|
||||
"tf_additional_test_srcs",
|
||||
"tf_additional_verbs_lib_defines",
|
||||
"tf_grpc_service_all",
|
||||
"tf_jspb_proto_library",
|
||||
"tf_kernel_tests_linkstatic",
|
||||
"tf_lib_proto_compiler_deps",
|
||||
@ -157,7 +158,6 @@ load(
|
||||
"tf_protos_grappler",
|
||||
"tf_protos_grappler_impl",
|
||||
"tf_pyclif_proto_library",
|
||||
"tf_grpc_service_all",
|
||||
)
|
||||
load(
|
||||
":platform/default/build_config_root.bzl",
|
||||
@ -4919,6 +4919,7 @@ tf_cc_test(
|
||||
"//tensorflow/core/kernels:array",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:function_ops",
|
||||
"//tensorflow/core/kernels:math",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode(
|
||||
} else {
|
||||
double value =
|
||||
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
|
||||
bool success = false;
|
||||
if (value >= 0) {
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
value, properties, output_shape, node, graph_, &success));
|
||||
value, properties, output_shape, node, graph_));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
||||
graph_modified_ = true;
|
||||
}
|
||||
|
||||
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
||||
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) return;
|
||||
const PartialTensorShape shape(
|
||||
properties.GetOutputProperties(node->name())[0].shape());
|
||||
if (!shape.IsFullyDefined()) return;
|
||||
|
||||
// Create constant node with shape.
|
||||
const string const_name = OptimizedNodeName(
|
||||
*node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
|
||||
if (node_map_->GetNode(const_name) != nullptr) {
|
||||
return;
|
||||
}
|
||||
|
||||
Tensor shape_t;
|
||||
if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return;
|
||||
NodeDef tmp;
|
||||
if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return;
|
||||
NodeDef* const_node = graph->add_node();
|
||||
const_node->Swap(&tmp);
|
||||
const_node->set_device(node->device());
|
||||
node_map_->AddNode(const_name, const_node);
|
||||
// Add a control input on the unused input.
|
||||
string ctrl_dep = AddControlDependency(
|
||||
NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get());
|
||||
*const_node->add_input() = ctrl_dep;
|
||||
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
|
||||
|
||||
// Rewrite `node` in-place to BroadcastTo.
|
||||
node->set_op("BroadcastTo");
|
||||
node->clear_attr();
|
||||
(*node->mutable_attr())["T"].set_type(dtype);
|
||||
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
|
||||
// Set the designated input to BroadcastTo.
|
||||
node->mutable_input()->SwapElements(0, input_to_broadcast);
|
||||
// Keep all other inputs as control dependencies.
|
||||
for (int i = 1; i < node->input_size(); ++i) {
|
||||
if (IsControlInput(node->input(i))) {
|
||||
break;
|
||||
}
|
||||
const string ctrl_dep =
|
||||
AddControlDependency(node->input(i), graph, node_map_.get());
|
||||
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||
node->set_input(i, ctrl_dep);
|
||||
}
|
||||
// Add the shape argument.
|
||||
*node->add_input() = const_node->name();
|
||||
node_map_->AddOutput(const_name, node->name());
|
||||
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
||||
graph_modified_ = true;
|
||||
}
|
||||
|
||||
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
|
||||
GraphDef* graph) {
|
||||
node->set_op("Reciprocal");
|
||||
@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
|
||||
|
||||
Status ConstantFolding::ReplaceOperationWithConstant(
|
||||
double value, const GraphProperties& properties,
|
||||
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
|
||||
bool* success) {
|
||||
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
|
||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||
if (dtype == DT_INVALID) {
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant(
|
||||
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||
node->set_input(i, ctrl_dep);
|
||||
}
|
||||
*success = true;
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
|
||||
TF_RETURN_IF_ERROR(EXPR); \
|
||||
if (graph_modified_) return Status::OK()
|
||||
|
||||
#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
|
||||
graph_modified_ = EXPR; \
|
||||
if (graph_modified_) return Status::OK()
|
||||
|
||||
#define RETURN_IF_MODIFIED(EXPR) \
|
||||
EXPR; \
|
||||
if (graph_modified_) return Status::OK()
|
||||
|
||||
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
GraphDef* optimized_graph,
|
||||
GraphProperties* properties) {
|
||||
if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
bool graph_modified_cached = graph_modified_;
|
||||
graph_modified_ = false;
|
||||
|
||||
bool remove_shuffle_transpose_successful = false;
|
||||
Status remove_shuffle_transpose_status =
|
||||
RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
|
||||
node, &remove_shuffle_transpose_successful);
|
||||
if (!remove_shuffle_transpose_status.ok()) {
|
||||
return remove_shuffle_transpose_status;
|
||||
} else if (remove_shuffle_transpose_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool remove_reverse_successful = false;
|
||||
Status remove_reverse_status =
|
||||
RemoveReverse(*properties, use_shape_info, optimized_graph, node,
|
||||
&remove_reverse_successful);
|
||||
if (!remove_reverse_status.ok()) {
|
||||
return remove_reverse_status;
|
||||
} else if (remove_reverse_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool simplify_slice_successful = false;
|
||||
Status simplify_slice_status =
|
||||
SimplifySlice(*properties, use_shape_info, optimized_graph, node,
|
||||
&simplify_slice_successful);
|
||||
if (!simplify_slice_status.ok()) {
|
||||
return simplify_slice_status;
|
||||
} else if (simplify_slice_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool simplify_strided_slice_successful = false;
|
||||
Status simplify_strided_slice_status =
|
||||
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
|
||||
&simplify_strided_slice_successful);
|
||||
if (!simplify_strided_slice_status.ok()) {
|
||||
return simplify_strided_slice_status;
|
||||
} else if (simplify_strided_slice_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool simplify_tile_successful = false;
|
||||
Status simplify_tile_status =
|
||||
SimplifyTile(*properties, use_shape_info, optimized_graph, node,
|
||||
&simplify_tile_successful);
|
||||
if (!simplify_tile_status.ok()) {
|
||||
return simplify_tile_status;
|
||||
} else if (simplify_tile_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool simplify_pad_successful = false;
|
||||
Status simplify_pad_status =
|
||||
SimplifyPad(*properties, use_shape_info, optimized_graph, node,
|
||||
&simplify_pad_successful);
|
||||
if (!simplify_pad_status.ok()) {
|
||||
return simplify_pad_status;
|
||||
} else if (simplify_pad_successful) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (SimplifyPack(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (MoveConstantsPastEnter(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (SimplifySwitch(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (SimplifyReduction(optimized_graph, *properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (SimplifyReshape(*properties, use_shape_info, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool arithmetic_simplification_succeed = false;
|
||||
Status simplify_arithmetic_status =
|
||||
SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
|
||||
node, &arithmetic_simplification_succeed);
|
||||
if (!simplify_arithmetic_status.ok()) {
|
||||
return simplify_arithmetic_status;
|
||||
} else if (arithmetic_simplification_succeed) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ReduceDivToReciprocalMul(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ConstantPushDown(optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (MulConvPushDown(optimized_graph, node, *properties)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialConstPropThroughIdentityN(node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
|
||||
graph_modified_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
|
||||
*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_MODIFIED(
|
||||
RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(
|
||||
RemoveReverse(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(
|
||||
SimplifySlice(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(
|
||||
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(
|
||||
SimplifyTile(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(
|
||||
SimplifyPad(*properties, use_shape_info, optimized_graph, node));
|
||||
RETURN_IF_MODIFIED(
|
||||
SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
SimplifyReduction(optimized_graph, *properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
SimplifyReshape(*properties, use_shape_info, node));
|
||||
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
|
||||
*properties, use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MulConvPushDown(optimized_graph, node, *properties));
|
||||
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialAssocOpConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MergeConcat(*properties, use_shape_info, optimized_graph, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (node->attr().count("num_split") == 0) return false;
|
||||
if (node->attr().count("num_split") == 0) return;
|
||||
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||
return true;
|
||||
}
|
||||
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ConstantFolding::RemoveShuffleOrTranspose(
|
||||
const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||
@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
||||
|
||||
void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
||||
if (!shape.unknown_rank() &&
|
||||
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success) {
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (use_shape_info && node->op() == "ReverseV2" &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||
@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success) {
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (use_shape_info && IsSlice(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 3) {
|
||||
const auto& input = properties.GetInputProperties(node->name())[0];
|
||||
@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node, bool* success) {
|
||||
NodeDef* node) {
|
||||
if (use_shape_info && IsStridedSlice(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 4) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success) {
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (use_shape_info && IsTile(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 2) {
|
||||
const auto& m = properties.GetInputProperties(node->name())[1];
|
||||
@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success) {
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (use_shape_info && IsPad(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
const auto& p = properties.GetInputProperties(node->name())[1];
|
||||
@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
||||
void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
@ -2263,95 +2202,92 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
|
||||
!OptimizedNodeExists(*node, "_const_axis")) {
|
||||
// Create constant axis node.
|
||||
Tensor axis_t(DT_INT32, TensorShape({}));
|
||||
NodeDef* axis_node = optimized_graph->add_node();
|
||||
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
|
||||
const int axis =
|
||||
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
|
||||
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
|
||||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
// Add a control dependency to make sure axis_node is in the right frame.
|
||||
const string ctrl_dep = ConstantFolding::AddControlDependency(
|
||||
node->input(0), optimized_graph, node_map_.get());
|
||||
axis_node->add_input(ctrl_dep);
|
||||
axis_node->set_device(node->device());
|
||||
node->set_op("ExpandDims");
|
||||
if (node->attr().count("axis") != 0) {
|
||||
node->mutable_attr()->erase("axis");
|
||||
}
|
||||
if (node->attr().count("N") != 0) {
|
||||
node->mutable_attr()->erase("N");
|
||||
}
|
||||
(*node->mutable_attr())["Tdim"].set_type(DT_INT32);
|
||||
node->add_input(axis_node->name());
|
||||
if (node->input_size() > 2) {
|
||||
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
||||
}
|
||||
return true;
|
||||
if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 &&
|
||||
!OptimizedNodeExists(*node, "_const_axis"))) {
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
// Create constant axis node.
|
||||
Tensor axis_t(DT_INT32, TensorShape({}));
|
||||
NodeDef* axis_node = optimized_graph->add_node();
|
||||
axis_node->set_name(OptimizedNodeName(*node, "_const_axis"));
|
||||
const int axis =
|
||||
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
|
||||
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
|
||||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) {
|
||||
return false;
|
||||
}
|
||||
// Add a control dependency to make sure axis_node is in the right frame.
|
||||
const string ctrl_dep = ConstantFolding::AddControlDependency(
|
||||
node->input(0), optimized_graph, node_map_.get());
|
||||
axis_node->add_input(ctrl_dep);
|
||||
axis_node->set_device(node->device());
|
||||
node->set_op("ExpandDims");
|
||||
if (node->attr().count("axis") != 0) {
|
||||
node->mutable_attr()->erase("axis");
|
||||
}
|
||||
if (node->attr().count("N") != 0) {
|
||||
node->mutable_attr()->erase("N");
|
||||
}
|
||||
(*node->mutable_attr())["Tdim"].set_type(DT_INT32);
|
||||
node->add_input(axis_node->name());
|
||||
if (node->input_size() > 2) {
|
||||
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (IsEnter(*node) && node->input_size() > 0) {
|
||||
if (node->attr().count("is_constant") == 0 ||
|
||||
!node->attr().at("is_constant").b()) {
|
||||
return false;
|
||||
}
|
||||
const string& node_name = node->name();
|
||||
const NodeDef* input = node_map_->GetNode(node->input(0));
|
||||
if (input != nullptr && IsReallyConstant(*input) &&
|
||||
!OptimizedNodeExists(*input, "_enter")) {
|
||||
auto fanouts = node_map_->GetOutputs(node_name);
|
||||
// Find non-constant nodes that consume the output of *node.
|
||||
std::vector<NodeDef*> consumers;
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
if (!IsConstant(*fanout)) {
|
||||
for (int i = 0; i < fanout->input_size(); ++i) {
|
||||
if (fanout->input(i) == node_name) {
|
||||
consumers.push_back(fanout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!IsEnter(*node) || node->input_size() == 0 ||
|
||||
node->attr().count("is_constant") == 0 ||
|
||||
!node->attr().at("is_constant").b()) {
|
||||
return false;
|
||||
}
|
||||
const string& node_name = node->name();
|
||||
const NodeDef* input = node_map_->GetNode(node->input(0));
|
||||
if (input == nullptr || !IsReallyConstant(*input) ||
|
||||
OptimizedNodeExists(*input, "_enter")) {
|
||||
return false;
|
||||
}
|
||||
auto fanouts = node_map_->GetOutputs(node_name);
|
||||
// Find non-constant nodes that consume the output of *node.
|
||||
std::vector<NodeDef*> consumers;
|
||||
for (NodeDef* fanout : fanouts) {
|
||||
if (!IsConstant(*fanout)) {
|
||||
for (int i = 0; i < fanout->input_size(); ++i) {
|
||||
if (fanout->input(i) == node_name) {
|
||||
consumers.push_back(fanout);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!consumers.empty()) {
|
||||
NodeDef* new_node = optimized_graph->add_node();
|
||||
*new_node = *input;
|
||||
new_node->set_name(OptimizedNodeName(*input, "_enter"));
|
||||
new_node->set_device(node->device());
|
||||
new_node->clear_input();
|
||||
new_node->add_input(AsControlDependency(node_name));
|
||||
node_map_->AddNode(new_node->name(), new_node);
|
||||
node_map_->AddOutput(node_name, new_node->name());
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (NodeName(consumer->input(i)) == node_name) {
|
||||
node_map_->UpdateInput(consumer->name(), node_name,
|
||||
new_node->name());
|
||||
consumer->set_input(i, new_node->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return false;
|
||||
if (consumers.empty()) {
|
||||
return false;
|
||||
}
|
||||
graph_modified_ = true;
|
||||
NodeDef* new_node = optimized_graph->add_node();
|
||||
*new_node = *input;
|
||||
new_node->set_name(OptimizedNodeName(*input, "_enter"));
|
||||
new_node->set_device(node->device());
|
||||
new_node->clear_input();
|
||||
new_node->add_input(AsControlDependency(node_name));
|
||||
node_map_->AddNode(new_node->name(), new_node);
|
||||
node_map_->AddOutput(node_name, new_node->name());
|
||||
for (NodeDef* consumer : consumers) {
|
||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||
if (NodeName(consumer->input(i)) == node_name) {
|
||||
node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
|
||||
consumer->set_input(i, new_node->name());
|
||||
}
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
||||
@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
||||
return n1->name() < n2->name();
|
||||
});
|
||||
// Create constant false & true nodes.
|
||||
NodeDef* false_node = optimized_graph->add_node();
|
||||
false_node->set_name(OptimizedNodeName(*node, "_const_false"));
|
||||
if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
|
||||
NodeDef tmp_false_node;
|
||||
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
|
||||
if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
|
||||
&tmp_false_node)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
false_node->set_device(node->device());
|
||||
tmp_false_node.set_device(node->device());
|
||||
NodeDef tmp_true_node;
|
||||
tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
|
||||
if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
|
||||
&tmp_true_node)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
tmp_true_node.set_device(node->device());
|
||||
|
||||
// Add const nodes to graph.
|
||||
NodeDef* false_node = optimized_graph->add_node();
|
||||
false_node->Swap(&tmp_false_node);
|
||||
NodeDef* true_node = optimized_graph->add_node();
|
||||
true_node->set_name(OptimizedNodeName(*node, "_const_true"));
|
||||
if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
|
||||
.ok()) {
|
||||
return false;
|
||||
}
|
||||
true_node->set_device(node->device());
|
||||
true_node->Swap(&tmp_true_node);
|
||||
|
||||
// Add controls from the switch ports to the constants, and connect the
|
||||
// constants to the original switch outputs.
|
||||
@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
|
||||
|
||||
Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
||||
*success = false;
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
|
||||
const bool is_matmul = IsAnyMatMul(*node);
|
||||
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
|
||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||
const bool is_sub = IsSub(*node);
|
||||
const bool is_any_div = IsAnyDiv(*node);
|
||||
@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
// of zeros.
|
||||
const TensorShapeProto& y_shape =
|
||||
properties.GetInputProperties(node->name())[1].shape();
|
||||
const bool x_is_zero = IsZeros(*x);
|
||||
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
|
||||
const TensorShapeProto& x_shape =
|
||||
properties.GetInputProperties(node->name())[0].shape();
|
||||
const bool y_matches_output_shape =
|
||||
ShapesSymbolicallyEqual(output_shape, y_shape);
|
||||
if (y_matches_output_shape &&
|
||||
((is_mul && x_is_one) || (is_add && x_is_zero))) {
|
||||
const bool x_matches_output_shape =
|
||||
ShapesSymbolicallyEqual(output_shape, x_shape);
|
||||
|
||||
const bool x_is_zero = IsZeros(*x);
|
||||
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
|
||||
if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
|
||||
// 1 * y = y or 0 + y = y.
|
||||
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
if (y_matches_output_shape) {
|
||||
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
|
||||
} else if (x_matches_output_shape) {
|
||||
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (y_matches_output_shape && (is_sub && x_is_zero)) {
|
||||
// Replace 0 - y with Neg(y).
|
||||
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
DataType type = node->attr().at("T").type();
|
||||
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
|
||||
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
const TensorShapeProto& x_shape =
|
||||
properties.GetInputProperties(node->name())[0].shape();
|
||||
const bool y_is_zero = IsZeros(*y);
|
||||
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
|
||||
const bool x_matches_output_shape =
|
||||
ShapesSymbolicallyEqual(output_shape, x_shape);
|
||||
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
|
||||
((is_add || is_sub) && y_is_zero))) {
|
||||
if (((is_mul || is_any_div) && y_is_one) ||
|
||||
((is_add || is_sub) && y_is_zero)) {
|
||||
// x * 1 = x or x / 1 = x or x +/- 0 = x
|
||||
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
if (x_matches_output_shape) {
|
||||
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
|
||||
} else if (y_matches_output_shape) {
|
||||
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// x OR true = true OR y = true.
|
||||
bool updated_graph = false;
|
||||
const PartialTensorShape shp(output_shape);
|
||||
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
|
||||
bool replace_succeed = false;
|
||||
Status replace_op_status = ReplaceOperationWithConstant(
|
||||
1, properties, output_shape, node, optimized_graph, &replace_succeed);
|
||||
if (!replace_op_status.ok()) {
|
||||
return replace_op_status;
|
||||
} else if (replace_succeed) {
|
||||
updated_graph = true;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
1, properties, output_shape, node, optimized_graph));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Simplify multiplication and matmul by zeros.
|
||||
@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
||||
if ((x_is_zero || y_is_zero) &&
|
||||
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
|
||||
if (shp.IsFullyDefined()) {
|
||||
bool replace_succeed = false;
|
||||
Status replace_op_status =
|
||||
ReplaceOperationWithConstant(0, properties, output_shape, node,
|
||||
optimized_graph, &replace_succeed);
|
||||
if (!replace_op_status.ok()) {
|
||||
return replace_op_status;
|
||||
} else if (replace_succeed) {
|
||||
if (is_quantized_matmul) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
|
||||
}
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
bool is_quantized = IsQuantizedMatMul(*node);
|
||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||
0, properties, output_shape, node, optimized_graph));
|
||||
if (is_quantized && graph_modified_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
// Even if an input shape is only partially known, we may known that it
|
||||
// matches the output shape and thus forward the corresponding zero
|
||||
// input.
|
||||
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
// matches the output shape and thus forward or broadcast the
|
||||
// corresponding zero input.
|
||||
if ((is_mul || is_any_div) && x_is_zero) {
|
||||
if (x_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
} else if (y_matches_output_shape) {
|
||||
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
return Status::OK();
|
||||
} else if (is_mul && y_is_zero && y_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||
*success = true;
|
||||
} else if (is_mul && y_is_zero) {
|
||||
if (y_matches_output_shape) {
|
||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||
} else if (x_matches_output_shape) {
|
||||
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
|
||||
optimized_graph);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
if (updated_graph) {
|
||||
*success = true;
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
*success = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
||||
auto add_quantized_out = [this, node, optimized_graph](
|
||||
const string& out_const_name, int index) {
|
||||
NodeDef* out_node = optimized_graph->add_node();
|
||||
graph_modified_ = true;
|
||||
Tensor value(DT_FLOAT, TensorShape({}));
|
||||
const bool is_min = index == 1;
|
||||
const DataType type_attr = node->attr().at("dtype").type();
|
||||
@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
||||
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
|
||||
node_map_->AddNode(out_const_name, out_node);
|
||||
out_node->set_device(node->device());
|
||||
|
||||
// Copy all inputs from node.
|
||||
out_node->mutable_input()->CopyFrom(node->input());
|
||||
for (const string& input : out_node->input()) {
|
||||
|
@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer {
|
||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
||||
const GraphProperties& properties,
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
|
||||
Status ReplaceOperationWithConstant(double value,
|
||||
const GraphProperties& properties,
|
||||
const TensorShapeProto& shape,
|
||||
NodeDef* node, GraphDef* graph,
|
||||
bool* success);
|
||||
NodeDef* node, GraphDef* graph);
|
||||
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
|
||||
Status FoldGraph(GraphDef* output,
|
||||
absl::flat_hash_set<string>* nodes_to_not_simplify);
|
||||
@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer {
|
||||
// was applied.
|
||||
Status SimplifyArithmeticOperations(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplifies a Reshape operation to an Identity operation if applicable.
|
||||
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
|
||||
@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer {
|
||||
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplifies a Squeeze operation to an Identity operation if applicable.
|
||||
bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
|
||||
void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplifies a Pad operation to an Identity operation if applicable.
|
||||
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplifies a Tile operation to an Identity operation if applicable.
|
||||
Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Simplifies a StridedSlice operation to an Identity operation if applicable.
|
||||
Status SimplifyStridedSlice(const GraphProperties& properties,
|
||||
bool use_shape_info, GraphDef* optimized_graph,
|
||||
NodeDef* node, bool* success);
|
||||
NodeDef* node);
|
||||
|
||||
// Simplifies a Slice operation to an Identity operation if applicable.
|
||||
Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes Reverse op over dimensions with size 1.
|
||||
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes RandomShuffle op if it is scalar or first dimension is of size 1.
|
||||
bool RemoveRandomShuffle(const GraphProperties& properties,
|
||||
void RemoveRandomShuffle(const GraphProperties& properties,
|
||||
bool use_shape_info, GraphDef* optimized_graph,
|
||||
NodeDef* node);
|
||||
|
||||
// Removes Shuffle or Transpose op over dimensions of size 1.
|
||||
Status RemoveShuffleOrTranspose(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node,
|
||||
bool* success);
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
// Removes Split or SplitV node if possible.
|
||||
bool RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
void RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
|
||||
|
@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
ops::Placeholder::Shape(TensorShape({2})));
|
||||
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
|
||||
Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
|
||||
Output zeros_const_bcast =
|
||||
ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
|
||||
Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
|
||||
Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
|
||||
Output zeros = const_type == kConst
|
||||
? zeros_const
|
||||
: (const_type == kLike ? zeros_like : zeros_fill);
|
||||
Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
|
||||
Output ones_const_bcast =
|
||||
ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
|
||||
Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
|
||||
Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
|
||||
Output ones = const_type == kConst
|
||||
@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
: (const_type == kLike ? ones_like : ones_fill);
|
||||
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
|
||||
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
|
||||
Output mul1_bcast =
|
||||
ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
|
||||
Output mul2_bcast =
|
||||
ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
|
||||
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
|
||||
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
|
||||
Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
|
||||
@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
|
||||
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
|
||||
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
|
||||
Output add1_bcast =
|
||||
ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
|
||||
Output add2_bcast =
|
||||
ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
|
||||
Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
|
||||
Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
|
||||
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
|
||||
@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
item.fetch = {"stack", "matmul3", "matmul4"};
|
||||
item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
|
||||
"mul2_bcast", "add1_bcast", "add2_bcast"};
|
||||
|
||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||
GraphDef output;
|
||||
@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
const string ones_name = strings::StrCat("ones", suffix);
|
||||
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
|
||||
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
|
||||
EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
|
||||
|
||||
EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
|
||||
for (int i = 0; i < output.node_size(); ++i) {
|
||||
const NodeDef& node = output.node(i);
|
||||
const string& name = node.name();
|
||||
@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(ctrl_zeros_name, node.input(0));
|
||||
EXPECT_EQ("^y", node.input(1));
|
||||
} else if (name == "mul1_bcast") {
|
||||
EXPECT_EQ("BroadcastTo", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("^ones_const_bcast", node.input(2));
|
||||
} else if (name == "mul2_bcast") {
|
||||
EXPECT_EQ("BroadcastTo", node.op());
|
||||
EXPECT_EQ("y", node.input(0));
|
||||
EXPECT_EQ("^ones_const_bcast", node.input(2));
|
||||
} else if (name == "mul3") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("y", node.input(0));
|
||||
EXPECT_EQ(ctrl_zeros_name, node.input(1));
|
||||
} else if (name == "add1_bcast") {
|
||||
EXPECT_EQ("BroadcastTo", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("^zeros_const_bcast", node.input(2));
|
||||
} else if (name == "add2_bcast") {
|
||||
EXPECT_EQ("BroadcastTo", node.op());
|
||||
EXPECT_EQ("y", node.input(0));
|
||||
EXPECT_EQ("^zeros_const_bcast", node.input(2));
|
||||
} else if (name == "bias_add1") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
EXPECT_EQ("^zeros_1d", node.input(1));
|
||||
} else if (name == "bias_add2") {
|
||||
// We don't eliminate this one, because it requires broadcasting.
|
||||
EXPECT_EQ("BiasAdd", node.op());
|
||||
EXPECT_EQ(zeros_name, node.input(0));
|
||||
EXPECT_EQ("bias", node.input(1));
|
||||
EXPECT_EQ("BroadcastTo", node.op());
|
||||
EXPECT_EQ("bias", node.input(0));
|
||||
EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
|
||||
node.input(1));
|
||||
EXPECT_EQ(ctrl_zeros_name, node.input(2));
|
||||
} else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
|
||||
EXPECT_EQ("Const", node.op());
|
||||
EXPECT_EQ(ctrl_zeros_name, node.input(0));
|
||||
EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
|
||||
TensorProto t = node.attr().at("value").tensor();
|
||||
EXPECT_EQ(DT_INT32, t.dtype());
|
||||
EXPECT_EQ(1, t.tensor_shape().dim_size());
|
||||
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
|
||||
} else if (name == "sub1") {
|
||||
EXPECT_EQ("Identity", node.op());
|
||||
EXPECT_EQ("x", node.input(0));
|
||||
|
@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
}
|
||||
REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
|
||||
|
||||
Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
DataType itype;
|
||||
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
|
||||
if (itype != DT_INT32) {
|
||||
return errors::Unimplemented(
|
||||
"BroadcastToGrad for int64 index are not supported.");
|
||||
}
|
||||
std::vector<FDH::Node> nodes = {
|
||||
{{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
|
||||
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
|
||||
{{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
|
||||
{{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
|
||||
{{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
|
||||
*g = FDH::Define(
|
||||
// Arg defs
|
||||
{"x: T", "shape: int32", "dy: T"},
|
||||
// Ret val defs
|
||||
{"dx: T", "dshape: Tidx"},
|
||||
// Attr defs
|
||||
{{"T: type"}, {"Tidx: {int32, int64}"}},
|
||||
// Nodes
|
||||
nodes);
|
||||
return Status::OK();
|
||||
}
|
||||
REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) {
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape,
|
||||
const Tensor& dy) {
|
||||
auto T = DT_FLOAT;
|
||||
auto Tidx = DT_INT32;
|
||||
auto gdef = test::function::GDef(
|
||||
{f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
|
||||
f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}),
|
||||
f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
|
||||
f::NDef(
|
||||
"dx", "SymbolicGradient", {"x", "shape", "dy"},
|
||||
{{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})},
|
||||
{"Tin", DataTypeSlice{T, Tidx, T}},
|
||||
{"Tout", DataTypeSlice{T, Tidx}}})});
|
||||
VLOG(1) << DebugStringWhole(gdef);
|
||||
auto sess = NewSession();
|
||||
TF_CHECK_OK(sess->Create(gdef));
|
||||
std::vector<Tensor> out;
|
||||
TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}},
|
||||
{"dx:0", "dx:1"}, {}, &out));
|
||||
CHECK_EQ(out.size(), 2);
|
||||
TF_CHECK_OK(sess->Close());
|
||||
return out;
|
||||
}
|
||||
|
||||
TEST(ArrayGradTest, BroadcastToGrad) {
|
||||
Tensor x(DT_FLOAT, {2, 2});
|
||||
x.flat<float>().setZero();
|
||||
Tensor shape(DT_INT32, {3});
|
||||
test::FillValues<int32>(&shape, {2, 2, 2});
|
||||
Tensor dy(DT_FLOAT, {2, 2, 2});
|
||||
test::FillIota<float>(&dy, 0);
|
||||
auto dx = BroadcastToGrad(x, shape, dy);
|
||||
test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2}));
|
||||
test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3}));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user