Add an option to avoid copying TensorProtos in shape inference when not needed. This saves a significant amount of memory and time in various optimizers requiring static shape inference. Calling GraphProperties::InferStatically would make at least one copy of all constant tensors stored in the graph. This CL adds a new argument to InferStatically, allowing callers to skip copying tensor values to InputProperties and OutputProperties for optimizers that only need shape and datatype information.

Turning this off for constant folding and arithmetic optimizer required cleaning up a few rewrite rules that relied on reading the copy of constant tensor values from GraphProperties instead of just extracting the original value from the corresponding constant node in the graph.

Measured results on Transformer graph:

CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB
Benchmark                Time(ns)        CPU(ns) Allocs     Iterations
----------------------------------------------------------------------
Before:
BM_OptimizeTransformer 5161658741     5281083018 23532315            1   139.189MB peak-mem
After:
BM_OptimizeTransformer 4887891650     5005937073 23063225            1   132.652MB peak-mem

Effects on this graph are not dramatic, since it does not contain a lot of large constant tensors. The effect would be much more pronounced on frozen inference graphs.

PiperOrigin-RevId: 250740251
This commit is contained in:
A. Unique TensorFlower 2019-05-30 12:36:28 -07:00 committed by TensorFlower Gardener
parent d7d2307248
commit 705b193812
15 changed files with 643 additions and 663 deletions

View File

@ -712,7 +712,10 @@ class SymbolicShapeRefiner {
// Perform inference on function body.
GraphProperties gp(grappler_function_item);
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
TF_RETURN_IF_ERROR(gp.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/aggressive_shape_inference_,
/*include_tensor_values=*/true));
// Add return nodes for output shapes.
int output = 0;
@ -2066,7 +2069,8 @@ Status GraphProperties::UpdateEnqueue(
}
Status GraphProperties::InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference) {
bool aggressive_shape_inference,
bool include_tensor_values) {
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library());
std::unordered_map<string, std::unordered_set<int>> fed_ports;
@ -2225,20 +2229,23 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
&input_properties[i]);
input.port_id = i;
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
// Export tensor value to input_properties.value.
if (IsConstant(*fanin.node)) {
const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
*input_properties[i].mutable_value() = raw_val;
} else if (ctx->input_tensor_protos.size() > i &&
ctx->input_tensor_protos[i] != nullptr) {
*input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
} else if (ic->input_tensors_as_shapes().size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i])) {
*input_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i]);
if (include_tensor_values) {
// Export tensor value to input_properties.value.
if (IsConstant(*fanin.node)) {
const TensorProto& raw_val =
fanin.node->attr().at("value").tensor();
*input_properties[i].mutable_value() = raw_val;
} else if (ctx->input_tensor_protos.size() > i &&
ctx->input_tensor_protos[i] != nullptr) {
*input_properties[i].mutable_value() = *ctx->input_tensor_protos[i];
} else if (ic->input_tensors_as_shapes().size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i])) {
*input_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->input(i), ic->input_tensors_as_shapes()[i],
ctx->input_types[i]);
}
}
}
}
@ -2254,20 +2261,24 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
for (int i = 0; i < ic->num_outputs(); ++i) {
shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
&output_properties[i]);
// Export tensor value to output_properties.value.
if (IsConstant(node)) {
const TensorProto& raw_val = node.attr().at("value").tensor();
*output_properties[i].mutable_value() = raw_val;
} else if (ctx->output_tensor_protos.size() > i &&
ctx->output_tensor_protos[i] != nullptr) {
*output_properties[i].mutable_value() = *ctx->output_tensor_protos[i];
} else if (ctx->output_tensors_as_shapes.size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i])) {
*output_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i]);
if (include_tensor_values) {
// Export tensor value to output_properties.value.
if (IsConstant(node)) {
// TODO(rmlarsen): Eliminate this copy.
const TensorProto& raw_val = node.attr().at("value").tensor();
*output_properties[i].mutable_value() = raw_val;
} else if (ctx->output_tensor_protos.size() > i &&
ctx->output_tensor_protos[i] != nullptr) {
*output_properties[i].mutable_value() =
*ctx->output_tensor_protos[i];
} else if (ctx->output_tensors_as_shapes.size() > i &&
IsShapeFullyDefinedIntegerVectorOrScalar(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i])) {
*output_properties[i].mutable_value() = MakeTensorProtoFromShape(
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
ctx->output_types[i]);
}
}
}
}

View File

@ -89,11 +89,15 @@ class GraphProperties {
// output values when possible and does other aggressive strategies.
// Similar to assuming_valid_feeds, this may cause incorrectness in graph
// analyses, but is useful for simulation or scheduling.
// If include_values is true, the values of constant tensors will be
// included in the input and output properties.
Status InferStatically(bool assume_valid_feeds,
bool aggressive_shape_inference);
bool aggressive_shape_inference,
bool include_tensor_values);
Status InferStatically(bool assume_valid_feeds) {
return InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false);
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true);
}
// Infer the shape by running the graph on the specified cluster and recording
// the shapes of the processed tensors.
@ -117,6 +121,7 @@ class GraphProperties {
const string& node_name) const;
const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
const string& node_name) const;
// Invalidate input/output properties for nodes modified during graph
// optimization pass, to prevent potential optimizations, based on incorrect
// shape information.

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/framework/graph_def_util.h"
@ -1009,7 +1010,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
@ -1028,7 +1030,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("fill");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
@ -1248,10 +1251,12 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// evaluate output value.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false));
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
LOG(INFO) << out_prop0.DebugString();
EXPECT_FALSE(out_prop0.has_value());
}
@ -1260,7 +1265,8 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
// With aggressive_shape_inference, output value is evaluated.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto out_props = properties.GetOutputProperties("MyFunc");
const OpInfo::TensorProperties out_prop0 = out_props[0];
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
@ -1802,7 +1808,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false));
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
}
@ -1812,7 +1819,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
const auto slice_value =
properties.GetOutputProperties("slice").at(0).value();
@ -1838,7 +1846,8 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
// Check output shapes and values.
const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
@ -1881,7 +1890,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Without aggressive_shape_inference, ignore annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false));
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
@ -1895,7 +1905,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
@ -1923,7 +1934,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
@ -1950,7 +1962,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("Identity");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];
@ -1977,7 +1990,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
// Use annotated information.
TF_CHECK_OK(properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/true));
/*aggressive_shape_inference=*/true,
/*include_tensor_values=*/true));
const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
EXPECT_EQ(1, props.size());
const OpInfo::TensorProperties& prop = props[0];

