diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 4d8e5675fc3..8507a7c0d56 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -712,7 +712,10 @@ class SymbolicShapeRefiner { // Perform inference on function body. GraphProperties gp(grappler_function_item); - TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_)); + TF_RETURN_IF_ERROR(gp.InferStatically( + /*assume_valid_feeds=*/true, + /*aggressive_shape_inference=*/aggressive_shape_inference_, + /*include_tensor_values=*/true)); // Add return nodes for output shapes. int output = 0; @@ -2066,7 +2069,8 @@ Status GraphProperties::UpdateEnqueue( } Status GraphProperties::InferStatically(bool assume_valid_feeds, - bool aggressive_shape_inference) { + bool aggressive_shape_inference, + bool include_tensor_values) { FunctionLibraryDefinition function_library(OpRegistry::Global(), item_.graph.library()); std::unordered_map> fed_ports; @@ -2225,20 +2229,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, &input_properties[i]); input.port_id = i; GraphView::OutputPort fanin = graph_view.GetRegularFanin(input); - // Export tensor value to input_properties.value. - if (IsConstant(*fanin.node)) { - const TensorProto& raw_val = fanin.node->attr().at("value").tensor(); - *input_properties[i].mutable_value() = raw_val; - } else if (ctx->input_tensor_protos.size() > i && - ctx->input_tensor_protos[i] != nullptr) { - *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i]; - } else if (ic->input_tensors_as_shapes().size() > i && - IsShapeFullyDefinedIntegerVectorOrScalar( - ic, ic->input(i), ic->input_tensors_as_shapes()[i], - ctx->input_types[i])) { - *input_properties[i].mutable_value() = MakeTensorProtoFromShape( - ic, ic->input(i), ic->input_tensors_as_shapes()[i], - ctx->input_types[i]); + if (include_tensor_values) { + // Export tensor value to input_properties.value. + if (IsConstant(*fanin.node)) { + const TensorProto& raw_val = + fanin.node->attr().at("value").tensor(); + *input_properties[i].mutable_value() = raw_val; + } else if (ctx->input_tensor_protos.size() > i && + ctx->input_tensor_protos[i] != nullptr) { + *input_properties[i].mutable_value() = *ctx->input_tensor_protos[i]; + } else if (ic->input_tensors_as_shapes().size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i])) { + *input_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->input(i), ic->input_tensors_as_shapes()[i], + ctx->input_types[i]); + } } } } @@ -2254,20 +2261,24 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds, for (int i = 0; i < ic->num_outputs(); ++i) { shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i], &output_properties[i]); - // Export tensor value to output_properties.value. - if (IsConstant(node)) { - const TensorProto& raw_val = node.attr().at("value").tensor(); - *output_properties[i].mutable_value() = raw_val; - } else if (ctx->output_tensor_protos.size() > i && - ctx->output_tensor_protos[i] != nullptr) { - *output_properties[i].mutable_value() = *ctx->output_tensor_protos[i]; - } else if (ctx->output_tensors_as_shapes.size() > i && - IsShapeFullyDefinedIntegerVectorOrScalar( - ic, ic->output(i), ctx->output_tensors_as_shapes[i], - ctx->output_types[i])) { - *output_properties[i].mutable_value() = MakeTensorProtoFromShape( - ic, ic->output(i), ctx->output_tensors_as_shapes[i], - ctx->output_types[i]); + if (include_tensor_values) { + // Export tensor value to output_properties.value. + if (IsConstant(node)) { + // TODO(rmlarsen): Eliminate this copy. + const TensorProto& raw_val = node.attr().at("value").tensor(); + *output_properties[i].mutable_value() = raw_val; + } else if (ctx->output_tensor_protos.size() > i && + ctx->output_tensor_protos[i] != nullptr) { + *output_properties[i].mutable_value() = + *ctx->output_tensor_protos[i]; + } else if (ctx->output_tensors_as_shapes.size() > i && + IsShapeFullyDefinedIntegerVectorOrScalar( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i])) { + *output_properties[i].mutable_value() = MakeTensorProtoFromShape( + ic, ic->output(i), ctx->output_tensors_as_shapes[i], + ctx->output_types[i]); + } } } } diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h index bb7e6ed16a6..49ffa467034 100644 --- a/tensorflow/core/grappler/costs/graph_properties.h +++ b/tensorflow/core/grappler/costs/graph_properties.h @@ -89,11 +89,15 @@ class GraphProperties { // output values when possible and does other aggressive strategies. // Similar to assuming_valid_feeds, this may cause incorrectness in graph // analyses, but is useful for simulation or scheduling. + // If include_values is true, the values of constant tensors will be + // included in the input and output properties. Status InferStatically(bool assume_valid_feeds, - bool aggressive_shape_inference); + bool aggressive_shape_inference, + bool include_tensor_values); Status InferStatically(bool assume_valid_feeds) { return InferStatically(assume_valid_feeds, - /*aggressive_shape_inference=*/false); + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true); } // Infer the shape by running the graph on the specified cluster and recording // the shapes of the processed tensors. @@ -117,6 +121,7 @@ class GraphProperties { const string& node_name) const; const std::vector& GetOutputProperties( const string& node_name) const; + // Invalidate input/output properties for nodes modified during graph // optimization pass, to prevent potential optimizations, based on incorrect // shape information. diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 7b9eeb298d9..3eee2e9fc9c 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/costs/graph_properties.h" + #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/graph_def_util.h" @@ -1009,7 +1010,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) { GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto out_props = properties.GetOutputProperties("fill"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ("float: [4,4]", PropToString(out_prop0)); @@ -1028,7 +1030,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) { GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto out_props = properties.GetOutputProperties("fill"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0)); @@ -1248,10 +1251,12 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) { // evaluate output value. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/true, - /*aggressive_shape_inference=*/false)); + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ("int32: [2]", PropToString(out_prop0)); + LOG(INFO) << out_prop0.DebugString(); EXPECT_FALSE(out_prop0.has_value()); } @@ -1260,7 +1265,8 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) { // With aggressive_shape_inference, output value is evaluated. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/true, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto out_props = properties.GetOutputProperties("MyFunc"); const OpInfo::TensorProperties out_prop0 = out_props[0]; EXPECT_EQ("int32: [2]", PropToString(out_prop0)); @@ -1802,7 +1808,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) { GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/false)); + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value()); } @@ -1812,7 +1819,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) { GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value()); const auto slice_value = properties.GetOutputProperties("slice").at(0).value(); @@ -1838,7 +1846,8 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) { GraphProperties properties(item); TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); // Check output shapes and values. const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0]; @@ -1881,7 +1890,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) { // Without aggressive_shape_inference, ignore annotated information. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/false)); + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); const auto props = properties.GetOutputProperties("Identity"); EXPECT_EQ(1, props.size()); const OpInfo::TensorProperties& prop = props[0]; @@ -1895,7 +1905,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) { // Use annotated information. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto props = properties.GetOutputProperties("Identity"); EXPECT_EQ(1, props.size()); const OpInfo::TensorProperties& prop = props[0]; @@ -1923,7 +1934,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) { // Use annotated information. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto props = properties.GetOutputProperties("Identity"); EXPECT_EQ(1, props.size()); const OpInfo::TensorProperties& prop = props[0]; @@ -1950,7 +1962,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) { // Use annotated information. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto props = properties.GetOutputProperties("Identity"); EXPECT_EQ(1, props.size()); const OpInfo::TensorProperties& prop = props[0]; @@ -1977,7 +1990,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) { // Use annotated information. TF_CHECK_OK(properties.InferStatically( /*assume_valid_feeds=*/false, - /*aggressive_shape_inference=*/true)); + /*aggressive_shape_inference=*/true, + /*include_tensor_values=*/true)); const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn"); EXPECT_EQ(1, props.size()); const OpInfo::TensorProperties& prop = props[0]; diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc index 556f5251cc5..f4351d45f08 100644 --- a/tensorflow/core/grappler/costs/virtual_scheduler.cc +++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc @@ -359,7 +359,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) { graph_properties_ = absl::make_unique(*item); if (use_static_shapes_) { TF_RETURN_IF_ERROR(graph_properties_->InferStatically( - true, use_aggressive_shape_inference_)); + true, use_aggressive_shape_inference_, true)); } else { TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_)); } diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc index 20843f400d9..b394e71f8e6 100644 --- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc @@ -234,6 +234,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { DedupControlInputs(target_node); } + bool IsReallyConstant(const NodeDef& node) const { + if (!IsConstant(node)) { + return false; + } + // If the node is fed it's not constant anymore. + return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end(); + } + bool IsInPreserveSet(const NodeDef& node) const { return ctx().nodes_to_preserve->find(node.name()) != ctx().nodes_to_preserve->end(); @@ -259,6 +267,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage { return false; } + bool GetTensorFromConstNode(const string& node_name_or_input, + Tensor* tensor) { + const NodeDef* node = ctx().node_map->GetNode(node_name_or_input); + return node != nullptr && IsReallyConstant(*node) && + CheckAttrExists(*node, "value").ok() && + tensor->FromProto(node->attr().at("value").tensor()); + } + private: // Extended context required for ArithmeticOptimizer. const ArithmeticOptimizerContext ctx_ext_; @@ -2480,101 +2496,78 @@ class ConvertPowStage : public ArithmeticOptimizerStage { bool IsSupported(const NodeDef* node) const override { return IsPow(*node) && - ctx().graph_properties->GetInputProperties(node->name()).size() == 2; + ctx().graph_properties->HasOutputProperties(node->name()) && + ctx().graph_properties->HasInputProperties(node->name()); } Status TrySimplify(NodeDef* node, string* simplified_node_name) override { - const auto& pow_props = - ctx().graph_properties->GetInputProperties(node->name())[1]; - PartialTensorShape shape(pow_props.shape()); - if (!shape.IsFullyDefined()) { - // skip if p is not fully defined. - return Status::OK(); + Tensor pow; + if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK(); + complex128 prev, curr; + for (int i = 0; i < pow.NumElements(); ++i) { + if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) { + // input data type is not supported by Pow. Skip. + return Status::OK(); + } + if (i != 0 && curr != prev) { + // pow has different values on different elements. Skip. + return Status::OK(); + } + prev = curr; } - if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) { - Tensor pow(pow_props.dtype(), pow_props.shape()); - if (!pow.FromProto(pow_props.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - pow_props.value().DebugString()); - } - - complex128 prev, curr; - for (int i = 0; i < pow.NumElements(); ++i) { - if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) { - // input data type is not supported by Pow. Skip. - return Status::OK(); - } - if (i != 0 && curr != prev) { - // pow has different values on different elements. Skip. - return Status::OK(); - } - prev = curr; - } - NodeDef *x, *y; - TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); - TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); - const auto& value_props = - ctx().graph_properties->GetInputProperties(node->name())[0]; - const TensorShapeProto& output_shape = - ctx().graph_properties->GetOutputProperties(node->name())[0].shape(); - if (curr == complex128(2, 0)) { - node->set_op("Square"); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(y); - } else if (curr == complex128(1, 0) && - ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { - // Pow could be used to broadcast, so make sure the shapes of the two - // arguments are identical before replacing Pow with Identity. - node->set_op("Identity"); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(y); - } else if (curr == complex128(0.5, 0)) { - node->set_op("Sqrt"); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(y); - } else if (curr == complex128(0, 0) && - ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { - PartialTensorShape shape(value_props.shape()); - if (!shape.IsFullyDefined()) { - // skip if b is not fully defined. - return Status::OK(); - } - if (TensorShape::IsValid(value_props.shape()) && - value_props.has_value()) { - Tensor base(value_props.dtype(), value_props.shape()); - if (!base.FromProto(value_props.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - value_props.value().DebugString()); - } - node->set_op("Const"); - Tensor c(base.dtype(), base.shape()); - for (int i = 0; i < c.NumElements(); ++i) { - TF_RETURN_IF_ERROR(SetElementToOne(i, &c)); - } - (*node->mutable_attr())["dtype"].set_type(base.dtype()); - c.AsProtoTensorContent( - (*node->mutable_attr())["value"].mutable_tensor()); - node->mutable_attr()->erase("T"); - node->set_input(0, AsControlDependency(x->name())); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(x); - AddToOptimizationQueue(y); - } - } else if (curr == complex128(-0.5, 0)) { - node->set_op("Rsqrt"); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(y); - } else if (curr == complex128(-1, 0)) { - node->set_op("Reciprocal"); - node->set_input(1, AsControlDependency(y->name())); - AddToOptimizationQueue(node); - AddToOptimizationQueue(y); + NodeDef *x, *y; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y)); + const auto& value_props = + ctx().graph_properties->GetInputProperties(node->name())[0]; + const TensorShapeProto& output_shape = + ctx().graph_properties->GetOutputProperties(node->name())[0].shape(); + if (curr == complex128(2, 0)) { + node->set_op("Square"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(1, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape)) { + // Pow could be used to broadcast, so make sure the shapes of the two + // arguments are identical before replacing Pow with Identity. + node->set_op("Identity"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0.5, 0)) { + node->set_op("Sqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(0, 0) && + ShapesSymbolicallyEqual(value_props.shape(), output_shape) && + PartialTensorShape(output_shape).IsFullyDefined()) { + const auto dtype = node->attr().at("T").type(); + Tensor ones(dtype, output_shape); + for (int i = 0; i < ones.NumElements(); ++i) { + TF_RETURN_IF_ERROR(SetElementToOne(i, &ones)); } + node->set_op("Const"); + (*node->mutable_attr())["dtype"].set_type(dtype); + node->mutable_attr()->erase("T"); + ones.AsProtoTensorContent( + (*node->mutable_attr())["value"].mutable_tensor()); + node->set_input(0, AsControlDependency(x->name())); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(x); + AddToOptimizationQueue(y); + } else if (curr == complex128(-0.5, 0)) { + node->set_op("Rsqrt"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); + } else if (curr == complex128(-1, 0)) { + node->set_op("Reciprocal"); + node->set_input(1, AsControlDependency(y->name())); + AddToOptimizationQueue(node); + AddToOptimizationQueue(y); } return Status::OK(); } @@ -2638,12 +2631,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } private: - Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j, + Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j, bool* modified) { const auto& t = - ctx().graph_properties->GetInputProperties(input->name())[i]; + ctx().graph_properties->GetInputProperties(add_node->name())[i]; const auto& c = - ctx().graph_properties->GetInputProperties(input->name())[j]; + ctx().graph_properties->GetInputProperties(add_node->name())[j]; for (int k = 0; k < c.shape().dim_size(); ++k) { // Skip if c shape is not fully determined. if (c.shape().dim(k).size() < 0) { @@ -2659,13 +2652,13 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { // broadcast. return Status::OK(); } - if (TensorShape::IsValid(c.shape()) && c.has_value()) { - Tensor constant(c.dtype(), c.shape()); - if (!constant.FromProto(c.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - c.value().DebugString()); - } + Tensor constant; + if (GetTensorFromConstNode(add_node->input(j), &constant)) { complex128 element; + // TODO(rmlarsen): Refactor the more general IsOnes from + // constant_folding.cc and use it here. Perhaps also convert log(x - (-1)) + // or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1), + // and Neg(-1) to 1. for (int k = 0; k < constant.NumElements(); ++k) { if (!GetElementUnexhaustive(constant, k, {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, @@ -2680,15 +2673,15 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage { } } NodeDef *x, *y; - TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x)); - TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y)); + TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x)); + TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y)); node->set_op("Log1p"); - node->set_input(0, input->input(i)); + node->set_input(0, add_node->input(i)); node->add_input(AsControlDependency(y->name())); - ForwardControlDependencies(node, {input}); + ForwardControlDependencies(node, {add_node}); AddToOptimizationQueue(node); - AddToOptimizationQueue(input); + AddToOptimizationQueue(add_node); AddToOptimizationQueue(x); AddToOptimizationQueue(y); *modified = true; @@ -2717,25 +2710,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) { return Status::OK(); } - - NodeDef* exp; - TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp)); - if (!IsExp(*exp)) { - return Status::OK(); - } - - if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) { - return Status::OK(); - } - - const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0]; + const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0]; const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1]; - for (int k = 0; k < c.shape().dim_size(); ++k) { - // Skip if c shape is not fully determined. - if (c.shape().dim(k).size() < 0) { - return Status::OK(); - } - } TensorShapeProto broadcast_shape; if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) { return Status::OK(); @@ -2745,39 +2721,39 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage { // broadcast. return Status::OK(); } - if (TensorShape::IsValid(c.shape()) && c.has_value()) { - Tensor constant(c.dtype(), c.shape()); - if (!constant.FromProto(c.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - c.value().DebugString()); + Tensor constant; + if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK(); + // TODO(rmlarsen): Use the more general IsOnes helper here. + complex128 element; + for (int k = 0; k < constant.NumElements(); ++k) { + if (!GetElementUnexhaustive(constant, k, + {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_COMPLEX128}, + &element)) { + // input data type is not supported by expm1. Skip. + return Status::OK(); } - complex128 element; - for (int k = 0; k < constant.NumElements(); ++k) { - if (!GetElementUnexhaustive(constant, k, - {DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE, - DT_COMPLEX64, DT_COMPLEX128}, - &element)) { - // input data type is not supported by expm1. Skip. - return Status::OK(); - } - if (element != complex128(1)) { - // current element is not 1. Skip. - return Status::OK(); - } + LOG(INFO) << "Got element = " << element; + if (element != complex128(1)) { + // current element is not 1. Skip. + return Status::OK(); } - NodeDef *exp_input, *ones; - TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input)); - TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones)); - node->set_op("Expm1"); - node->set_input(0, exp->input(0)); - node->set_input(1, AsControlDependency(ones->name())); - ForwardControlDependencies(node, {exp}); - - AddToOptimizationQueue(node); - AddToOptimizationQueue(exp); - AddToOptimizationQueue(exp_input); - AddToOptimizationQueue(ones); } + NodeDef* exp; + TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp)); + NodeDef *exp_input, *ones; + TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input)); + TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones)); + node->set_op("Expm1"); + node->set_input(0, exp->input(0)); + node->set_input(1, AsControlDependency(ones->name())); + ForwardControlDependencies(node, {exp}); + + AddToOptimizationQueue(node); + AddToOptimizationQueue(exp); + AddToOptimizationQueue(exp_input); + AddToOptimizationQueue(ones); + *simplified_node_name = node->name(); return Status::OK(); } }; @@ -3096,14 +3072,6 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage { } protected: - bool IsReallyConstant(const NodeDef& node) const { - if (!IsConstant(node)) { - return false; - } - // If the node is fed it's not constant anymore. - return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end(); - } - bool GetConstantAsInt64(const NodeDef& node, DataType dtype, std::vector* values) { if (dtype == DT_INT32) { @@ -3430,8 +3398,6 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const { void ArithmeticOptimizer::DedupComputations() { CanonicalizeGraph(optimized_graph_); - // LOG(INFO) << "Graph after canonicalization: \n" - // << optimized_graph_->DebugString(); GraphTopologyView graph_view; if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) { @@ -3683,7 +3649,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/, graph_properties_.reset(new GraphProperties(optimized_item)); const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; - const Status status = graph_properties_->InferStatically(assume_valid_feeds); + const Status status = + graph_properties_->InferStatically(assume_valid_feeds, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false); const bool can_use_shapes = status.ok(); if (!can_use_shapes) { VLOG(1) << "Shape inference failed." << status.error_message(); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 26fd9e5fd78..fb6aaf7082e 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -122,27 +122,6 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node, return removed_input; } -bool GetConcatAxis(const GraphProperties& properties, NodeDef* node, - int* axis) { - if (node->op() != "ConcatV2" || - properties.GetInputProperties(node->name()).empty()) { - return false; - } - const auto& axis_input = properties.GetInputProperties(node->name()).back(); - if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) { - return false; - } - - Tensor axis_tensor(axis_input.dtype(), axis_input.shape()); - if (!axis_tensor.FromProto(axis_input.value())) { - return false; - } - *axis = axis_input.dtype() == DT_INT64 - ? static_cast(axis_tensor.scalar()()) - : axis_tensor.scalar()(); - return true; -} - bool HasTPUAttributes(const NodeDef& node) { AttrSlice attrs(node); for (auto attr : attrs) { @@ -220,9 +199,9 @@ string ConstantFolding::AddControlDependency(const string& input_name, if (IsControlInput(input_name)) { return input_name; } - const NodeDef* node = node_map->GetNode(input_name); - if (!IsSwitch(*node)) { - return AsControlDependency(*node); + const NodeDef& node = *node_map->GetNode(input_name); + if (!IsSwitch(node)) { + return AsControlDependency(node); } else { // We can't anchor control dependencies directly on the switch node: unlike // other nodes only one of the outputs of the switch node will be generated @@ -230,10 +209,10 @@ string ConstantFolding::AddControlDependency(const string& input_name, // dependency is only triggered when the corresponding output is triggered. // We start by looking for an identity node connected to the output of the // switch node, and use it to anchor the control dependency. - auto outputs = node_map->GetOutputs(node->name()); + auto outputs = node_map->GetOutputs(node.name()); for (const NodeDef* output : outputs) { if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) { - if (IsSameInput(node->input(0), input_name)) { + if (IsSameInput(node.input(0), input_name)) { return AsControlDependency(*output); } } @@ -244,19 +223,19 @@ string ConstantFolding::AddControlDependency(const string& input_name, string ctrl_dep_name = ParseNodeName(input_name, &port); strings::StrAppend(&ctrl_dep_name, "_", port); ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl); - const DataType output_type = node->attr().at("T").type(); + const DataType output_type = node.attr().at("T").type(); NodeDef* added_node = node_map->GetNode(ctrl_dep_name); if (added_node == nullptr) { added_node = graph->add_node(); added_node->set_name(ctrl_dep_name); added_node->set_op("Identity"); - added_node->set_device(node->device()); + added_node->set_device(node.device()); (*added_node->mutable_attr())["T"].set_type(output_type); *added_node->add_input() = input_name; node_map->AddNode(added_node->name(), added_node); - node_map->AddOutput(node->name(), added_node->name()); + node_map->AddOutput(node.name(), added_node->name()); } return AsControlDependency(*added_node); } @@ -321,6 +300,15 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const { return feed_nodes_.find(node.name()) == feed_nodes_.end(); } +// TODO(rmlarsen): Refactor to shared util. +bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input, + Tensor* tensor) { + const NodeDef* node = node_map_->GetNode(node_name_or_input); + return node != nullptr && IsReallyConstant(*node) && + CheckAttrExists(*node, "value").ok() && + tensor->FromProto(node->attr().at("value").tensor()); +} + // Materialize the shapes using constants whenever possible. Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) { // We may add some nodes to the graph to encode control dependencies and hold @@ -868,6 +856,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const { return false; } + if (node.op() == "AccumulateNV2") { + return false; + } // Skips ops that don't benefit from folding. if (IsPlaceholder(node)) { return false; @@ -1856,9 +1847,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node, SET_AND_RETURN_IF_MODIFIED( PartialAssocOpConstFolding(optimized_graph, properties, node)); SET_AND_RETURN_IF_MODIFIED( - PartialConcatConstFolding(optimized_graph, properties, node)); + MergeConcat(use_shape_info, optimized_graph, node)); SET_AND_RETURN_IF_MODIFIED( - MergeConcat(*properties, use_shape_info, optimized_graph, node)); + PartialConcatConstFolding(optimized_graph, properties, node)); graph_modified_ = graph_modified_cached; return Status::OK(); @@ -1879,43 +1870,33 @@ void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties, Status ConstantFolding::RemoveShuffleOrTranspose( const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) && - properties.GetInputProperties(node->name()).size() >= 2) { + if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node))) + return Status::OK(); + Tensor permutation_tensor; + if (GetTensorFromConstNode(node->input(1), &permutation_tensor) && + properties.HasInputProperties(node->name())) { const auto& shape = properties.GetInputProperties(node->name())[0].shape(); - if (shape.unknown_rank()) { - // Not optimizable. + std::vector permutation; + for (int j = 0; j < permutation_tensor.NumElements(); ++j) { + if (permutation_tensor.dtype() == DT_INT64) { + permutation.push_back(permutation_tensor.vec()(j)); + } else { + permutation.push_back(permutation_tensor.vec()(j)); + } + } + if (permutation.size() != shape.dim_size()) { + // Number of elements in perm should be same as dim_size. Skip if not. return Status::OK(); } - const auto& p = properties.GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor perm(p.dtype(), p.shape()); - if (!perm.FromProto(p.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); - } - std::vector permutation; - for (int j = 0; j < perm.NumElements(); ++j) { - if (perm.dtype() == DT_INT64) { - permutation.push_back(perm.vec()(j)); - } else { - permutation.push_back(perm.vec()(j)); - } - } - if (permutation.size() != shape.dim_size()) { - // Number of elements in perm should be same as dim_size. Skip if not. - return Status::OK(); - } - // The node is replaceable iff - // dim_size == 0 || all dims have size 1 || - // all dims with > 1 size are not permuted. - bool replaceable = true; - for (int j = 0; replaceable && j < shape.dim_size(); ++j) { - replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - return Status::OK(); - } + // The node is replaceable iff + // dim_size == 0 || all dims have size 1 || + // all dims with > 1 size are not permuted. + bool replaceable = true; + for (int j = 0; replaceable && j < shape.dim_size(); ++j) { + replaceable &= shape.dim(j).size() == 1 || j == permutation[j]; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } return Status::OK(); @@ -1941,44 +1922,35 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (use_shape_info && node->op() == "ReverseV2" && - properties.GetInputProperties(node->name()).size() >= 2) { + if (!use_shape_info || node->op() != "ReverseV2") return Status::OK(); + Tensor axis; + if (properties.HasInputProperties(node->name()) && + GetTensorFromConstNode(node->input(1), &axis)) { const auto& shape = properties.GetInputProperties(node->name())[0].shape(); - if (shape.unknown_rank()) { - // Not optimizable. - return Status::OK(); + if (shape.unknown_rank()) return Status::OK(); + std::set target_axes; + for (int j = 0; j < axis.NumElements(); ++j) { + // value of axis can be negative. + if (axis.dtype() == DT_INT64) { + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); + } else { + target_axes.insert((axis.vec()(j) + shape.dim_size()) % + shape.dim_size()); + } } - const auto& a = properties.GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(a.shape()) && a.has_value()) { - Tensor axis(a.dtype(), a.shape()); - if (!axis.FromProto(a.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - a.value().DebugString()); - } - std::set target_axes; - for (int j = 0; j < axis.NumElements(); ++j) { - // value of axis can be negative. - if (axis.dtype() == DT_INT64) { - target_axes.insert((axis.vec()(j) + shape.dim_size()) % - shape.dim_size()); - } else { - target_axes.insert((axis.vec()(j) + shape.dim_size()) % - shape.dim_size()); - } - } - // The node is replaceable iff - // unknown_rank == false && - // (dim_size == 0 || all dims have size 1 || - // all dims with > 1 size are not in target_axes) - bool replaceable = !shape.unknown_rank(); - for (int j = 0; replaceable && j < shape.dim_size(); ++j) { - replaceable &= shape.dim(j).size() == 1 || - target_axes.find(j) == target_axes.end(); - } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - } + // The node is replaceable iff + // unknown_rank == false && + // (dim_size == 0 || all dims have size 1 || + // all dims with > 1 size are not in target_axes) + bool replaceable = true; + for (int j = 0; replaceable && j < shape.dim_size(); ++j) { + replaceable &= + shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end(); + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } return Status::OK(); @@ -1988,45 +1960,33 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (use_shape_info && IsSlice(*node) && - properties.GetInputProperties(node->name()).size() == 3) { + if (!use_shape_info || !IsSlice(*node)) return Status::OK(); + Tensor begin; + Tensor size; + if (properties.HasInputProperties(node->name()) && + GetTensorFromConstNode(node->input(1), &begin) && + GetTensorFromConstNode(node->input(2), &size)) { const auto& input = properties.GetInputProperties(node->name())[0]; - const auto& b = properties.GetInputProperties(node->name())[1]; - const auto& s = properties.GetInputProperties(node->name())[2]; - if (TensorShape::IsValid(b.shape()) && b.has_value() && - TensorShape::IsValid(s.shape()) && s.has_value()) { - Tensor begin(b.dtype(), b.shape()); - if (!begin.FromProto(b.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - b.value().DebugString()); + // The node is replaceable iff unknown_rank == false && + // begin == 0 && (size == -1 || size == input_shape) for all dimensions + bool replaceable = !input.shape().unknown_rank(); + for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { + if (begin.dtype() == DT_INT32) { + replaceable &= begin.vec()(j) == 0; + } else { + replaceable &= begin.vec()(j) == 0; } - Tensor size(s.dtype(), s.shape()); - if (!size.FromProto(s.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - s.value().DebugString()); - } - // The node is replaceable iff unknown_rank == false && - // begin == 0 && (size == -1 || size == input_shape) for all dimensions - bool replaceable = !input.shape().unknown_rank(); - for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { - if (begin.dtype() == DT_INT32) { - replaceable &= begin.vec()(j) == 0; - } else { - replaceable &= begin.vec()(j) == 0; - } - if (size.dtype() == DT_INT32) { - replaceable &= (size.vec()(j) == -1 || - size.vec()(j) == input.shape().dim(j).size()); - } else { - replaceable &= (size.vec()(j) == -1 || - size.vec()(j) == input.shape().dim(j).size()); - } - } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - return Status::OK(); + if (size.dtype() == DT_INT32) { + replaceable &= (size.vec()(j) == -1 || + size.vec()(j) == input.shape().dim(j).size()); + } else { + replaceable &= (size.vec()(j) == -1 || + size.vec()(j) == input.shape().dim(j).size()); } } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + } } return Status::OK(); } @@ -2052,81 +2012,70 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, return Status::OK(); } } - const auto& b = properties.GetInputProperties(node->name())[1]; - const auto& e = properties.GetInputProperties(node->name())[2]; - const auto& s = properties.GetInputProperties(node->name())[3]; - if (TensorShape::IsValid(b.shape()) && b.has_value() && - TensorShape::IsValid(e.shape()) && e.has_value() && - TensorShape::IsValid(s.shape()) && s.has_value()) { - Tensor begin(b.dtype(), b.shape()); - if (!begin.FromProto(b.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - b.value().DebugString()); - } - Tensor end(e.dtype(), e.shape()); - if (!end.FromProto(e.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - e.value().DebugString()); - } - Tensor strides(s.dtype(), s.shape()); - if (!strides.FromProto(s.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - s.value().DebugString()); - } - TF_RETURN_IF_ERROR( - CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"})); - int begin_mask = node->attr().at("begin_mask").i(); - int end_mask = node->attr().at("end_mask").i(); - std::set expanded_ellipsis_indices; - int ellipsis_index = -1; - for (int j = 0; j < input.shape().dim_size(); ++j) { - // find the ellipsis_mask. If not found, insert one in the end if - // necessary. - if (node->attr().at("ellipsis_mask").i() & 1 << j || - (ellipsis_index == -1 && j >= strides.NumElements())) { - ellipsis_index = j; - } - // insert the indices that are immediately after ellipsis_index if - // necessary. - if (ellipsis_index != -1 && - input.shape().dim_size() > - strides.NumElements() + j - ellipsis_index) { - expanded_ellipsis_indices.insert(j); - } - } - // The node is replaceable iff unknown_rank == false && - // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim) - // && strides == 1) for all dimensions. - bool replaceable = !input.shape().unknown_rank(); - for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { - if (expanded_ellipsis_indices.find(j) != - expanded_ellipsis_indices.end()) { - // ellipsis_mask is effective on current dimension. - continue; - } - // when we have ellipsis_mask in between, input.shape().dim_size() will - // be greater than strides.NumElements(), since we will insert - // as many as expanded_ellipsis_indices.size() axes during computation. - // We need to subtract this number from j. - int i = j; - if (ellipsis_index != -1 && - j >= ellipsis_index + expanded_ellipsis_indices.size()) { - i = j - expanded_ellipsis_indices.size(); - } - int b = begin.dtype() == DT_INT32 ? begin.vec()(i) - : begin.vec()(i); - int e = - end.dtype() == DT_INT32 ? end.vec()(i) : end.vec()(i); - int s = strides.dtype() == DT_INT32 ? strides.vec()(i) - : strides.vec()(i); - replaceable &= - (begin_mask & 1 << i || b == 0) && - (end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1; + std::vector input_tensors(3); + for (int i = 1; i < 4; ++i) { + if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) { + return Status::OK(); } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + } + + const Tensor& begin = input_tensors[0]; + const Tensor& end = input_tensors[1]; + const Tensor& strides = input_tensors[2]; + + TF_RETURN_IF_ERROR( + CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"})); + int begin_mask = node->attr().at("begin_mask").i(); + int end_mask = node->attr().at("end_mask").i(); + std::set expanded_ellipsis_indices; + int ellipsis_index = -1; + for (int j = 0; j < input.shape().dim_size(); ++j) { + // find the ellipsis_mask. If not found, insert one in the end if + // necessary. + if (node->attr().at("ellipsis_mask").i() & 1 << j || + (ellipsis_index == -1 && j >= strides.NumElements())) { + ellipsis_index = j; } + // insert the indices that are immediately after ellipsis_index if + // necessary. + if (ellipsis_index != -1 && + input.shape().dim_size() > + strides.NumElements() + j - ellipsis_index) { + expanded_ellipsis_indices.insert(j); + } + } + + // The node is replaceable iff unknown_rank == false && + // ((begin_mask is set || begin == 0) && (end_mask is set || end == dim) + // && strides == 1) for all dimensions. + bool replaceable = !input.shape().unknown_rank(); + for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) { + if (expanded_ellipsis_indices.find(j) != + expanded_ellipsis_indices.end()) { + // ellipsis_mask is effective on current dimension. + continue; + } + // when we have ellipsis_mask in between, input.shape().dim_size() will + // be greater than strides.NumElements(), since we will insert + // as many as expanded_ellipsis_indices.size() axes during computation. + // We need to subtract this number from j. + int i = j; + if (ellipsis_index != -1 && + j >= ellipsis_index + expanded_ellipsis_indices.size()) { + i = j - expanded_ellipsis_indices.size(); + } + int b = begin.dtype() == DT_INT32 ? begin.vec()(i) + : begin.vec()(i); + int e = end.dtype() == DT_INT32 ? end.vec()(i) : end.vec()(i); + int s = strides.dtype() == DT_INT32 ? strides.vec()(i) + : strides.vec()(i); + replaceable &= (begin_mask & 1 << i || b == 0) && + (end_mask & 1 << i || e == input.shape().dim(j).size()) && + s == 1; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } return Status::OK(); @@ -2135,31 +2084,23 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties, Status ConstantFolding::SimplifyTile(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { + Tensor multiplies; if (use_shape_info && IsTile(*node) && - properties.GetInputProperties(node->name()).size() == 2) { - const auto& m = properties.GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(m.shape()) && m.has_value()) { - Tensor multiplies(m.dtype(), m.shape()); - if (!multiplies.FromProto(m.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - m.value().DebugString()); + GetTensorFromConstNode(node->input(1), &multiplies)) { + // The node is replaceable iff all values in multiplies are 1. + bool replaceable = true; + if (multiplies.dtype() == DT_INT32) { + for (int j = 0; replaceable && j < multiplies.vec().size(); ++j) { + replaceable &= multiplies.vec()(j) == 1; } - // The node is replaceable iff all values in multiplies are 1. - bool replaceable = true; - if (multiplies.dtype() == DT_INT32) { - for (int j = 0; replaceable && j < multiplies.vec().size(); ++j) { - replaceable &= multiplies.vec()(j) == 1; - } - } else { - for (int j = 0; replaceable && j < multiplies.vec().size(); - ++j) { - replaceable &= multiplies.vec()(j) == 1; - } - } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + } else { + for (int j = 0; replaceable && j < multiplies.vec().size(); ++j) { + replaceable &= multiplies.vec()(j) == 1; } } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); + } } return Status::OK(); } @@ -2167,26 +2108,20 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties, Status ConstantFolding::SimplifyPad(const GraphProperties& properties, bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { - if (use_shape_info && IsPad(*node) && - properties.GetInputProperties(node->name()).size() >= 2) { - const auto& p = properties.GetInputProperties(node->name())[1]; - if (TensorShape::IsValid(p.shape()) && p.has_value()) { - Tensor paddings(p.dtype(), p.shape()); - if (!paddings.FromProto(p.value())) { - return errors::InvalidArgument("Cannot parse tensor from proto: ", - p.value().DebugString()); - } - // The node is replaceable iff all values in paddings are 0. - bool replaceable = true; - // The operation requires it to be int32 value so we don't check for - // 1nt64. - const auto flatten = paddings.flat(); - for (int j = 0; replaceable && j < flatten.size(); ++j) { - replaceable &= flatten(j) == 0; - } - if (replaceable) { - ReplaceOperationWithIdentity(0, properties, node, optimized_graph); - } + if (!use_shape_info || !IsPad(*node)) return Status::OK(); + + Tensor paddings; + if (GetTensorFromConstNode(node->input(1), &paddings)) { + // The node is replaceable iff all values in paddings are 0. + bool replaceable = true; + // The operation requires it to be int32 value so we don't check for + // 1nt64. + const auto flatten = paddings.flat(); + for (int j = 0; replaceable && j < flatten.size(); ++j) { + replaceable &= flatten(j) == 0; + } + if (replaceable) { + ReplaceOperationWithIdentity(0, properties, node, optimized_graph); } } return Status::OK(); @@ -3031,73 +2966,71 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph, // folding of ops when more than one but not all inputs are constant. // For AddN and AccumulateNV2, we may furthermore reorder inputs, since // addition is commutative. - const int num_non_control_inputs = NumNonControlInputs(*node); - if (IsAggregate(*node) && IsCommutative(*node) && - num_non_control_inputs > 2) { - const int num_control_inputs = node->input_size() - num_non_control_inputs; - std::vector const_inputs; - std::vector nonconst_inputs; - for (int i = 0; i < node->input_size(); ++i) { - const string& input = node->input(i); - const NodeDef* input_node = node_map_->GetNode(NodeName(input)); - CHECK(input_node != nullptr) << input; - if (!IsControlInput(input) && IsReallyConstant(*input_node)) { - const_inputs.push_back(i); - } else { - // Non-const and control inputs. - nonconst_inputs.push_back(i); - } - } - // Promote AccumulateNV2 with all constant inputs to AddN, since it is - // a fake node that cannot be constant folded by itself. - if (const_inputs.size() == num_non_control_inputs && - node->op() == "AccumulateNV2") { - node->set_op("AddN"); - node->mutable_attr()->erase("shape"); - return true; - } - const string new_node_name = OptimizedNodeName( - *node, strings::StrCat("_partial_split_", const_inputs.size())); - if (1 < const_inputs.size() && - const_inputs.size() < num_non_control_inputs && - !node_map_->NodeExists(new_node_name)) { - NodeDef* added_node = optimized_graph->add_node(); - *added_node = *node; - // Always use AddN for the constant node, since AccumulateNV2 is a fake - // node that cannot be constant folded, since it does not have a kernel. - added_node->set_op("AddN"); - added_node->mutable_attr()->erase("shape"); - added_node->set_name(new_node_name); - node_map_->AddNode(added_node->name(), added_node); - added_node->clear_input(); - for (int i : const_inputs) { - added_node->add_input(node->input(i)); - node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), - added_node->name()); - } + if (!IsAggregate(*node) || !IsCommutative(*node)) return false; - // Overwrite the first const input with the added node. - node->set_input(const_inputs[0], added_node->name()); - node_map_->AddOutput(added_node->name(), node->name()); - nonconst_inputs.push_back(const_inputs[0]); - // Compact the remaining inputs to the original node. - std::sort(nonconst_inputs.begin(), nonconst_inputs.end()); - int idx = 0; - for (int i : nonconst_inputs) { - if (idx != i) { - node->set_input(idx, node->input(i)); - } - ++idx; - } - node->mutable_input()->DeleteSubrange(nonconst_inputs.size(), - const_inputs.size() - 1); - (*node->mutable_attr())["N"].set_i(node->input_size() - - num_control_inputs); - properties->ClearInputProperties(node->name()); - (*added_node->mutable_attr())["N"].set_i(const_inputs.size()); - return true; + const int num_non_control_inputs = NumNonControlInputs(*node); + if (num_non_control_inputs <= 2) return false; + const int num_control_inputs = node->input_size() - num_non_control_inputs; + std::vector const_inputs; + std::vector nonconst_inputs; + for (int i = 0; i < node->input_size(); ++i) { + const string& input = node->input(i); + const NodeDef* input_node = node_map_->GetNode(NodeName(input)); + if (input_node == nullptr) return false; + if (!IsControlInput(input) && IsReallyConstant(*input_node)) { + const_inputs.push_back(i); + } else { + // Non-const and control inputs. + nonconst_inputs.push_back(i); } } + // Promote AccumulateNV2 with all constant inputs to AddN, since it is + // a fake node that cannot be constant folded by itself. + if (const_inputs.size() == num_non_control_inputs && + node->op() == "AccumulateNV2") { + node->set_op("AddN"); + node->mutable_attr()->erase("shape"); + return true; + } + const string new_node_name = OptimizedNodeName( + *node, strings::StrCat("_partial_split_", const_inputs.size())); + if (const_inputs.size() > 1 && const_inputs.size() < num_non_control_inputs && + !node_map_->NodeExists(new_node_name)) { + NodeDef* added_node = optimized_graph->add_node(); + *added_node = *node; + // Always use AddN for the constant node, since AccumulateNV2 is a fake + // node that cannot be constant folded, since it does not have a kernel. + added_node->set_op("AddN"); + added_node->mutable_attr()->erase("shape"); + added_node->set_name(new_node_name); + node_map_->AddNode(added_node->name(), added_node); + added_node->clear_input(); + for (int i : const_inputs) { + added_node->add_input(node->input(i)); + node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), + added_node->name()); + } + + // Overwrite the first const input with the added node. + node->set_input(const_inputs[0], added_node->name()); + node_map_->AddOutput(added_node->name(), node->name()); + nonconst_inputs.push_back(const_inputs[0]); + // Compact the remaining inputs to the original node. + std::sort(nonconst_inputs.begin(), nonconst_inputs.end()); + int idx = 0; + for (int i : nonconst_inputs) { + if (idx != i) { + node->set_input(idx, node->input(i)); + } + ++idx; + } + node->mutable_input()->DeleteSubrange(nonconst_inputs.size(), + const_inputs.size() - 1); + (*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs); + properties->ClearInputProperties(node->name()); + (*added_node->mutable_attr())["N"].set_i(const_inputs.size()); + return true; + } return false; } @@ -3107,156 +3040,176 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph, // Partial constant folding for Concat which is not commutative, so // we have to preserve order and can only push consecutive runs of constant // inputs into sub-nodes. - const int num_non_control_inputs = NumNonControlInputs(*node); - if (IsConcat(*node) && num_non_control_inputs > 3 && - node->name().rfind("_partial_split_") == string::npos) { - int axis_arg = -1; - int begin = 0; - int end = num_non_control_inputs; - if (node->op() == "Concat") { - begin = 1; - axis_arg = 0; - } else if (node->op() == "ConcatV2") { - end = num_non_control_inputs - 1; - axis_arg = num_non_control_inputs - 1; - } else { - return false; - } - - const NodeDef* axis_arg_node = - node_map_->GetNode(NodeName(node->input(axis_arg))); - if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) { - // We cannot constant fold Concat unless we the axis argument is - // constant. Skip node. - return false; - } - - // We search for consecutive runs of constant inputs in the range - // [begin:end[ and push then down into child nodes. - std::vector> constant_input_runs; - int first = begin; - int last = begin; - while (last < end) { - while (first < end && !IsReallyConstant(*node_map_->GetNode( - NodeName(node->input(first))))) { - ++first; - } - // Invariant: node[first] is constant || first >= end. - last = first + 1; - while (last < end && IsReallyConstant(*node_map_->GetNode( - NodeName(node->input(last))))) { - ++last; - } - // Invariant: node[last] is not constant || last >= end - // Discard intervals shorter than 2 elements. - if (first < end && (last - first) > 1) { - constant_input_runs.emplace_back(first, last); - } - first = last; - } - - // Skip if all inputs are constant, and let constant folding take over. - if (constant_input_runs.size() == 1 && - constant_input_runs[0].first == begin && - constant_input_runs[0].second == end) { - return false; - } - std::set inputs_to_delete; - for (auto interval : constant_input_runs) { - // Push the constant inputs in the interval to a child node than can be - // constant folded. - string new_node_name = OptimizedNodeName(*node, "_partial_split"); - do { - new_node_name += strings::StrCat("_", interval.first); - } while (node_map_->NodeExists(new_node_name)); - - NodeDef* added_node = optimized_graph->add_node(); - *added_node = *node; - added_node->set_name(new_node_name); - node_map_->AddNode(added_node->name(), added_node); - added_node->clear_input(); - for (int i = interval.first; i < interval.second; ++i) { - added_node->add_input(node->input(i)); - node_map_->UpdateOutput(NodeName(node->input(i)), node->name(), - added_node->name()); - if (i != interval.first) { - inputs_to_delete.insert(i); - } - } - added_node->add_input(node->input(axis_arg)); - (*added_node->mutable_attr())["N"].set_i(interval.second - - interval.first); - node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name()); - - // Overwrite the first constant input with the result of the added - // child node. - node->set_input(interval.first, added_node->name()); - node_map_->AddOutput(added_node->name(), node->name()); - } - if (!constant_input_runs.empty()) { - if (!inputs_to_delete.empty()) { - // Fix up the inputs to the original node. - std::vector tmp(node->input().begin(), node->input().end()); - node->clear_input(); - for (int i = 0; i < tmp.size(); ++i) { - if (inputs_to_delete.find(i) == inputs_to_delete.end()) { - node->add_input(tmp[i]); - } - } - (*node->mutable_attr())["N"].set_i(node->input_size() - 1); - properties->ClearInputProperties(node->name()); - } - return true; - } + if (!IsConcat(*node) || + node->name().rfind("_partial_split_") != string::npos) { + return false; } - return false; + const int num_non_control_inputs = NumNonControlInputs(*node); + if (num_non_control_inputs <= 3) return false; + int axis_arg = -1; + int begin = 0; + int end = num_non_control_inputs; + if (node->op() == "Concat") { + begin = 1; + axis_arg = 0; + } else if (node->op() == "ConcatV2") { + end = num_non_control_inputs - 1; + axis_arg = num_non_control_inputs - 1; + } else { + return false; + } + + // We search for consecutive runs of constant inputs in the range + // [begin:end[ and push then down into child nodes. + std::vector> constant_input_runs; + int first = begin; + int last = begin; + while (last < end) { + while (first < end && !IsReallyConstant(*node_map_->GetNode( + NodeName(node->input(first))))) { + ++first; + } + // Invariant: node[first] is constant || first >= end. + last = first + 1; + while (last < end && + IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) { + ++last; + } + // Invariant: node[last] is not constant || last >= end + // Discard intervals shorter than 2 elements. + if (first < end && (last - first) > 1) { + constant_input_runs.emplace_back(first, last); + } + first = last; + } + + // Skip if all inputs are constant, and let constant folding take over. + if (constant_input_runs.empty() || (constant_input_runs.size() == 1 && + constant_input_runs[0].first == begin && + constant_input_runs[0].second == end)) { + return false; + } + std::set inputs_to_delete; + for (auto interval : constant_input_runs) { + // Push the constant inputs in the interval to a child node than can be + // constant folded. + string new_node_name = OptimizedNodeName(*node, "_partial_split"); + do { + new_node_name += strings::StrCat("_", interval.first); + } while (node_map_->NodeExists(new_node_name)); + + NodeDef* added_node = optimized_graph->add_node(); + *added_node = *node; + added_node->set_op("ConcatV2"); + added_node->set_name(new_node_name); + node_map_->AddNode(added_node->name(), added_node); + added_node->clear_input(); + for (int i = interval.first; i < interval.second; ++i) { + added_node->add_input(node->input(i)); + node_map_->UpdateInput(node->name(), node->input(i), added_node->name()); + if (i != interval.first) { + inputs_to_delete.insert(i); + } + } + added_node->add_input(node->input(axis_arg)); + (*added_node->mutable_attr())["N"].set_i(interval.second - interval.first); + node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name()); + + // Overwrite the first constant input with the result of the added + // child node. + node->set_input(interval.first, added_node->name()); + } + if (!constant_input_runs.empty() && !inputs_to_delete.empty()) { + // Fix up the inputs to the original node. + protobuf::RepeatedPtrField tmp; + tmp.Swap(node->mutable_input()); + for (int i = 0; i < tmp.size(); ++i) { + if (inputs_to_delete.find(i) == inputs_to_delete.end()) { + node->add_input(tmp.Get(i)); + } + } + (*node->mutable_attr())["N"].set_i(node->input_size() - 1); + properties->ClearInputProperties(node->name()); + } + return true; } -bool ConstantFolding::MergeConcat(const GraphProperties& properties, - bool use_shape_info, +bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) { + if (node.op() != "ConcatV2") { + return false; + } + int axis_idx = node.input_size() - 1; + while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) { + --axis_idx; + } + if (axis_idx <= 0) { + return false; + } + Tensor axis_tensor; + if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) { + return false; + } + *axis = axis_tensor.dtype() == DT_INT64 + ? static_cast(axis_tensor.scalar()()) + : axis_tensor.scalar()(); + return true; +} + +bool ConstantFolding::MergeConcat(bool use_shape_info, GraphDef* optimized_graph, NodeDef* node) { // We only optimize for ConcatV2. int axis; - if (!use_shape_info || !GetConcatAxis(properties, node, &axis) || + if (!use_shape_info || !GetConcatAxis(*node, &axis) || nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() || node_map_->GetOutputs(node->name()).size() != 1) { return false; } + // If all inputs are constant, don't merge and let folding take case of it. + const int num_regular_inputs = NumNonControlInputs(*node); + bool all_inputs_are_const = true; + for (int i = 0; i < num_regular_inputs - 1; ++i) { + const NodeDef* input_node = node_map_->GetNode(node->input(i)); + if (!IsReallyConstant(*input_node)) { + all_inputs_are_const = false; + } + } + if (all_inputs_are_const) return false; + NodeDef* parent = *node_map_->GetOutputs(node->name()).begin(); int parent_axis; - if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) { + if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) { return false; } - const int index = NumNonControlInputs(*node) - 1; - auto inputs = parent->input(); - parent->clear_input(); - for (int i = 0; i < inputs.size(); ++i) { - if (IsSameInput(inputs.Get(i), node->name())) { - for (int j = 0; j < node->input_size(); ++j) { - if (j < index) { - // Input tensors (non axis), add to input list of parent. - parent->add_input(node->input(j)); - node_map_->RemoveOutput(node->input(j), node->name()); - node_map_->AddOutput(node->input(j), parent->name()); - } - // Skip j == index, which means axis tensor. - if (j > index) { - // Control Dependencies, push back to inputs so they can be forwarded - // to parent. - *inputs.Add() = node->input(j); - } + protobuf::RepeatedPtrField parent_inputs; + parent_inputs.Swap(parent->mutable_input()); + std::vector ctrl_output; + // TODO(rmlarsen): IF the child occurs more than once, is it beneficial to + // collapse it into the parent multiple times? Probablyu not. + for (const auto& input : parent_inputs) { + if (IsSameInput(input, node->name())) { + for (int j = 0; j < num_regular_inputs - 1; ++j) { + // Add tensor inputs to first child concat tensors (exceptthe final axis + // input) to the parent's inputs. + parent->add_input(node->input(j)); + node_map_->UpdateInput(parent->name(), node->name(), node->input(j)); } } else { - parent->add_input(inputs.Get(i)); + parent->add_input(input); } } + // Forward Add control inputs + for (int i = num_regular_inputs; i < node->input_size(); ++i) { + parent->add_input(node->input(i)); + node_map_->UpdateInput(parent->name(), node->name(), node->input(i)); + } node->clear_input(); node->set_op("NoOp"); node->clear_attr(); node_map_->RemoveNode(node->name()); (*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1); + DedupControlInputs(parent); return true; } @@ -3344,7 +3297,9 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, // that the shape inference deals with this conservatively unless we're in // aggressive mode. const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; - Status s = properties.InferStatically(assume_valid_feeds); + Status s = properties.InferStatically(assume_valid_feeds, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false); const bool can_use_shape_info = s.ok(); if (can_use_shape_info) { diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index b4c39a5074b..d9e19a57ae3 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -61,6 +61,8 @@ class ConstantFolding : public GraphOptimizer { bool IsReallyConstant(const NodeDef& node) const; + bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor); + Status MaterializeShapes(const GraphProperties& properties); Status MaterializeBroadcastGradientArgs(const NodeDef& node, @@ -239,8 +241,9 @@ class ConstantFolding : public GraphOptimizer { void RemoveSplitOrSplitV(const GraphProperties& properties, GraphDef* optimized_graph, NodeDef* node); - bool MergeConcat(const GraphProperties& properties, bool use_shape_info, - GraphDef* optimized_graph, NodeDef* node); + bool GetConcatAxis(const NodeDef& node, int* axis); + bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph, + NodeDef* node); Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node, GraphDef* optimized_graph); diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc index 6ee32252be3..4a5d0031466 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc @@ -2390,12 +2390,11 @@ TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) { TF_EXPECT_OK(status); GraphDef want; - AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want); + AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want); AddNode("axis", "Const", {}, {}, &want); AddNode("ph", "Placeholder", {}, {}, &want); AddNode("concat2", "ConcatV2", - {"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {}, - &want); + {"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want); CompareGraphs(want, got); } diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc index 00cd776c907..02675328b71 100644 --- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc @@ -2209,7 +2209,10 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } GraphProperties graph_properties(item); - TF_RETURN_IF_ERROR(graph_properties.InferStatically(false)); + TF_RETURN_IF_ERROR( + graph_properties.InferStatically(/*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED(); virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices())); diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc index 391b11a0f8e..355e77fe7eb 100644 --- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc @@ -102,7 +102,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph, if (!properties->has_properties()) { // This is an expensive call, call it lazily. TF_RETURN_IF_ERROR(properties->InferStatically( - /*assume_valid_feeds=*/false)); + /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); } const auto& output_properties = properties->GetOutputProperties(node.name()); if (port_id >= output_properties.size()) { @@ -252,7 +253,8 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties, if (!properties->has_properties()) { // This is an expensive call, call it lazily. TF_RETURN_IF_ERROR(properties->InferStatically( - /*assume_valid_feeds=*/false)); + /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); } for (const auto& prop : properties->GetOutputProperties(node.name())) { if (!IsTensorSmall(prop)) { diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index e095c9c6f41..1d3932f28b4 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -1170,7 +1170,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item, // Infer properties lazily in case they are not needed. if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) { - TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(false)); + // TODO(rmlarsen): Get rid of tensor value copies. + TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically( + /*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/true)); ctx.inferred_graph_properties = true; } diff --git a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc index c51c5fcfaf5..13fb883217a 100644 --- a/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.cc @@ -739,9 +739,9 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/, GraphProperties graph_properties(item); const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE; - LOG_WARNING_AND_RETURN_IF_ERROR( - graph_properties.InferStatically(assume_valid_feeds)); - + LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically( + assume_valid_feeds, /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); *optimized_graph = item.graph; node_map_.reset(new NodeMap(optimized_graph)); diff --git a/tensorflow/core/grappler/optimizers/shape_optimizer.cc b/tensorflow/core/grappler/optimizers/shape_optimizer.cc index 66e77e7ae72..dfdbc8cfc87 100644 --- a/tensorflow/core/grappler/optimizers/shape_optimizer.cc +++ b/tensorflow/core/grappler/optimizers/shape_optimizer.cc @@ -87,7 +87,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1)); if (!inferred_properties) { // Infer properties lazily in case they are not needed. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); inferred_properties = true; } const auto& prop = @@ -144,7 +147,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item, } if (!inferred_properties) { // Infer properties lazily in case they are not needed. - TF_RETURN_IF_ERROR(properties.InferStatically(false)); + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/false, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); inferred_properties = true; } const auto& prop1 = properties.GetInputProperties(input1.node->name()); diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc index 9950db063d6..ede89a642ee 100644 --- a/tensorflow/core/grappler/optimizers/static_schedule.cc +++ b/tensorflow/core/grappler/optimizers/static_schedule.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/grappler/optimizers/static_schedule.h" + #include + #include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/grappler/costs/graph_properties.h" #include "tensorflow/core/grappler/costs/op_level_cost_estimator.h" @@ -92,7 +94,10 @@ Status EstimateEarliestExecutionTimes( name_map.clear(); GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(true)); + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/true, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster->GetDevices()); @@ -160,7 +165,10 @@ Status EstimateRequiredTimes( } } GraphProperties properties(item); - TF_RETURN_IF_ERROR(properties.InferStatically(true)); + TF_RETURN_IF_ERROR( + properties.InferStatically(/*assume_valid_feeds=*/true, + /*aggressive_shape_inference=*/false, + /*include_tensor_values=*/false)); OpLevelCostEstimator estimator; VirtualPlacer placer(cluster->GetDevices()); diff --git a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc index c5fa9b16b0c..ebcbcc68334 100644 --- a/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc +++ b/tensorflow/tools/graph_transforms/fold_old_batch_norms_test.cc @@ -346,7 +346,8 @@ class FoldOldBatchNormsTest : public ::testing::Test { std::vector fused_outputs; TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs)); - test::ExpectTensorNear(original_outputs[0], fused_outputs[0], 2e-5); + test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5, + /*rtol=*/2e-5); for (const NodeDef& node : fused_graph_def.node()) { EXPECT_NE("FusedBatchNorm", node.op());