Add broadcast optimization rewrite to Grappler. This generalizes the rewrites for algebraically neutral element simplications (e.g. x + zeros(shape) => x ) to handle the cases where the output shape differs from the input shape of the non-trivial argument. This pattern is sometimes used in tensorflow for broadcasting.
Example of new rewrites enabled by this change: x * ones_like(y) => broadcast_to(x, shape(y)) zeros_like(x) + y => broadcast_to(y, shape(x)) This change also cleans up the code in SimplifyNode to consistently rely on only graph_modified_ (and possibly an error status) to signal if the graph was updated. PiperOrigin-RevId: 243319718
This commit is contained in:
parent
a3d2ee6ada
commit
ef1ca5b2e1
@ -136,14 +136,15 @@ load(
|
|||||||
"tf_additional_libdevice_srcs",
|
"tf_additional_libdevice_srcs",
|
||||||
"tf_additional_minimal_lib_srcs",
|
"tf_additional_minimal_lib_srcs",
|
||||||
"tf_additional_mpi_lib_defines",
|
"tf_additional_mpi_lib_defines",
|
||||||
|
"tf_additional_numa_copts",
|
||||||
"tf_additional_numa_deps",
|
"tf_additional_numa_deps",
|
||||||
"tf_additional_numa_lib_defines",
|
"tf_additional_numa_lib_defines",
|
||||||
"tf_additional_numa_copts",
|
|
||||||
"tf_additional_proto_hdrs",
|
"tf_additional_proto_hdrs",
|
||||||
"tf_additional_proto_srcs",
|
"tf_additional_proto_srcs",
|
||||||
"tf_additional_test_deps",
|
"tf_additional_test_deps",
|
||||||
"tf_additional_test_srcs",
|
"tf_additional_test_srcs",
|
||||||
"tf_additional_verbs_lib_defines",
|
"tf_additional_verbs_lib_defines",
|
||||||
|
"tf_grpc_service_all",
|
||||||
"tf_jspb_proto_library",
|
"tf_jspb_proto_library",
|
||||||
"tf_kernel_tests_linkstatic",
|
"tf_kernel_tests_linkstatic",
|
||||||
"tf_lib_proto_compiler_deps",
|
"tf_lib_proto_compiler_deps",
|
||||||
@ -157,7 +158,6 @@ load(
|
|||||||
"tf_protos_grappler",
|
"tf_protos_grappler",
|
||||||
"tf_protos_grappler_impl",
|
"tf_protos_grappler_impl",
|
||||||
"tf_pyclif_proto_library",
|
"tf_pyclif_proto_library",
|
||||||
"tf_grpc_service_all",
|
|
||||||
)
|
)
|
||||||
load(
|
load(
|
||||||
":platform/default/build_config_root.bzl",
|
":platform/default/build_config_root.bzl",
|
||||||
@ -4919,6 +4919,7 @@ tf_cc_test(
|
|||||||
"//tensorflow/core/kernels:array",
|
"//tensorflow/core/kernels:array",
|
||||||
"//tensorflow/core/kernels:cwise_op",
|
"//tensorflow/core/kernels:cwise_op",
|
||||||
"//tensorflow/core/kernels:function_ops",
|
"//tensorflow/core/kernels:function_ops",
|
||||||
|
"//tensorflow/core/kernels:math",
|
||||||
"//third_party/eigen3",
|
"//third_party/eigen3",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
@ -806,10 +806,9 @@ Status ConstantFolding::MaterializeConstantValuedNode(
|
|||||||
} else {
|
} else {
|
||||||
double value =
|
double value =
|
||||||
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
|
(IsZerosLike(*node) ? 0.0 : (IsOnesLike(*node) ? 1.0 : -1.0));
|
||||||
bool success = false;
|
|
||||||
if (value >= 0) {
|
if (value >= 0) {
|
||||||
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||||
value, properties, output_shape, node, graph_, &success));
|
value, properties, output_shape, node, graph_));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -1672,6 +1671,60 @@ void ConstantFolding::ReplaceOperationWithSnapshot(
|
|||||||
graph_modified_ = true;
|
graph_modified_ = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ConstantFolding::ReplaceBinaryOperationWithBroadcastTo(
|
||||||
|
int input_to_broadcast, const GraphProperties& properties, NodeDef* node,
|
||||||
|
GraphDef* graph) {
|
||||||
|
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||||
|
if (dtype == DT_INVALID) return;
|
||||||
|
const PartialTensorShape shape(
|
||||||
|
properties.GetOutputProperties(node->name())[0].shape());
|
||||||
|
if (!shape.IsFullyDefined()) return;
|
||||||
|
|
||||||
|
// Create constant node with shape.
|
||||||
|
const string const_name = OptimizedNodeName(
|
||||||
|
*node, strings::StrCat("-broadcastto_shape-", input_to_broadcast));
|
||||||
|
if (node_map_->GetNode(const_name) != nullptr) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor shape_t;
|
||||||
|
if (!ConvertShapeToConstant("Shape", DT_INT32, shape, &shape_t).ok()) return;
|
||||||
|
NodeDef tmp;
|
||||||
|
if (!CreateNodeDef(const_name, TensorValue(&shape_t), &tmp).ok()) return;
|
||||||
|
NodeDef* const_node = graph->add_node();
|
||||||
|
const_node->Swap(&tmp);
|
||||||
|
const_node->set_device(node->device());
|
||||||
|
node_map_->AddNode(const_name, const_node);
|
||||||
|
// Add a control input on the unused input.
|
||||||
|
string ctrl_dep = AddControlDependency(
|
||||||
|
NodeName(node->input(1 - input_to_broadcast)), graph, node_map_.get());
|
||||||
|
*const_node->add_input() = ctrl_dep;
|
||||||
|
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
|
||||||
|
|
||||||
|
// Rewrite `node` in-place to BroadcastTo.
|
||||||
|
node->set_op("BroadcastTo");
|
||||||
|
node->clear_attr();
|
||||||
|
(*node->mutable_attr())["T"].set_type(dtype);
|
||||||
|
(*node->mutable_attr())["Tidx"].set_type(DT_INT32);
|
||||||
|
// Set the designated input to BroadcastTo.
|
||||||
|
node->mutable_input()->SwapElements(0, input_to_broadcast);
|
||||||
|
// Keep all other inputs as control dependencies.
|
||||||
|
for (int i = 1; i < node->input_size(); ++i) {
|
||||||
|
if (IsControlInput(node->input(i))) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const string ctrl_dep =
|
||||||
|
AddControlDependency(node->input(i), graph, node_map_.get());
|
||||||
|
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||||
|
node->set_input(i, ctrl_dep);
|
||||||
|
}
|
||||||
|
// Add the shape argument.
|
||||||
|
*node->add_input() = const_node->name();
|
||||||
|
node_map_->AddOutput(const_name, node->name());
|
||||||
|
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
||||||
|
graph_modified_ = true;
|
||||||
|
}
|
||||||
|
|
||||||
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
|
void ConstantFolding::ReplaceDivisionOfOnesByReciprocal(NodeDef* node,
|
||||||
GraphDef* graph) {
|
GraphDef* graph) {
|
||||||
node->set_op("Reciprocal");
|
node->set_op("Reciprocal");
|
||||||
@ -1696,11 +1749,9 @@ void ConstantFolding::ReplaceSubtractionFromZeroByNegation(NodeDef* node,
|
|||||||
|
|
||||||
Status ConstantFolding::ReplaceOperationWithConstant(
|
Status ConstantFolding::ReplaceOperationWithConstant(
|
||||||
double value, const GraphProperties& properties,
|
double value, const GraphProperties& properties,
|
||||||
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph,
|
const TensorShapeProto& shape, NodeDef* node, GraphDef* graph) {
|
||||||
bool* success) {
|
|
||||||
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
const DataType dtype = GetDataTypeFromNodeOrProps(*node, properties);
|
||||||
if (dtype == DT_INVALID) {
|
if (dtype == DT_INVALID) {
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1721,7 +1772,6 @@ Status ConstantFolding::ReplaceOperationWithConstant(
|
|||||||
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
node_map_->UpdateInput(node->name(), node->input(i), ctrl_dep);
|
||||||
node->set_input(i, ctrl_dep);
|
node->set_input(i, ctrl_dep);
|
||||||
}
|
}
|
||||||
*success = true;
|
|
||||||
graph_modified_ = true;
|
graph_modified_ = true;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
@ -1746,173 +1796,81 @@ Status ConstantFolding::SimplifyGraph(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define RETURN_IF_ERROR_OR_MODIFIED(EXPR) \
|
||||||
|
TF_RETURN_IF_ERROR(EXPR); \
|
||||||
|
if (graph_modified_) return Status::OK()
|
||||||
|
|
||||||
|
#define SET_AND_RETURN_IF_MODIFIED(EXPR) \
|
||||||
|
graph_modified_ = EXPR; \
|
||||||
|
if (graph_modified_) return Status::OK()
|
||||||
|
|
||||||
|
#define RETURN_IF_MODIFIED(EXPR) \
|
||||||
|
EXPR; \
|
||||||
|
if (graph_modified_) return Status::OK()
|
||||||
|
|
||||||
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
GraphProperties* properties) {
|
GraphProperties* properties) {
|
||||||
if (RemoveSplitOrSplitV(*properties, optimized_graph, node)) {
|
bool graph_modified_cached = graph_modified_;
|
||||||
return Status::OK();
|
graph_modified_ = false;
|
||||||
}
|
|
||||||
|
|
||||||
bool remove_shuffle_transpose_successful = false;
|
RETURN_IF_MODIFIED(RemoveSplitOrSplitV(*properties, optimized_graph, node));
|
||||||
Status remove_shuffle_transpose_status =
|
RETURN_IF_ERROR_OR_MODIFIED(RemoveShuffleOrTranspose(
|
||||||
RemoveShuffleOrTranspose(*properties, use_shape_info, optimized_graph,
|
*properties, use_shape_info, optimized_graph, node));
|
||||||
node, &remove_shuffle_transpose_successful);
|
RETURN_IF_MODIFIED(
|
||||||
if (!remove_shuffle_transpose_status.ok()) {
|
RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node));
|
||||||
return remove_shuffle_transpose_status;
|
RETURN_IF_ERROR_OR_MODIFIED(
|
||||||
} else if (remove_shuffle_transpose_successful) {
|
RemoveReverse(*properties, use_shape_info, optimized_graph, node));
|
||||||
return Status::OK();
|
RETURN_IF_ERROR_OR_MODIFIED(
|
||||||
}
|
SimplifySlice(*properties, use_shape_info, optimized_graph, node));
|
||||||
|
RETURN_IF_ERROR_OR_MODIFIED(
|
||||||
if (RemoveRandomShuffle(*properties, use_shape_info, optimized_graph, node)) {
|
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node));
|
||||||
return Status::OK();
|
RETURN_IF_ERROR_OR_MODIFIED(
|
||||||
}
|
SimplifyTile(*properties, use_shape_info, optimized_graph, node));
|
||||||
|
RETURN_IF_ERROR_OR_MODIFIED(
|
||||||
bool remove_reverse_successful = false;
|
SimplifyPad(*properties, use_shape_info, optimized_graph, node));
|
||||||
Status remove_reverse_status =
|
RETURN_IF_MODIFIED(
|
||||||
RemoveReverse(*properties, use_shape_info, optimized_graph, node,
|
SimplifySqueeze(*properties, use_shape_info, optimized_graph, node));
|
||||||
&remove_reverse_successful);
|
SET_AND_RETURN_IF_MODIFIED(SimplifyPack(optimized_graph, node));
|
||||||
if (!remove_reverse_status.ok()) {
|
SET_AND_RETURN_IF_MODIFIED(MoveConstantsPastEnter(optimized_graph, node));
|
||||||
return remove_reverse_status;
|
SET_AND_RETURN_IF_MODIFIED(SimplifySwitch(optimized_graph, node));
|
||||||
} else if (remove_reverse_successful) {
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
return Status::OK();
|
SimplifyReduction(optimized_graph, *properties, node));
|
||||||
}
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
|
SimplifyReshape(*properties, use_shape_info, node));
|
||||||
bool simplify_slice_successful = false;
|
RETURN_IF_ERROR_OR_MODIFIED(SimplifyArithmeticOperations(
|
||||||
Status simplify_slice_status =
|
*properties, use_shape_info, optimized_graph, node));
|
||||||
SimplifySlice(*properties, use_shape_info, optimized_graph, node,
|
SET_AND_RETURN_IF_MODIFIED(ReduceDivToReciprocalMul(optimized_graph, node));
|
||||||
&simplify_slice_successful);
|
SET_AND_RETURN_IF_MODIFIED(ConstantPushDown(optimized_graph, node));
|
||||||
if (!simplify_slice_status.ok()) {
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
return simplify_slice_status;
|
MulConvPushDown(optimized_graph, node, *properties));
|
||||||
} else if (simplify_slice_successful) {
|
SET_AND_RETURN_IF_MODIFIED(PartialConstPropThroughIdentityN(node));
|
||||||
return Status::OK();
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
}
|
PartialAssocOpConstFolding(optimized_graph, properties, node));
|
||||||
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
bool simplify_strided_slice_successful = false;
|
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||||
Status simplify_strided_slice_status =
|
SET_AND_RETURN_IF_MODIFIED(
|
||||||
SimplifyStridedSlice(*properties, use_shape_info, optimized_graph, node,
|
MergeConcat(*properties, use_shape_info, optimized_graph, node));
|
||||||
&simplify_strided_slice_successful);
|
|
||||||
if (!simplify_strided_slice_status.ok()) {
|
|
||||||
return simplify_strided_slice_status;
|
|
||||||
} else if (simplify_strided_slice_successful) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool simplify_tile_successful = false;
|
|
||||||
Status simplify_tile_status =
|
|
||||||
SimplifyTile(*properties, use_shape_info, optimized_graph, node,
|
|
||||||
&simplify_tile_successful);
|
|
||||||
if (!simplify_tile_status.ok()) {
|
|
||||||
return simplify_tile_status;
|
|
||||||
} else if (simplify_tile_successful) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool simplify_pad_successful = false;
|
|
||||||
Status simplify_pad_status =
|
|
||||||
SimplifyPad(*properties, use_shape_info, optimized_graph, node,
|
|
||||||
&simplify_pad_successful);
|
|
||||||
if (!simplify_pad_status.ok()) {
|
|
||||||
return simplify_pad_status;
|
|
||||||
} else if (simplify_pad_successful) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SimplifySqueeze(*properties, use_shape_info, optimized_graph, node)) {
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SimplifyPack(optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (MoveConstantsPastEnter(optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SimplifySwitch(optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SimplifyReduction(optimized_graph, *properties, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (SimplifyReshape(*properties, use_shape_info, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool arithmetic_simplification_succeed = false;
|
|
||||||
Status simplify_arithmetic_status =
|
|
||||||
SimplifyArithmeticOperations(*properties, use_shape_info, optimized_graph,
|
|
||||||
node, &arithmetic_simplification_succeed);
|
|
||||||
if (!simplify_arithmetic_status.ok()) {
|
|
||||||
return simplify_arithmetic_status;
|
|
||||||
} else if (arithmetic_simplification_succeed) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ReduceDivToReciprocalMul(optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (ConstantPushDown(optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (MulConvPushDown(optimized_graph, node, *properties)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (PartialConstPropThroughIdentityN(node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (PartialAssocOpConstFolding(optimized_graph, properties, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (PartialConcatConstFolding(optimized_graph, properties, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
if (MergeConcat(*properties, use_shape_info, optimized_graph, node)) {
|
|
||||||
graph_modified_ = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
|
graph_modified_ = graph_modified_cached;
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
if (node->attr().count("num_split") == 0) return false;
|
if (node->attr().count("num_split") == 0) return;
|
||||||
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
|
if (IsSplit(*node) && node->attr().at("num_split").i() == 1) {
|
||||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
|
if (IsSplitV(*node) && node->attr().at("num_split").i() == 1) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::RemoveShuffleOrTranspose(
|
Status ConstantFolding::RemoveShuffleOrTranspose(
|
||||||
const GraphProperties& properties, bool use_shape_info,
|
const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
GraphDef* optimized_graph, NodeDef* node) {
|
||||||
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
|
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
|
||||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||||
@ -1948,15 +1906,14 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
|
||||||
|
void ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
@ -1968,16 +1925,14 @@ bool ConstantFolding::RemoveRandomShuffle(const GraphProperties& properties,
|
|||||||
if (!shape.unknown_rank() &&
|
if (!shape.unknown_rank() &&
|
||||||
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
|
(shape.dim_size() == 0 || shape.dim(0).size() == 1)) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph,
|
||||||
bool* success) {
|
NodeDef* node) {
|
||||||
if (use_shape_info && node->op() == "ReverseV2" &&
|
if (use_shape_info && node->op() == "ReverseV2" &&
|
||||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||||
@ -2015,19 +1970,16 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph,
|
||||||
bool* success) {
|
NodeDef* node) {
|
||||||
if (use_shape_info && IsSlice(*node) &&
|
if (use_shape_info && IsSlice(*node) &&
|
||||||
properties.GetInputProperties(node->name()).size() == 3) {
|
properties.GetInputProperties(node->name()).size() == 3) {
|
||||||
const auto& input = properties.GetInputProperties(node->name())[0];
|
const auto& input = properties.GetInputProperties(node->name())[0];
|
||||||
@ -2064,19 +2016,17 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
NodeDef* node, bool* success) {
|
NodeDef* node) {
|
||||||
if (use_shape_info && IsStridedSlice(*node) &&
|
if (use_shape_info && IsStridedSlice(*node) &&
|
||||||
properties.GetInputProperties(node->name()).size() == 4) {
|
properties.GetInputProperties(node->name()).size() == 4) {
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
@ -2168,19 +2118,15 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph, NodeDef* node) {
|
||||||
bool* success) {
|
|
||||||
if (use_shape_info && IsTile(*node) &&
|
if (use_shape_info && IsTile(*node) &&
|
||||||
properties.GetInputProperties(node->name()).size() == 2) {
|
properties.GetInputProperties(node->name()).size() == 2) {
|
||||||
const auto& m = properties.GetInputProperties(node->name())[1];
|
const auto& m = properties.GetInputProperties(node->name())[1];
|
||||||
@ -2204,19 +2150,15 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph, NodeDef* node) {
|
||||||
bool* success) {
|
|
||||||
if (use_shape_info && IsPad(*node) &&
|
if (use_shape_info && IsPad(*node) &&
|
||||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||||
const auto& p = properties.GetInputProperties(node->name())[1];
|
const auto& p = properties.GetInputProperties(node->name())[1];
|
||||||
@ -2236,16 +2178,13 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
void ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph,
|
GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
@ -2263,15 +2202,15 @@ bool ConstantFolding::SimplifySqueeze(const GraphProperties& properties,
|
|||||||
}
|
}
|
||||||
if (replaceable) {
|
if (replaceable) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
return true;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
||||||
if (IsPack(*node) && NumNonControlInputs(*node) == 1 &&
|
if (!(IsPack(*node) && NumNonControlInputs(*node) == 1 &&
|
||||||
!OptimizedNodeExists(*node, "_const_axis")) {
|
!OptimizedNodeExists(*node, "_const_axis"))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
// Create constant axis node.
|
// Create constant axis node.
|
||||||
Tensor axis_t(DT_INT32, TensorShape({}));
|
Tensor axis_t(DT_INT32, TensorShape({}));
|
||||||
NodeDef* axis_node = optimized_graph->add_node();
|
NodeDef* axis_node = optimized_graph->add_node();
|
||||||
@ -2279,8 +2218,7 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
const int axis =
|
const int axis =
|
||||||
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
|
node->attr().count("axis") == 0 ? 0 : node->attr().at("axis").i();
|
||||||
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
|
if (!SetTensorValue(DT_INT32, axis, &axis_t).ok() ||
|
||||||
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node)
|
!CreateNodeDef(axis_node->name(), TensorValue(&axis_t), axis_node).ok()) {
|
||||||
.ok()) {
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// Add a control dependency to make sure axis_node is in the right frame.
|
// Add a control dependency to make sure axis_node is in the right frame.
|
||||||
@ -2301,21 +2239,21 @@ bool ConstantFolding::SimplifyPack(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
node->mutable_input()->SwapElements(1, node->input_size() - 1);
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
||||||
NodeDef* node) {
|
NodeDef* node) {
|
||||||
if (IsEnter(*node) && node->input_size() > 0) {
|
if (!IsEnter(*node) || node->input_size() == 0 ||
|
||||||
if (node->attr().count("is_constant") == 0 ||
|
node->attr().count("is_constant") == 0 ||
|
||||||
!node->attr().at("is_constant").b()) {
|
!node->attr().at("is_constant").b()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const string& node_name = node->name();
|
const string& node_name = node->name();
|
||||||
const NodeDef* input = node_map_->GetNode(node->input(0));
|
const NodeDef* input = node_map_->GetNode(node->input(0));
|
||||||
if (input != nullptr && IsReallyConstant(*input) &&
|
if (input == nullptr || !IsReallyConstant(*input) ||
|
||||||
!OptimizedNodeExists(*input, "_enter")) {
|
OptimizedNodeExists(*input, "_enter")) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
auto fanouts = node_map_->GetOutputs(node_name);
|
auto fanouts = node_map_->GetOutputs(node_name);
|
||||||
// Find non-constant nodes that consume the output of *node.
|
// Find non-constant nodes that consume the output of *node.
|
||||||
std::vector<NodeDef*> consumers;
|
std::vector<NodeDef*> consumers;
|
||||||
@ -2329,7 +2267,10 @@ bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (!consumers.empty()) {
|
if (consumers.empty()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
graph_modified_ = true;
|
||||||
NodeDef* new_node = optimized_graph->add_node();
|
NodeDef* new_node = optimized_graph->add_node();
|
||||||
*new_node = *input;
|
*new_node = *input;
|
||||||
new_node->set_name(OptimizedNodeName(*input, "_enter"));
|
new_node->set_name(OptimizedNodeName(*input, "_enter"));
|
||||||
@ -2341,17 +2282,12 @@ bool ConstantFolding::MoveConstantsPastEnter(GraphDef* optimized_graph,
|
|||||||
for (NodeDef* consumer : consumers) {
|
for (NodeDef* consumer : consumers) {
|
||||||
for (int i = 0; i < consumer->input_size(); ++i) {
|
for (int i = 0; i < consumer->input_size(); ++i) {
|
||||||
if (NodeName(consumer->input(i)) == node_name) {
|
if (NodeName(consumer->input(i)) == node_name) {
|
||||||
node_map_->UpdateInput(consumer->name(), node_name,
|
node_map_->UpdateInput(consumer->name(), node_name, new_node->name());
|
||||||
new_node->name());
|
|
||||||
consumer->set_input(i, new_node->name());
|
consumer->set_input(i, new_node->name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
||||||
@ -2387,21 +2323,28 @@ bool ConstantFolding::SimplifySwitch(GraphDef* optimized_graph, NodeDef* node) {
|
|||||||
return n1->name() < n2->name();
|
return n1->name() < n2->name();
|
||||||
});
|
});
|
||||||
// Create constant false & true nodes.
|
// Create constant false & true nodes.
|
||||||
NodeDef* false_node = optimized_graph->add_node();
|
NodeDef tmp_false_node;
|
||||||
false_node->set_name(OptimizedNodeName(*node, "_const_false"));
|
tmp_false_node.set_name(OptimizedNodeName(*node, "_const_false"));
|
||||||
if (!CreateNodeDef(false_node->name(), TensorValue(&false_t), false_node)
|
if (!CreateNodeDef(tmp_false_node.name(), TensorValue(&false_t),
|
||||||
|
&tmp_false_node)
|
||||||
.ok()) {
|
.ok()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
false_node->set_device(node->device());
|
tmp_false_node.set_device(node->device());
|
||||||
|
NodeDef tmp_true_node;
|
||||||
|
tmp_true_node.set_name(OptimizedNodeName(*node, "_const_true"));
|
||||||
|
if (!CreateNodeDef(tmp_true_node.name(), TensorValue(&true_t),
|
||||||
|
&tmp_true_node)
|
||||||
|
.ok()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
tmp_true_node.set_device(node->device());
|
||||||
|
|
||||||
|
// Add const nodes to graph.
|
||||||
|
NodeDef* false_node = optimized_graph->add_node();
|
||||||
|
false_node->Swap(&tmp_false_node);
|
||||||
NodeDef* true_node = optimized_graph->add_node();
|
NodeDef* true_node = optimized_graph->add_node();
|
||||||
true_node->set_name(OptimizedNodeName(*node, "_const_true"));
|
true_node->Swap(&tmp_true_node);
|
||||||
if (!CreateNodeDef(true_node->name(), TensorValue(&true_t), true_node)
|
|
||||||
.ok()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
true_node->set_device(node->device());
|
|
||||||
|
|
||||||
// Add controls from the switch ports to the constants, and connect the
|
// Add controls from the switch ports to the constants, and connect the
|
||||||
// constants to the original switch outputs.
|
// constants to the original switch outputs.
|
||||||
@ -2615,11 +2558,9 @@ bool ConstantFolding::SimplifyReshape(const GraphProperties& properties,
|
|||||||
|
|
||||||
Status ConstantFolding::SimplifyArithmeticOperations(
|
Status ConstantFolding::SimplifyArithmeticOperations(
|
||||||
const GraphProperties& properties, bool use_shape_info,
|
const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success) {
|
GraphDef* optimized_graph, NodeDef* node) {
|
||||||
*success = false;
|
|
||||||
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
|
const bool is_mul = IsAnyMul(*node) || IsLogicalAnd(*node);
|
||||||
const bool is_matmul = IsAnyMatMul(*node);
|
const bool is_matmul = IsAnyMatMul(*node);
|
||||||
const bool is_quantized_matmul = IsQuantizedMatMul(*node);
|
|
||||||
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
const bool is_add = IsAdd(*node) || IsBiasAdd(*node) || IsLogicalOr(*node);
|
||||||
const bool is_sub = IsSub(*node);
|
const bool is_sub = IsSub(*node);
|
||||||
const bool is_any_div = IsAnyDiv(*node);
|
const bool is_any_div = IsAnyDiv(*node);
|
||||||
@ -2641,22 +2582,29 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
|||||||
// of zeros.
|
// of zeros.
|
||||||
const TensorShapeProto& y_shape =
|
const TensorShapeProto& y_shape =
|
||||||
properties.GetInputProperties(node->name())[1].shape();
|
properties.GetInputProperties(node->name())[1].shape();
|
||||||
const bool x_is_zero = IsZeros(*x);
|
const TensorShapeProto& x_shape =
|
||||||
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
|
properties.GetInputProperties(node->name())[0].shape();
|
||||||
const bool y_matches_output_shape =
|
const bool y_matches_output_shape =
|
||||||
ShapesSymbolicallyEqual(output_shape, y_shape);
|
ShapesSymbolicallyEqual(output_shape, y_shape);
|
||||||
if (y_matches_output_shape &&
|
const bool x_matches_output_shape =
|
||||||
((is_mul && x_is_one) || (is_add && x_is_zero))) {
|
ShapesSymbolicallyEqual(output_shape, x_shape);
|
||||||
|
|
||||||
|
const bool x_is_zero = IsZeros(*x);
|
||||||
|
const bool x_is_one = x_is_zero ? false : IsOnes(*x);
|
||||||
|
if ((is_mul && x_is_one) || (is_add && x_is_zero)) {
|
||||||
// 1 * y = y or 0 + y = y.
|
// 1 * y = y or 0 + y = y.
|
||||||
|
if (y_matches_output_shape) {
|
||||||
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
|
ReplaceOperationWithSnapshot(1, properties, node, optimized_graph);
|
||||||
*success = true;
|
} else if (x_matches_output_shape) {
|
||||||
|
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
|
||||||
|
optimized_graph);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (y_matches_output_shape && (is_sub && x_is_zero)) {
|
if (y_matches_output_shape && (is_sub && x_is_zero)) {
|
||||||
// Replace 0 - y with Neg(y).
|
// Replace 0 - y with Neg(y).
|
||||||
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
|
ReplaceSubtractionFromZeroByNegation(node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2666,37 +2614,30 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
|||||||
DataType type = node->attr().at("T").type();
|
DataType type = node->attr().at("T").type();
|
||||||
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
|
if (DataTypeIsFloating(type) || DataTypeIsComplex(type)) {
|
||||||
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
|
ReplaceDivisionOfOnesByReciprocal(node, optimized_graph);
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const TensorShapeProto& x_shape =
|
|
||||||
properties.GetInputProperties(node->name())[0].shape();
|
|
||||||
const bool y_is_zero = IsZeros(*y);
|
const bool y_is_zero = IsZeros(*y);
|
||||||
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
|
const bool y_is_one = y_is_zero ? false : IsOnes(*y);
|
||||||
const bool x_matches_output_shape =
|
if (((is_mul || is_any_div) && y_is_one) ||
|
||||||
ShapesSymbolicallyEqual(output_shape, x_shape);
|
((is_add || is_sub) && y_is_zero)) {
|
||||||
if (x_matches_output_shape && (((is_mul || is_any_div) && y_is_one) ||
|
|
||||||
((is_add || is_sub) && y_is_zero))) {
|
|
||||||
// x * 1 = x or x / 1 = x or x +/- 0 = x
|
// x * 1 = x or x / 1 = x or x +/- 0 = x
|
||||||
|
if (x_matches_output_shape) {
|
||||||
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
|
ReplaceOperationWithSnapshot(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
} else if (y_matches_output_shape) {
|
||||||
|
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
|
||||||
|
optimized_graph);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// x OR true = true OR y = true.
|
// x OR true = true OR y = true.
|
||||||
bool updated_graph = false;
|
|
||||||
const PartialTensorShape shp(output_shape);
|
const PartialTensorShape shp(output_shape);
|
||||||
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
|
if (shp.IsFullyDefined() && IsLogicalOr(*node) && (y_is_one || x_is_one)) {
|
||||||
bool replace_succeed = false;
|
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||||
Status replace_op_status = ReplaceOperationWithConstant(
|
1, properties, output_shape, node, optimized_graph));
|
||||||
1, properties, output_shape, node, optimized_graph, &replace_succeed);
|
return Status::OK();
|
||||||
if (!replace_op_status.ok()) {
|
|
||||||
return replace_op_status;
|
|
||||||
} else if (replace_succeed) {
|
|
||||||
updated_graph = true;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simplify multiplication and matmul by zeros.
|
// Simplify multiplication and matmul by zeros.
|
||||||
@ -2707,40 +2648,37 @@ Status ConstantFolding::SimplifyArithmeticOperations(
|
|||||||
if ((x_is_zero || y_is_zero) &&
|
if ((x_is_zero || y_is_zero) &&
|
||||||
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
|
(is_mul || is_matmul || optimize_zeros_divided_by_y)) {
|
||||||
if (shp.IsFullyDefined()) {
|
if (shp.IsFullyDefined()) {
|
||||||
bool replace_succeed = false;
|
bool is_quantized = IsQuantizedMatMul(*node);
|
||||||
Status replace_op_status =
|
TF_RETURN_IF_ERROR(ReplaceOperationWithConstant(
|
||||||
ReplaceOperationWithConstant(0, properties, output_shape, node,
|
0, properties, output_shape, node, optimized_graph));
|
||||||
optimized_graph, &replace_succeed);
|
if (is_quantized && graph_modified_) {
|
||||||
if (!replace_op_status.ok()) {
|
|
||||||
return replace_op_status;
|
|
||||||
} else if (replace_succeed) {
|
|
||||||
if (is_quantized_matmul) {
|
|
||||||
TF_RETURN_IF_ERROR(
|
TF_RETURN_IF_ERROR(
|
||||||
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
|
AddQuantizedMatMulMinMaxOutConstNodes(node, optimized_graph));
|
||||||
}
|
}
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
|
||||||
// Even if an input shape is only partially known, we may known that it
|
// Even if an input shape is only partially known, we may known that it
|
||||||
// matches the output shape and thus forward the corresponding zero
|
// matches the output shape and thus forward or broadcast the
|
||||||
// input.
|
// corresponding zero input.
|
||||||
if ((is_mul || is_any_div) && x_is_zero && x_matches_output_shape) {
|
if ((is_mul || is_any_div) && x_is_zero) {
|
||||||
|
if (x_matches_output_shape) {
|
||||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||||
*success = true;
|
} else if (y_matches_output_shape) {
|
||||||
|
ReplaceBinaryOperationWithBroadcastTo(0, properties, node,
|
||||||
|
optimized_graph);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
} else if (is_mul && y_is_zero && y_matches_output_shape) {
|
} else if (is_mul && y_is_zero) {
|
||||||
|
if (y_matches_output_shape) {
|
||||||
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
ReplaceOperationWithIdentity(1, properties, node, optimized_graph);
|
||||||
*success = true;
|
} else if (x_matches_output_shape) {
|
||||||
|
ReplaceBinaryOperationWithBroadcastTo(1, properties, node,
|
||||||
|
optimized_graph);
|
||||||
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (updated_graph) {
|
|
||||||
*success = true;
|
|
||||||
return Status::OK();
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
*success = false;
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -3300,6 +3238,7 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
|||||||
auto add_quantized_out = [this, node, optimized_graph](
|
auto add_quantized_out = [this, node, optimized_graph](
|
||||||
const string& out_const_name, int index) {
|
const string& out_const_name, int index) {
|
||||||
NodeDef* out_node = optimized_graph->add_node();
|
NodeDef* out_node = optimized_graph->add_node();
|
||||||
|
graph_modified_ = true;
|
||||||
Tensor value(DT_FLOAT, TensorShape({}));
|
Tensor value(DT_FLOAT, TensorShape({}));
|
||||||
const bool is_min = index == 1;
|
const bool is_min = index == 1;
|
||||||
const DataType type_attr = node->attr().at("dtype").type();
|
const DataType type_attr = node->attr().at("dtype").type();
|
||||||
@ -3310,7 +3249,6 @@ Status ConstantFolding::AddQuantizedMatMulMinMaxOutConstNodes(
|
|||||||
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
|
CreateNodeDef(out_const_name, TensorValue(&value), out_node));
|
||||||
node_map_->AddNode(out_const_name, out_node);
|
node_map_->AddNode(out_const_name, out_node);
|
||||||
out_node->set_device(node->device());
|
out_node->set_device(node->device());
|
||||||
|
|
||||||
// Copy all inputs from node.
|
// Copy all inputs from node.
|
||||||
out_node->mutable_input()->CopyFrom(node->input());
|
out_node->mutable_input()->CopyFrom(node->input());
|
||||||
for (const string& input : out_node->input()) {
|
for (const string& input : out_node->input()) {
|
||||||
|
@ -92,12 +92,14 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
void ReplaceOperationWithSnapshot(int input_to_forward,
|
void ReplaceOperationWithSnapshot(int input_to_forward,
|
||||||
const GraphProperties& properties,
|
const GraphProperties& properties,
|
||||||
NodeDef* node, GraphDef* graph);
|
NodeDef* node, GraphDef* graph);
|
||||||
|
void ReplaceBinaryOperationWithBroadcastTo(int input_to_broadcast,
|
||||||
|
const GraphProperties& properties,
|
||||||
|
NodeDef* node, GraphDef* graph);
|
||||||
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
|
void ReplaceSubtractionFromZeroByNegation(NodeDef* node, GraphDef* graph);
|
||||||
Status ReplaceOperationWithConstant(double value,
|
Status ReplaceOperationWithConstant(double value,
|
||||||
const GraphProperties& properties,
|
const GraphProperties& properties,
|
||||||
const TensorShapeProto& shape,
|
const TensorShapeProto& shape,
|
||||||
NodeDef* node, GraphDef* graph,
|
NodeDef* node, GraphDef* graph);
|
||||||
bool* success);
|
|
||||||
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
|
void ReplaceDivisionOfOnesByReciprocal(NodeDef* node, GraphDef* graph);
|
||||||
Status FoldGraph(GraphDef* output,
|
Status FoldGraph(GraphDef* output,
|
||||||
absl::flat_hash_set<string>* nodes_to_not_simplify);
|
absl::flat_hash_set<string>* nodes_to_not_simplify);
|
||||||
@ -145,8 +147,7 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
// was applied.
|
// was applied.
|
||||||
Status SimplifyArithmeticOperations(const GraphProperties& properties,
|
Status SimplifyArithmeticOperations(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
bool* success);
|
|
||||||
|
|
||||||
// Simplifies a Reshape operation to an Identity operation if applicable.
|
// Simplifies a Reshape operation to an Identity operation if applicable.
|
||||||
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
|
bool SimplifyReshape(const GraphProperties& properties, bool use_shape_info,
|
||||||
@ -194,43 +195,42 @@ class ConstantFolding : public GraphOptimizer {
|
|||||||
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
|
bool SimplifyPack(GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Simplifies a Squeeze operation to an Identity operation if applicable.
|
// Simplifies a Squeeze operation to an Identity operation if applicable.
|
||||||
bool SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
|
void SimplifySqueeze(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Simplifies a Pad operation to an Identity operation if applicable.
|
// Simplifies a Pad operation to an Identity operation if applicable.
|
||||||
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
|
Status SimplifyPad(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Simplifies a Tile operation to an Identity operation if applicable.
|
// Simplifies a Tile operation to an Identity operation if applicable.
|
||||||
Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
|
Status SimplifyTile(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Simplifies a StridedSlice operation to an Identity operation if applicable.
|
// Simplifies a StridedSlice operation to an Identity operation if applicable.
|
||||||
Status SimplifyStridedSlice(const GraphProperties& properties,
|
Status SimplifyStridedSlice(const GraphProperties& properties,
|
||||||
bool use_shape_info, GraphDef* optimized_graph,
|
bool use_shape_info, GraphDef* optimized_graph,
|
||||||
NodeDef* node, bool* success);
|
NodeDef* node);
|
||||||
|
|
||||||
// Simplifies a Slice operation to an Identity operation if applicable.
|
// Simplifies a Slice operation to an Identity operation if applicable.
|
||||||
Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
|
Status SimplifySlice(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Removes Reverse op over dimensions with size 1.
|
// Removes Reverse op over dimensions with size 1.
|
||||||
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
Status RemoveReverse(const GraphProperties& properties, bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node, bool* success);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
// Removes RandomShuffle op if it is scalar or first dimension is of size 1.
|
// Removes RandomShuffle op if it is scalar or first dimension is of size 1.
|
||||||
bool RemoveRandomShuffle(const GraphProperties& properties,
|
void RemoveRandomShuffle(const GraphProperties& properties,
|
||||||
bool use_shape_info, GraphDef* optimized_graph,
|
bool use_shape_info, GraphDef* optimized_graph,
|
||||||
NodeDef* node);
|
NodeDef* node);
|
||||||
|
|
||||||
// Removes Shuffle or Transpose op over dimensions of size 1.
|
// Removes Shuffle or Transpose op over dimensions of size 1.
|
||||||
Status RemoveShuffleOrTranspose(const GraphProperties& properties,
|
Status RemoveShuffleOrTranspose(const GraphProperties& properties,
|
||||||
bool use_shape_info,
|
bool use_shape_info,
|
||||||
GraphDef* optimized_graph, NodeDef* node,
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
bool* success);
|
|
||||||
|
|
||||||
// Removes Split or SplitV node if possible.
|
// Removes Split or SplitV node if possible.
|
||||||
bool RemoveSplitOrSplitV(const GraphProperties& properties,
|
void RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||||
GraphDef* optimized_graph, NodeDef* node);
|
GraphDef* optimized_graph, NodeDef* node);
|
||||||
|
|
||||||
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
|
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
|
||||||
|
@ -502,12 +502,16 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
ops::Placeholder::Shape(TensorShape({2})));
|
ops::Placeholder::Shape(TensorShape({2})));
|
||||||
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
|
Output zeros_1d = ops::Const(s.WithOpName("zeros_1d"), 0.0f, {2});
|
||||||
Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
|
Output zeros_const = ops::Const(s.WithOpName("zeros_const"), 0.0f, {2, 2});
|
||||||
|
Output zeros_const_bcast =
|
||||||
|
ops::Const(s.WithOpName("zeros_const_bcast"), 0.0f, {2, 2, 2});
|
||||||
Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
|
Output zeros_like = ops::ZerosLike(s.WithOpName("zeros_like"), x);
|
||||||
Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
|
Output zeros_fill = ops::Fill(s.WithOpName("zeros_fill"), {2, 2}, 0.0f);
|
||||||
Output zeros = const_type == kConst
|
Output zeros = const_type == kConst
|
||||||
? zeros_const
|
? zeros_const
|
||||||
: (const_type == kLike ? zeros_like : zeros_fill);
|
: (const_type == kLike ? zeros_like : zeros_fill);
|
||||||
Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
|
Output ones_const = ops::Const(s.WithOpName("ones_const"), 1.0f, {2, 2});
|
||||||
|
Output ones_const_bcast =
|
||||||
|
ops::Const(s.WithOpName("ones_const_bcast"), 1.0f, {2, 2, 2});
|
||||||
Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
|
Output ones_like = ops::OnesLike(s.WithOpName("ones_like"), x);
|
||||||
Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
|
Output ones_fill = ops::Fill(s.WithOpName("ones_fill"), {2, 2}, 1.0f);
|
||||||
Output ones = const_type == kConst
|
Output ones = const_type == kConst
|
||||||
@ -515,6 +519,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
: (const_type == kLike ? ones_like : ones_fill);
|
: (const_type == kLike ? ones_like : ones_fill);
|
||||||
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
|
Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
|
||||||
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
|
Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
|
||||||
|
Output mul1_bcast =
|
||||||
|
ops::Mul(s.WithOpName("mul1_bcast"), x, ones_const_bcast);
|
||||||
|
Output mul2_bcast =
|
||||||
|
ops::Mul(s.WithOpName("mul2_bcast"), ones_const_bcast, y);
|
||||||
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
|
Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
|
||||||
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
|
Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
|
||||||
Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
|
Output mul5 = ops::MulNoNan(s.WithOpName("mul5"), x, zeros_1d);
|
||||||
@ -527,6 +535,10 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
|
Output matmul4 = ops::MatMul(s.WithOpName("matmul4"), zeros, b);
|
||||||
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
|
Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
|
||||||
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
|
Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
|
||||||
|
Output add1_bcast =
|
||||||
|
ops::Add(s.WithOpName("add1_bcast"), x, zeros_const_bcast);
|
||||||
|
Output add2_bcast =
|
||||||
|
ops::Add(s.WithOpName("add2_bcast"), zeros_const_bcast, y);
|
||||||
Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
|
Output bias_add1 = ops::BiasAdd(s.WithOpName("bias_add1"), x, zeros_1d);
|
||||||
Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
|
Output bias_add2 = ops::BiasAdd(s.WithOpName("bias_add2"), zeros, bias);
|
||||||
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
|
Output sub1 = ops::Sub(s.WithOpName("sub1"), x, zeros);
|
||||||
@ -537,7 +549,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
|
matmul2, add1, add2, bias_add1, bias_add2, sub1, sub2});
|
||||||
GrapplerItem item;
|
GrapplerItem item;
|
||||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||||
item.fetch = {"stack", "matmul3", "matmul4"};
|
item.fetch = {"stack", "matmul3", "matmul4", "mul1_bcast",
|
||||||
|
"mul2_bcast", "add1_bcast", "add2_bcast"};
|
||||||
|
|
||||||
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
ConstantFolding optimizer(/*cpu_device=*/nullptr);
|
||||||
GraphDef output;
|
GraphDef output;
|
||||||
@ -551,7 +564,8 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
const string ones_name = strings::StrCat("ones", suffix);
|
const string ones_name = strings::StrCat("ones", suffix);
|
||||||
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
|
const string ctrl_zeros_name = strings::StrCat("^zeros", suffix);
|
||||||
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
|
const string ctrl_ones_name = strings::StrCat("^ones", suffix);
|
||||||
EXPECT_EQ(const_type == kFill ? 31 : 27, output.node_size());
|
|
||||||
|
EXPECT_EQ(const_type == kFill ? 42 : 38, output.node_size());
|
||||||
for (int i = 0; i < output.node_size(); ++i) {
|
for (int i = 0; i < output.node_size(); ++i) {
|
||||||
const NodeDef& node = output.node(i);
|
const NodeDef& node = output.node(i);
|
||||||
const string& name = node.name();
|
const string& name = node.name();
|
||||||
@ -563,6 +577,14 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
EXPECT_EQ("Const", node.op());
|
EXPECT_EQ("Const", node.op());
|
||||||
EXPECT_EQ(ctrl_zeros_name, node.input(0));
|
EXPECT_EQ(ctrl_zeros_name, node.input(0));
|
||||||
EXPECT_EQ("^y", node.input(1));
|
EXPECT_EQ("^y", node.input(1));
|
||||||
|
} else if (name == "mul1_bcast") {
|
||||||
|
EXPECT_EQ("BroadcastTo", node.op());
|
||||||
|
EXPECT_EQ("x", node.input(0));
|
||||||
|
EXPECT_EQ("^ones_const_bcast", node.input(2));
|
||||||
|
} else if (name == "mul2_bcast") {
|
||||||
|
EXPECT_EQ("BroadcastTo", node.op());
|
||||||
|
EXPECT_EQ("y", node.input(0));
|
||||||
|
EXPECT_EQ("^ones_const_bcast", node.input(2));
|
||||||
} else if (name == "mul3") {
|
} else if (name == "mul3") {
|
||||||
EXPECT_EQ("Identity", node.op());
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ("x", node.input(0));
|
EXPECT_EQ("x", node.input(0));
|
||||||
@ -623,15 +645,32 @@ TEST_F(ConstantFoldingTest, NeutralElement) {
|
|||||||
EXPECT_EQ("Identity", node.op());
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ("y", node.input(0));
|
EXPECT_EQ("y", node.input(0));
|
||||||
EXPECT_EQ(ctrl_zeros_name, node.input(1));
|
EXPECT_EQ(ctrl_zeros_name, node.input(1));
|
||||||
|
} else if (name == "add1_bcast") {
|
||||||
|
EXPECT_EQ("BroadcastTo", node.op());
|
||||||
|
EXPECT_EQ("x", node.input(0));
|
||||||
|
EXPECT_EQ("^zeros_const_bcast", node.input(2));
|
||||||
|
} else if (name == "add2_bcast") {
|
||||||
|
EXPECT_EQ("BroadcastTo", node.op());
|
||||||
|
EXPECT_EQ("y", node.input(0));
|
||||||
|
EXPECT_EQ("^zeros_const_bcast", node.input(2));
|
||||||
} else if (name == "bias_add1") {
|
} else if (name == "bias_add1") {
|
||||||
EXPECT_EQ("Identity", node.op());
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ("x", node.input(0));
|
EXPECT_EQ("x", node.input(0));
|
||||||
EXPECT_EQ("^zeros_1d", node.input(1));
|
EXPECT_EQ("^zeros_1d", node.input(1));
|
||||||
} else if (name == "bias_add2") {
|
} else if (name == "bias_add2") {
|
||||||
// We don't eliminate this one, because it requires broadcasting.
|
EXPECT_EQ("BroadcastTo", node.op());
|
||||||
EXPECT_EQ("BiasAdd", node.op());
|
EXPECT_EQ("bias", node.input(0));
|
||||||
EXPECT_EQ(zeros_name, node.input(0));
|
EXPECT_EQ("ConstantFolding/bias_add2-broadcastto_shape-1",
|
||||||
EXPECT_EQ("bias", node.input(1));
|
node.input(1));
|
||||||
|
EXPECT_EQ(ctrl_zeros_name, node.input(2));
|
||||||
|
} else if (name == "ConstantFolding/bias_add2-broadcastto_shape-1") {
|
||||||
|
EXPECT_EQ("Const", node.op());
|
||||||
|
EXPECT_EQ(ctrl_zeros_name, node.input(0));
|
||||||
|
EXPECT_EQ(node.attr().at("dtype").type(), DT_INT32);
|
||||||
|
TensorProto t = node.attr().at("value").tensor();
|
||||||
|
EXPECT_EQ(DT_INT32, t.dtype());
|
||||||
|
EXPECT_EQ(1, t.tensor_shape().dim_size());
|
||||||
|
EXPECT_EQ(2, t.tensor_shape().dim(0).size());
|
||||||
} else if (name == "sub1") {
|
} else if (name == "sub1") {
|
||||||
EXPECT_EQ("Identity", node.op());
|
EXPECT_EQ("Identity", node.op());
|
||||||
EXPECT_EQ("x", node.input(0));
|
EXPECT_EQ("x", node.input(0));
|
||||||
|
@ -550,4 +550,30 @@ Status StridedSliceGradGrad(const AttrSlice& attrs, FunctionDef* g) {
|
|||||||
}
|
}
|
||||||
REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
|
REGISTER_OP_GRADIENT("StridedSliceGrad", StridedSliceGradGrad);
|
||||||
|
|
||||||
|
Status BroadcastToGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||||
|
DataType itype;
|
||||||
|
TF_RETURN_IF_ERROR(GetNodeAttr(attrs, "Tidx", &itype));
|
||||||
|
if (itype != DT_INT32) {
|
||||||
|
return errors::Unimplemented(
|
||||||
|
"BroadcastToGrad for int64 index are not supported.");
|
||||||
|
}
|
||||||
|
std::vector<FDH::Node> nodes = {
|
||||||
|
{{"sx"}, "Shape", {"x"}, {{"T", "$T"}}},
|
||||||
|
{{"rx", "ry"}, "BroadcastGradientArgs", {"sx", "shape"}},
|
||||||
|
{{"sum_gx"}, "Sum", {"dy", "rx"}, {{"T", "$T"}}},
|
||||||
|
{{"dx"}, "Reshape", {"sum_gx", "sx"}, {{"T", "$T"}}},
|
||||||
|
{{"dshape"}, "ZerosLike", {"shape"}, {{"T", "$Tidx"}}}};
|
||||||
|
*g = FDH::Define(
|
||||||
|
// Arg defs
|
||||||
|
{"x: T", "shape: int32", "dy: T"},
|
||||||
|
// Ret val defs
|
||||||
|
{"dx: T", "dshape: Tidx"},
|
||||||
|
// Attr defs
|
||||||
|
{{"T: type"}, {"Tidx: {int32, int64}"}},
|
||||||
|
// Nodes
|
||||||
|
nodes);
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
REGISTER_OP_GRADIENT("BroadcastTo", BroadcastToGrad);
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -765,5 +765,40 @@ TEST(ArrayGradTest, StridedSliceGrad) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
std::vector<Tensor> BroadcastToGrad(const Tensor& x, const Tensor& shape,
|
||||||
|
const Tensor& dy) {
|
||||||
|
auto T = DT_FLOAT;
|
||||||
|
auto Tidx = DT_INT32;
|
||||||
|
auto gdef = test::function::GDef(
|
||||||
|
{f::NDef("x", "Placeholder", {}, {{"dtype", T}}),
|
||||||
|
f::NDef("shape", "Placeholder", {}, {{"dtype", Tidx}}),
|
||||||
|
f::NDef("dy", "Placeholder", {}, {{"dtype", T}}),
|
||||||
|
f::NDef(
|
||||||
|
"dx", "SymbolicGradient", {"x", "shape", "dy"},
|
||||||
|
{{"f", FDH::FunctionRef("BroadcastTo", {{"T", T}, {"Tidx", Tidx}})},
|
||||||
|
{"Tin", DataTypeSlice{T, Tidx, T}},
|
||||||
|
{"Tout", DataTypeSlice{T, Tidx}}})});
|
||||||
|
VLOG(1) << DebugStringWhole(gdef);
|
||||||
|
auto sess = NewSession();
|
||||||
|
TF_CHECK_OK(sess->Create(gdef));
|
||||||
|
std::vector<Tensor> out;
|
||||||
|
TF_CHECK_OK(sess->Run({{"x:0", x}, {"shape:0", shape}, {"dy:0", dy}},
|
||||||
|
{"dx:0", "dx:1"}, {}, &out));
|
||||||
|
CHECK_EQ(out.size(), 2);
|
||||||
|
TF_CHECK_OK(sess->Close());
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ArrayGradTest, BroadcastToGrad) {
|
||||||
|
Tensor x(DT_FLOAT, {2, 2});
|
||||||
|
x.flat<float>().setZero();
|
||||||
|
Tensor shape(DT_INT32, {3});
|
||||||
|
test::FillValues<int32>(&shape, {2, 2, 2});
|
||||||
|
Tensor dy(DT_FLOAT, {2, 2, 2});
|
||||||
|
test::FillIota<float>(&dy, 0);
|
||||||
|
auto dx = BroadcastToGrad(x, shape, dy);
|
||||||
|
test::ExpectClose(dx[0], test::AsTensor<float>({4., 6., 8., 10.}, {2, 2}));
|
||||||
|
test::ExpectTensorEqual<int32>(dx[1], test::AsTensor<int32>({0, 0, 0}, {3}));
|
||||||
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user