View File

@ -359,7 +359,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
graph_properties_ = absl::make_unique<GraphProperties>(*item);
if (use_static_shapes_) {
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
true, use_aggressive_shape_inference_));
true, use_aggressive_shape_inference_, true));
} else {
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
}

View File

@ -234,6 +234,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
DedupControlInputs(target_node);
}
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool IsInPreserveSet(const NodeDef& node) const {
return ctx().nodes_to_preserve->find(node.name()) !=
ctx().nodes_to_preserve->end();
@ -259,6 +267,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
return false;
}
bool GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
private:
// Extended context required for ArithmeticOptimizer.
const ArithmeticOptimizerContext ctx_ext_;
@ -2480,101 +2496,78 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
bool IsSupported(const NodeDef* node) const override {
return IsPow(*node) &&
ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
ctx().graph_properties->HasOutputProperties(node->name()) &&
ctx().graph_properties->HasInputProperties(node->name());
}
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
const auto& pow_props =
ctx().graph_properties->GetInputProperties(node->name())[1];
PartialTensorShape shape(pow_props.shape());
if (!shape.IsFullyDefined()) {
// skip if p is not fully defined.
return Status::OK();
Tensor pow;
if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
}
prev = curr;
}
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
Tensor pow(pow_props.dtype(), pow_props.shape());
if (!pow.FromProto(pow_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
pow_props.value().DebugString());
}
complex128 prev, curr;
for (int i = 0; i < pow.NumElements(); ++i) {
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
// input data type is not supported by Pow. Skip.
return Status::OK();
}
if (i != 0 && curr != prev) {
// pow has different values on different elements. Skip.
return Status::OK();
}
prev = curr;
}
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
const auto& value_props =
ctx().graph_properties->GetInputProperties(node->name())[0];
const TensorShapeProto& output_shape =
ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(1, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// Pow could be used to broadcast, so make sure the shapes of the two
// arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0.5, 0)) {
node->set_op("Sqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
PartialTensorShape shape(value_props.shape());
if (!shape.IsFullyDefined()) {
// skip if b is not fully defined.
return Status::OK();
}
if (TensorShape::IsValid(value_props.shape()) &&
value_props.has_value()) {
Tensor base(value_props.dtype(), value_props.shape());
if (!base.FromProto(value_props.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
value_props.value().DebugString());
}
node->set_op("Const");
Tensor c(base.dtype(), base.shape());
for (int i = 0; i < c.NumElements(); ++i) {
TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
}
(*node->mutable_attr())["dtype"].set_type(base.dtype());
c.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->mutable_attr()->erase("T");
node->set_input(0, AsControlDependency(x->name()));
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
}
} else if (curr == complex128(-0.5, 0)) {
node->set_op("Rsqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(-1, 0)) {
node->set_op("Reciprocal");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &x));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &y));
const auto& value_props =
ctx().graph_properties->GetInputProperties(node->name())[0];
const TensorShapeProto& output_shape =
ctx().graph_properties->GetOutputProperties(node->name())[0].shape();
if (curr == complex128(2, 0)) {
node->set_op("Square");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(1, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
// Pow could be used to broadcast, so make sure the shapes of the two
// arguments are identical before replacing Pow with Identity.
node->set_op("Identity");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0.5, 0)) {
node->set_op("Sqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(0, 0) &&
ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
PartialTensorShape(output_shape).IsFullyDefined()) {
const auto dtype = node->attr().at("T").type();
Tensor ones(dtype, output_shape);
for (int i = 0; i < ones.NumElements(); ++i) {
TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
}
node->set_op("Const");
(*node->mutable_attr())["dtype"].set_type(dtype);
node->mutable_attr()->erase("T");
ones.AsProtoTensorContent(
(*node->mutable_attr())["value"].mutable_tensor());
node->set_input(0, AsControlDependency(x->name()));
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
} else if (curr == complex128(-0.5, 0)) {
node->set_op("Rsqrt");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
} else if (curr == complex128(-1, 0)) {
node->set_op("Reciprocal");
node->set_input(1, AsControlDependency(y->name()));
AddToOptimizationQueue(node);
AddToOptimizationQueue(y);
}
return Status::OK();
}
@ -2638,12 +2631,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
private:
Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
bool* modified) {
const auto& t =
ctx().graph_properties->GetInputProperties(input->name())[i];
ctx().graph_properties->GetInputProperties(add_node->name())[i];
const auto& c =
ctx().graph_properties->GetInputProperties(input->name())[j];
ctx().graph_properties->GetInputProperties(add_node->name())[j];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
@ -2659,13 +2652,13 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
// broadcast.
return Status::OK();
}
if (TensorShape::IsValid(c.shape()) && c.has_value()) {
Tensor constant(c.dtype(), c.shape());
if (!constant.FromProto(c.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
c.value().DebugString());
}
Tensor constant;
if (GetTensorFromConstNode(add_node->input(j), &constant)) {
complex128 element;
// TODO(rmlarsen): Refactor the more general IsOnes from
// constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
// or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
// and Neg(-1) to 1.
for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
@ -2680,15 +2673,15 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
}
}
NodeDef *x, *y;
TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
node->set_op("Log1p");
node->set_input(0, input->input(i));
node->set_input(0, add_node->input(i));
node->add_input(AsControlDependency(y->name()));
ForwardControlDependencies(node, {input});
ForwardControlDependencies(node, {add_node});
AddToOptimizationQueue(node);
AddToOptimizationQueue(input);
AddToOptimizationQueue(add_node);
AddToOptimizationQueue(x);
AddToOptimizationQueue(y);
*modified = true;
@ -2717,25 +2710,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
return Status::OK();
}
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
if (!IsExp(*exp)) {
return Status::OK();
}
if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
return Status::OK();
}
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
for (int k = 0; k < c.shape().dim_size(); ++k) {
// Skip if c shape is not fully determined.
if (c.shape().dim(k).size() < 0) {
return Status::OK();
}
}
TensorShapeProto broadcast_shape;
if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
return Status::OK();
@ -2745,39 +2721,39 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
// broadcast.
return Status::OK();
}
if (TensorShape::IsValid(c.shape()) && c.has_value()) {
Tensor constant(c.dtype(), c.shape());
if (!constant.FromProto(c.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
c.value().DebugString());
Tensor constant;
if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
// TODO(rmlarsen): Use the more general IsOnes helper here.
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128},
&element)) {
// input data type is not supported by expm1. Skip.
return Status::OK();
}
complex128 element;
for (int k = 0; k < constant.NumElements(); ++k) {
if (!GetElementUnexhaustive(constant, k,
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
DT_COMPLEX64, DT_COMPLEX128},
&element)) {
// input data type is not supported by expm1. Skip.
return Status::OK();
}
if (element != complex128(1)) {
// current element is not 1. Skip.
return Status::OK();
}
LOG(INFO) << "Got element = " << element;
if (element != complex128(1)) {
// current element is not 1. Skip.
return Status::OK();
}
NodeDef *exp_input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
node->set_op("Expm1");
node->set_input(0, exp->input(0));
node->set_input(1, AsControlDependency(ones->name()));
ForwardControlDependencies(node, {exp});
AddToOptimizationQueue(node);
AddToOptimizationQueue(exp);
AddToOptimizationQueue(exp_input);
AddToOptimizationQueue(ones);
}
NodeDef* exp;
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
NodeDef *exp_input, *ones;
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
node->set_op("Expm1");
node->set_input(0, exp->input(0));
node->set_input(1, AsControlDependency(ones->name()));
ForwardControlDependencies(node, {exp});
AddToOptimizationQueue(node);
AddToOptimizationQueue(exp);
AddToOptimizationQueue(exp_input);
AddToOptimizationQueue(ones);
*simplified_node_name = node->name();
return Status::OK();
}
};
@ -3096,14 +3072,6 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
}
protected:
bool IsReallyConstant(const NodeDef& node) const {
if (!IsConstant(node)) {
return false;
}
// If the node is fed it's not constant anymore.
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
}
bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
std::vector<int64>* values) {
if (dtype == DT_INT32) {
@ -3430,8 +3398,6 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
void ArithmeticOptimizer::DedupComputations() {
CanonicalizeGraph(optimized_graph_);
// LOG(INFO) << "Graph after canonicalization: \n"
// << optimized_graph_->DebugString();
GraphTopologyView graph_view;
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
@ -3683,7 +3649,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
graph_properties_.reset(new GraphProperties(optimized_item));
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
const Status status = graph_properties_->InferStatically(assume_valid_feeds);
const Status status =
graph_properties_->InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shapes = status.ok();
if (!can_use_shapes) {
VLOG(1) << "Shape inference failed." << status.error_message();

View File

@ -122,27 +122,6 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
return removed_input;
}
bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
int* axis) {
if (node->op() != "ConcatV2" ||
properties.GetInputProperties(node->name()).empty()) {
return false;
}
const auto& axis_input = properties.GetInputProperties(node->name()).back();
if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
return false;
}
Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
if (!axis_tensor.FromProto(axis_input.value())) {
return false;
}
*axis = axis_input.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool HasTPUAttributes(const NodeDef& node) {
AttrSlice attrs(node);
for (auto attr : attrs) {
@ -220,9 +199,9 @@ string ConstantFolding::AddControlDependency(const string& input_name,
if (IsControlInput(input_name)) {
return input_name;
}
const NodeDef* node = node_map->GetNode(input_name);
if (!IsSwitch(*node)) {
return AsControlDependency(*node);
const NodeDef& node = *node_map->GetNode(input_name);
if (!IsSwitch(node)) {
return AsControlDependency(node);
} else {
// We can't anchor control dependencies directly on the switch node: unlike
// other nodes only one of the outputs of the switch node will be generated
@ -230,10 +209,10 @@ string ConstantFolding::AddControlDependency(const string& input_name,
// dependency is only triggered when the corresponding output is triggered.
// We start by looking for an identity node connected to the output of the
// switch node, and use it to anchor the control dependency.
auto outputs = node_map->GetOutputs(node->name());
auto outputs = node_map->GetOutputs(node.name());
for (const NodeDef* output : outputs) {
if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
if (IsSameInput(node->input(0), input_name)) {
if (IsSameInput(node.input(0), input_name)) {
return AsControlDependency(*output);
}
}
@ -244,19 +223,19 @@ string ConstantFolding::AddControlDependency(const string& input_name,
string ctrl_dep_name = ParseNodeName(input_name, &port);
strings::StrAppend(&ctrl_dep_name, "_", port);
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
const DataType output_type = node->attr().at("T").type();
const DataType output_type = node.attr().at("T").type();
NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
if (added_node == nullptr) {
added_node = graph->add_node();
added_node->set_name(ctrl_dep_name);
added_node->set_op("Identity");
added_node->set_device(node->device());
added_node->set_device(node.device());
(*added_node->mutable_attr())["T"].set_type(output_type);
*added_node->add_input() = input_name;
node_map->AddNode(added_node->name(), added_node);
node_map->AddOutput(node->name(), added_node->name());
node_map->AddOutput(node.name(), added_node->name());
}
return AsControlDependency(*added_node);
}
@ -321,6 +300,15 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
return feed_nodes_.find(node.name()) == feed_nodes_.end();
}
// TODO(rmlarsen): Refactor to shared util.
bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
Tensor* tensor) {
const NodeDef* node = node_map_->GetNode(node_name_or_input);
return node != nullptr && IsReallyConstant(*node) &&
CheckAttrExists(*node, "value").ok() &&
tensor->FromProto(node->attr().at("value").tensor());
}
// Materialize the shapes using constants whenever possible.
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies and hold
@ -868,6 +856,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return false;
}
if (node.op() == "AccumulateNV2") {
return false;
}
// Skips ops that don't benefit from folding.
if (IsPlaceholder(node)) {
return false;
@ -1856,9 +1847,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
SET_AND_RETURN_IF_MODIFIED(
PartialAssocOpConstFolding(optimized_graph, properties, node));
SET_AND_RETURN_IF_MODIFIED(
PartialConcatConstFolding(optimized_graph, properties, node));
MergeConcat(use_shape_info, optimized_graph, node));
SET_AND_RETURN_IF_MODIFIED(
MergeConcat(*properties, use_shape_info, optimized_graph, node));
PartialConcatConstFolding(optimized_graph, properties, node));
graph_modified_ = graph_modified_cached;
return Status::OK();
@ -1879,43 +1870,33 @@ void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
Status ConstantFolding::RemoveShuffleOrTranspose(
const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
properties.GetInputProperties(node->name()).size() >= 2) {
if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
return Status::OK();
Tensor permutation_tensor;
if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
properties.HasInputProperties(node->name())) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) {
// Not optimizable.
std::vector<int> permutation;
for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
if (permutation_tensor.dtype() == DT_INT64) {
permutation.push_back(permutation_tensor.vec<int64>()(j));
} else {
permutation.push_back(permutation_tensor.vec<int>()(j));
}
}
if (permutation.size() != shape.dim_size()) {
// Number of elements in perm should be same as dim_size. Skip if not.
return Status::OK();
}
const auto& p = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor perm(p.dtype(), p.shape());
if (!perm.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
std::vector<int> permutation;
for (int j = 0; j < perm.NumElements(); ++j) {
if (perm.dtype() == DT_INT64) {
permutation.push_back(perm.vec<int64>()(j));
} else {
permutation.push_back(perm.vec<int>()(j));
}
}
if (permutation.size() != shape.dim_size()) {
// Number of elements in perm should be same as dim_size. Skip if not.
return Status::OK();
}
// The node is replaceable iff
// dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not permuted.
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
}
// The node is replaceable iff
// dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not permuted.
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 || j == permutation[j];
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
@ -1941,44 +1922,35 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && node->op() == "ReverseV2" &&
properties.GetInputProperties(node->name()).size() >= 2) {
if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
Tensor axis;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &axis)) {
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
if (shape.unknown_rank()) {
// Not optimizable.
return Status::OK();
if (shape.unknown_rank()) return Status::OK();
std::set<int> target_axes;
for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
}
const auto& a = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(a.shape()) && a.has_value()) {
Tensor axis(a.dtype(), a.shape());
if (!axis.FromProto(a.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
a.value().DebugString());
}
std::set<int> target_axes;
for (int j = 0; j < axis.NumElements(); ++j) {
// value of axis can be negative.
if (axis.dtype() == DT_INT64) {
target_axes.insert((axis.vec<int64>()(j) + shape.dim_size()) %
shape.dim_size());
} else {
target_axes.insert((axis.vec<int>()(j) + shape.dim_size()) %
shape.dim_size());
}
}
// The node is replaceable iff
// unknown_rank == false &&
// (dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not in target_axes)
bool replaceable = !shape.unknown_rank();
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &= shape.dim(j).size() == 1 ||
target_axes.find(j) == target_axes.end();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
// The node is replaceable iff
// unknown_rank == false &&
// (dim_size == 0 || all dims have size 1 ||
// all dims with > 1 size are not in target_axes)
bool replaceable = true;
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
replaceable &=
shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
@ -1988,45 +1960,33 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph,
NodeDef* node) {
if (use_shape_info && IsSlice(*node) &&
properties.GetInputProperties(node->name()).size() == 3) {
if (!use_shape_info || !IsSlice(*node)) return Status::OK();
Tensor begin;
Tensor size;
if (properties.HasInputProperties(node->name()) &&
GetTensorFromConstNode(node->input(1), &begin) &&
GetTensorFromConstNode(node->input(2), &size)) {
const auto& input = properties.GetInputProperties(node->name())[0];
const auto& b = properties.GetInputProperties(node->name())[1];
const auto& s = properties.GetInputProperties(node->name())[2];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
if (!begin.FromProto(b.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
// The node is replaceable iff unknown_rank == false &&
// begin == 0 && (size == -1 || size == input_shape) for all dimensions
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (begin.dtype() == DT_INT32) {
replaceable &= begin.vec<int>()(j) == 0;
} else {
replaceable &= begin.vec<int64>()(j) == 0;
}
Tensor size(s.dtype(), s.shape());
if (!size.FromProto(s.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
}
// The node is replaceable iff unknown_rank == false &&
// begin == 0 && (size == -1 || size == input_shape) for all dimensions
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (begin.dtype() == DT_INT32) {
replaceable &= begin.vec<int>()(j) == 0;
} else {
replaceable &= begin.vec<int64>()(j) == 0;
}
if (size.dtype() == DT_INT32) {
replaceable &= (size.vec<int>()(j) == -1 ||
size.vec<int>()(j) == input.shape().dim(j).size());
} else {
replaceable &= (size.vec<int64>()(j) == -1 ||
size.vec<int64>()(j) == input.shape().dim(j).size());
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
return Status::OK();
if (size.dtype() == DT_INT32) {
replaceable &= (size.vec<int>()(j) == -1 ||
size.vec<int>()(j) == input.shape().dim(j).size());
} else {
replaceable &= (size.vec<int64>()(j) == -1 ||
size.vec<int64>()(j) == input.shape().dim(j).size());
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
@ -2052,81 +2012,70 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
return Status::OK();
}
}
const auto& b = properties.GetInputProperties(node->name())[1];
const auto& e = properties.GetInputProperties(node->name())[2];
const auto& s = properties.GetInputProperties(node->name())[3];
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
TensorShape::IsValid(e.shape()) && e.has_value() &&
TensorShape::IsValid(s.shape()) && s.has_value()) {
Tensor begin(b.dtype(), b.shape());
if (!begin.FromProto(b.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
b.value().DebugString());
}
Tensor end(e.dtype(), e.shape());
if (!end.FromProto(e.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
e.value().DebugString());
}
Tensor strides(s.dtype(), s.shape());
if (!strides.FromProto(s.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
s.value().DebugString());
}
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
int ellipsis_index = -1;
for (int j = 0; j < input.shape().dim_size(); ++j) {
// find the ellipsis_mask. If not found, insert one in the end if
// necessary.
if (node->attr().at("ellipsis_mask").i() & 1 << j ||
(ellipsis_index == -1 && j >= strides.NumElements())) {
ellipsis_index = j;
}
// insert the indices that are immediately after ellipsis_index if
// necessary.
if (ellipsis_index != -1 &&
input.shape().dim_size() >
strides.NumElements() + j - ellipsis_index) {
expanded_ellipsis_indices.insert(j);
}
}
// The node is replaceable iff unknown_rank == false &&
// ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
// && strides == 1) for all dimensions.
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (expanded_ellipsis_indices.find(j) !=
expanded_ellipsis_indices.end()) {
// ellipsis_mask is effective on current dimension.
continue;
}
// when we have ellipsis_mask in between, input.shape().dim_size() will
// be greater than strides.NumElements(), since we will insert
// as many as expanded_ellipsis_indices.size() axes during computation.
// We need to subtract this number from j.
int i = j;
if (ellipsis_index != -1 &&
j >= ellipsis_index + expanded_ellipsis_indices.size()) {
i = j - expanded_ellipsis_indices.size();
}
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64>()(i);
int e =
end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64>()(i);
replaceable &=
(begin_mask & 1 << i || b == 0) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
std::vector<Tensor> input_tensors(3);
for (int i = 1; i < 4; ++i) {
if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
return Status::OK();
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
const Tensor& begin = input_tensors[0];
const Tensor& end = input_tensors[1];
const Tensor& strides = input_tensors[2];
TF_RETURN_IF_ERROR(
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
int begin_mask = node->attr().at("begin_mask").i();
int end_mask = node->attr().at("end_mask").i();
std::set<int> expanded_ellipsis_indices;
int ellipsis_index = -1;
for (int j = 0; j < input.shape().dim_size(); ++j) {
// find the ellipsis_mask. If not found, insert one in the end if
// necessary.
if (node->attr().at("ellipsis_mask").i() & 1 << j ||
(ellipsis_index == -1 && j >= strides.NumElements())) {
ellipsis_index = j;
}
// insert the indices that are immediately after ellipsis_index if
// necessary.
if (ellipsis_index != -1 &&
input.shape().dim_size() >
strides.NumElements() + j - ellipsis_index) {
expanded_ellipsis_indices.insert(j);
}
}
// The node is replaceable iff unknown_rank == false &&
// ((begin_mask is set || begin == 0) && (end_mask is set || end == dim)
// && strides == 1) for all dimensions.
bool replaceable = !input.shape().unknown_rank();
for (int j = 0; replaceable && j < input.shape().dim_size(); ++j) {
if (expanded_ellipsis_indices.find(j) !=
expanded_ellipsis_indices.end()) {
// ellipsis_mask is effective on current dimension.
continue;
}
// when we have ellipsis_mask in between, input.shape().dim_size() will
// be greater than strides.NumElements(), since we will insert
// as many as expanded_ellipsis_indices.size() axes during computation.
// We need to subtract this number from j.
int i = j;
if (ellipsis_index != -1 &&
j >= ellipsis_index + expanded_ellipsis_indices.size()) {
i = j - expanded_ellipsis_indices.size();
}
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
: begin.vec<int64>()(i);
int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
: strides.vec<int64>()(i);
replaceable &= (begin_mask & 1 << i || b == 0) &&
(end_mask & 1 << i || e == input.shape().dim(j).size()) &&
s == 1;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
@ -2135,31 +2084,23 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
Tensor multiplies;
if (use_shape_info && IsTile(*node) &&
properties.GetInputProperties(node->name()).size() == 2) {
const auto& m = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
Tensor multiplies(m.dtype(), m.shape());
if (!multiplies.FromProto(m.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
m.value().DebugString());
GetTensorFromConstNode(node->input(1), &multiplies)) {
// The node is replaceable iff all values in multiplies are 1.
bool replaceable = true;
if (multiplies.dtype() == DT_INT32) {
for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
replaceable &= multiplies.vec<int>()(j) == 1;
}
// The node is replaceable iff all values in multiplies are 1.
bool replaceable = true;
if (multiplies.dtype() == DT_INT32) {
for (int j = 0; replaceable && j < multiplies.vec<int>().size(); ++j) {
replaceable &= multiplies.vec<int>()(j) == 1;
}
} else {
for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
++j) {
replaceable &= multiplies.vec<int64>()(j) == 1;
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
} else {
for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
replaceable &= multiplies.vec<int64>()(j) == 1;
}
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
}
@ -2167,26 +2108,20 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
if (use_shape_info && IsPad(*node) &&
properties.GetInputProperties(node->name()).size() >= 2) {
const auto& p = properties.GetInputProperties(node->name())[1];
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
Tensor paddings(p.dtype(), p.shape());
if (!paddings.FromProto(p.value())) {
return errors::InvalidArgument("Cannot parse tensor from proto: ",
p.value().DebugString());
}
// The node is replaceable iff all values in paddings are 0.
bool replaceable = true;
// The operation requires it to be int32 value so we don't check for
// 1nt64.
const auto flatten = paddings.flat<int32>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
if (!use_shape_info || !IsPad(*node)) return Status::OK();
Tensor paddings;
if (GetTensorFromConstNode(node->input(1), &paddings)) {
// The node is replaceable iff all values in paddings are 0.
bool replaceable = true;
// The operation requires it to be int32 value so we don't check for
// 1nt64.
const auto flatten = paddings.flat<int32>();
for (int j = 0; replaceable && j < flatten.size(); ++j) {
replaceable &= flatten(j) == 0;
}
if (replaceable) {
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
}
}
return Status::OK();
@ -3031,73 +2966,71 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
// folding of ops when more than one but not all inputs are constant.
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since
// addition is commutative.
const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsAggregate(*node) && IsCommutative(*node) &&
num_non_control_inputs > 2) {
const int num_control_inputs = node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
CHECK(input_node != nullptr) << input;
if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
const_inputs.push_back(i);
} else {
// Non-const and control inputs.
nonconst_inputs.push_back(i);
}
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
return true;
}
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (1 < const_inputs.size() &&
const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
// Always use AddN for the constant node, since AccumulateNV2 is a fake
// node that cannot be constant folded, since it does not have a kernel.
added_node->set_op("AddN");
added_node->mutable_attr()->erase("shape");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i : const_inputs) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
}
if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
// Overwrite the first const input with the added node.
node->set_input(const_inputs[0], added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
nonconst_inputs.push_back(const_inputs[0]);
// Compact the remaining inputs to the original node.
std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
int idx = 0;
for (int i : nonconst_inputs) {
if (idx != i) {
node->set_input(idx, node->input(i));
}
++idx;
}
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() -
num_control_inputs);
properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true;
const int num_non_control_inputs = NumNonControlInputs(*node);
if (num_non_control_inputs <= 2) return false;
const int num_control_inputs = node->input_size() - num_non_control_inputs;
std::vector<int> const_inputs;
std::vector<int> nonconst_inputs;
for (int i = 0; i < node->input_size(); ++i) {
const string& input = node->input(i);
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
if (input_node == nullptr) return false;
if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
const_inputs.push_back(i);
} else {
// Non-const and control inputs.
nonconst_inputs.push_back(i);
}
}
// Promote AccumulateNV2 with all constant inputs to AddN, since it is
// a fake node that cannot be constant folded by itself.
if (const_inputs.size() == num_non_control_inputs &&
node->op() == "AccumulateNV2") {
node->set_op("AddN");
node->mutable_attr()->erase("shape");
return true;
}
const string new_node_name = OptimizedNodeName(
*node, strings::StrCat("_partial_split_", const_inputs.size()));
if (const_inputs.size() > 1 && const_inputs.size() < num_non_control_inputs &&
!node_map_->NodeExists(new_node_name)) {
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
// Always use AddN for the constant node, since AccumulateNV2 is a fake
// node that cannot be constant folded, since it does not have a kernel.
added_node->set_op("AddN");
added_node->mutable_attr()->erase("shape");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i : const_inputs) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
}
// Overwrite the first const input with the added node.
node->set_input(const_inputs[0], added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
nonconst_inputs.push_back(const_inputs[0]);
// Compact the remaining inputs to the original node.
std::sort(nonconst_inputs.begin(), nonconst_inputs.end());
int idx = 0;
for (int i : nonconst_inputs) {
if (idx != i) {
node->set_input(idx, node->input(i));
}
++idx;
}
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
const_inputs.size() - 1);
(*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
properties->ClearInputProperties(node->name());
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
return true;
}
return false;
}
@ -3107,156 +3040,176 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
// Partial constant folding for Concat which is not commutative, so
// we have to preserve order and can only push consecutive runs of constant
// inputs into sub-nodes.
const int num_non_control_inputs = NumNonControlInputs(*node);
if (IsConcat(*node) && num_non_control_inputs > 3 &&
node->name().rfind("_partial_split_") == string::npos) {
int axis_arg = -1;
int begin = 0;
int end = num_non_control_inputs;
if (node->op() == "Concat") {
begin = 1;
axis_arg = 0;
} else if (node->op() == "ConcatV2") {
end = num_non_control_inputs - 1;
axis_arg = num_non_control_inputs - 1;
} else {
return false;
}
const NodeDef* axis_arg_node =
node_map_->GetNode(NodeName(node->input(axis_arg)));
if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
// We cannot constant fold Concat unless we the axis argument is
// constant. Skip node.
return false;
}
// We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs;
int first = begin;
int last = begin;
while (last < end) {
while (first < end && !IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(first))))) {
++first;
}
// Invariant: node[first] is constant || first >= end.
last = first + 1;
while (last < end && IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(last))))) {
++last;
}
// Invariant: node[last] is not constant || last >= end
// Discard intervals shorter than 2 elements.
if (first < end && (last - first) > 1) {
constant_input_runs.emplace_back(first, last);
}
first = last;
}
// Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end) {
return false;
}
std::set<int> inputs_to_delete;
for (auto interval : constant_input_runs) {
// Push the constant inputs in the interval to a child node than can be
// constant folded.
string new_node_name = OptimizedNodeName(*node, "_partial_split");
do {
new_node_name += strings::StrCat("_", interval.first);
} while (node_map_->NodeExists(new_node_name));
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i));
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
added_node->name());
if (i != interval.first) {
inputs_to_delete.insert(i);
}
}
added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second -
interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added
// child node.
node->set_input(interval.first, added_node->name());
node_map_->AddOutput(added_node->name(), node->name());
}
if (!constant_input_runs.empty()) {
if (!inputs_to_delete.empty()) {
// Fix up the inputs to the original node.
std::vector<string> tmp(node->input().begin(), node->input().end());
node->clear_input();
for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp[i]);
}
}
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
properties->ClearInputProperties(node->name());
}
return true;
}
if (!IsConcat(*node) ||
node->name().rfind("_partial_split_") != string::npos) {
return false;
}
return false;
const int num_non_control_inputs = NumNonControlInputs(*node);
if (num_non_control_inputs <= 3) return false;
int axis_arg = -1;
int begin = 0;
int end = num_non_control_inputs;
if (node->op() == "Concat") {
begin = 1;
axis_arg = 0;
} else if (node->op() == "ConcatV2") {
end = num_non_control_inputs - 1;
axis_arg = num_non_control_inputs - 1;
} else {
return false;
}
// We search for consecutive runs of constant inputs in the range
// [begin:end[ and push then down into child nodes.
std::vector<std::pair<int, int>> constant_input_runs;
int first = begin;
int last = begin;
while (last < end) {
while (first < end && !IsReallyConstant(*node_map_->GetNode(
NodeName(node->input(first))))) {
++first;
}
// Invariant: node[first] is constant || first >= end.
last = first + 1;
while (last < end &&
IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
++last;
}
// Invariant: node[last] is not constant || last >= end
// Discard intervals shorter than 2 elements.
if (first < end && (last - first) > 1) {
constant_input_runs.emplace_back(first, last);
}
first = last;
}
// Skip if all inputs are constant, and let constant folding take over.
if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
constant_input_runs[0].first == begin &&
constant_input_runs[0].second == end)) {
return false;
}
std::set<int> inputs_to_delete;
for (auto interval : constant_input_runs) {
// Push the constant inputs in the interval to a child node than can be
// constant folded.
string new_node_name = OptimizedNodeName(*node, "_partial_split");
do {
new_node_name += strings::StrCat("_", interval.first);
} while (node_map_->NodeExists(new_node_name));
NodeDef* added_node = optimized_graph->add_node();
*added_node = *node;
added_node->set_op("ConcatV2");
added_node->set_name(new_node_name);
node_map_->AddNode(added_node->name(), added_node);
added_node->clear_input();
for (int i = interval.first; i < interval.second; ++i) {
added_node->add_input(node->input(i));
node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
if (i != interval.first) {
inputs_to_delete.insert(i);
}
}
added_node->add_input(node->input(axis_arg));
(*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
// Overwrite the first constant input with the result of the added
// child node.
node->set_input(interval.first, added_node->name());
}
if (!constant_input_runs.empty() && !inputs_to_delete.empty()) {
// Fix up the inputs to the original node.
protobuf::RepeatedPtrField<string> tmp;
tmp.Swap(node->mutable_input());
for (int i = 0; i < tmp.size(); ++i) {
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
node->add_input(tmp.Get(i));
}
}
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
properties->ClearInputProperties(node->name());
}
return true;
}
bool ConstantFolding::MergeConcat(const GraphProperties& properties,
bool use_shape_info,
bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
if (node.op() != "ConcatV2") {
return false;
}
int axis_idx = node.input_size() - 1;
while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
--axis_idx;
}
if (axis_idx <= 0) {
return false;
}
Tensor axis_tensor;
if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
return false;
}
*axis = axis_tensor.dtype() == DT_INT64
? static_cast<int>(axis_tensor.scalar<int64>()())
: axis_tensor.scalar<int32>()();
return true;
}
bool ConstantFolding::MergeConcat(bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node) {
// We only optimize for ConcatV2.
int axis;
if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
node_map_->GetOutputs(node->name()).size() != 1) {
return false;
}
// If all inputs are constant, don't merge and let folding take case of it.
const int num_regular_inputs = NumNonControlInputs(*node);
bool all_inputs_are_const = true;
for (int i = 0; i < num_regular_inputs - 1; ++i) {
const NodeDef* input_node = node_map_->GetNode(node->input(i));
if (!IsReallyConstant(*input_node)) {
all_inputs_are_const = false;
}
}
if (all_inputs_are_const) return false;
NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
int parent_axis;
if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
return false;
}
const int index = NumNonControlInputs(*node) - 1;
auto inputs = parent->input();
parent->clear_input();
for (int i = 0; i < inputs.size(); ++i) {
if (IsSameInput(inputs.Get(i), node->name())) {
for (int j = 0; j < node->input_size(); ++j) {
if (j < index) {
// Input tensors (non axis), add to input list of parent.
parent->add_input(node->input(j));
node_map_->RemoveOutput(node->input(j), node->name());
node_map_->AddOutput(node->input(j), parent->name());
}
// Skip j == index, which means axis tensor.
if (j > index) {
// Control Dependencies, push back to inputs so they can be forwarded
// to parent.
*inputs.Add() = node->input(j);
}
protobuf::RepeatedPtrField<string> parent_inputs;
parent_inputs.Swap(parent->mutable_input());
std::vector<string> ctrl_output;
// TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
// collapse it into the parent multiple times? Probablyu not.
for (const auto& input : parent_inputs) {
if (IsSameInput(input, node->name())) {
for (int j = 0; j < num_regular_inputs - 1; ++j) {
// Add tensor inputs to first child concat tensors (exceptthe final axis
// input) to the parent's inputs.
parent->add_input(node->input(j));
node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
}
} else {
parent->add_input(inputs.Get(i));
parent->add_input(input);
}
}
// Forward Add control inputs
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
parent->add_input(node->input(i));
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
}
node->clear_input();
node->set_op("NoOp");
node->clear_attr();
node_map_->RemoveNode(node->name());
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
DedupControlInputs(parent);
return true;
}
@ -3344,7 +3297,9 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
// that the shape inference deals with this conservatively unless we're in
// aggressive mode.
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
Status s = properties.InferStatically(assume_valid_feeds);
Status s = properties.InferStatically(assume_valid_feeds,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false);
const bool can_use_shape_info = s.ok();
if (can_use_shape_info) {

View File

@ -61,6 +61,8 @@ class ConstantFolding : public GraphOptimizer {
bool IsReallyConstant(const NodeDef& node) const;
bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node,
@ -239,8 +241,9 @@ class ConstantFolding : public GraphOptimizer {
void RemoveSplitOrSplitV(const GraphProperties& properties,
GraphDef* optimized_graph, NodeDef* node);
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
GraphDef* optimized_graph, NodeDef* node);
bool GetConcatAxis(const NodeDef& node, int* axis);
bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
NodeDef* node);
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
GraphDef* optimized_graph);

View File

@ -2390,12 +2390,11 @@ TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
TF_EXPECT_OK(status);
GraphDef want;
AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want);
AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
AddNode("axis", "Const", {}, {}, &want);
AddNode("ph", "Placeholder", {}, {}, &want);
AddNode("concat2", "ConcatV2",
{"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {},
&want);
{"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
CompareGraphs(want, got);
}

View File

@ -2209,7 +2209,10 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
GraphProperties graph_properties(item);
TF_RETURN_IF_ERROR(graph_properties.InferStatically(false));
TF_RETURN_IF_ERROR(
graph_properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));

View File

@ -102,7 +102,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false));
/*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
}
const auto& output_properties = properties->GetOutputProperties(node.name());
if (port_id >= output_properties.size()) {
@ -252,7 +253,8 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
if (!properties->has_properties()) {
// This is an expensive call, call it lazily.
TF_RETURN_IF_ERROR(properties->InferStatically(
/*assume_valid_feeds=*/false));
/*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
}
for (const auto& prop : properties->GetOutputProperties(node.name())) {
if (!IsTensorSmall(prop)) {

View File

@ -1170,7 +1170,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
// Infer properties lazily in case they are not needed.
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(false));
// TODO(rmlarsen): Get rid of tensor value copies.
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/true));
ctx.inferred_graph_properties = true;
}

View File

@ -739,9 +739,9 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
GraphProperties graph_properties(item);
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
LOG_WARNING_AND_RETURN_IF_ERROR(
graph_properties.InferStatically(assume_valid_feeds));
LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
assume_valid_feeds, /*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
*optimized_graph = item.graph;
node_map_.reset(new NodeMap(optimized_graph));

View File

@ -87,7 +87,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
if (!inferred_properties) {
// Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false));
TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true;
}
const auto& prop =
@ -144,7 +147,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
}
if (!inferred_properties) {
// Infer properties lazily in case they are not needed.
TF_RETURN_IF_ERROR(properties.InferStatically(false));
TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/false,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
inferred_properties = true;
}
const auto& prop1 = properties.GetInputProperties(input1.node->name());

View File

@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
#include <deque>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
@ -92,7 +94,10 @@ Status EstimateEarliestExecutionTimes(
name_map.clear();
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices());
@ -160,7 +165,10 @@ Status EstimateRequiredTimes(
}
}
GraphProperties properties(item);
TF_RETURN_IF_ERROR(properties.InferStatically(true));
TF_RETURN_IF_ERROR(
properties.InferStatically(/*assume_valid_feeds=*/true,
/*aggressive_shape_inference=*/false,
/*include_tensor_values=*/false));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster->GetDevices());

View File

@ -346,7 +346,8 @@ class FoldOldBatchNormsTest : public ::testing::Test {
std::vector<Tensor> fused_outputs;
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
/*rtol=*/2e-5);
for (const NodeDef& node : fused_graph_def.node()) {
EXPECT_NE("FusedBatchNorm", node.op());