Add an option to avoid copying TensorProtos in shape inference when not needed. This saves a significant amount of memory and time in various optimizers requiring static shape inference. Calling GraphProperties::InferStatically would make at least one copy of all constant tensors stored in the graph. This CL adds a new argument to InferStatically, allowing callers to skip copying tensor values to InputProperties and OutputProperties for optimizers that only need shape and datatype information.
Turning this off for constant folding and arithmetic optimizer required cleaning up a few rewrite rules that relied on reading the copy of constant tensor values from GraphProperties instead of just extracting the original value from the corresponding constant node in the graph. Measured results on Transformer graph: CPU: Intel Skylake Xeon with HyperThreading (36 cores) dL1:32KB dL2:1024KB dL3:24MB Benchmark Time(ns) CPU(ns) Allocs Iterations ---------------------------------------------------------------------- Before: BM_OptimizeTransformer 5161658741 5281083018 23532315 1 139.189MB peak-mem After: BM_OptimizeTransformer 4887891650 5005937073 23063225 1 132.652MB peak-mem Effects on this graph are not dramatic, since it does not contain a lot of large constant tensors. The effect would be much more pronounced on frozen inference graphs. PiperOrigin-RevId: 250740251
This commit is contained in:
parent
d7d2307248
commit
705b193812
@ -712,7 +712,10 @@ class SymbolicShapeRefiner {
|
||||
|
||||
// Perform inference on function body.
|
||||
GraphProperties gp(grappler_function_item);
|
||||
TF_RETURN_IF_ERROR(gp.InferStatically(true, aggressive_shape_inference_));
|
||||
TF_RETURN_IF_ERROR(gp.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/aggressive_shape_inference_,
|
||||
/*include_tensor_values=*/true));
|
||||
|
||||
// Add return nodes for output shapes.
|
||||
int output = 0;
|
||||
@ -2066,7 +2069,8 @@ Status GraphProperties::UpdateEnqueue(
|
||||
}
|
||||
|
||||
Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
bool aggressive_shape_inference) {
|
||||
bool aggressive_shape_inference,
|
||||
bool include_tensor_values) {
|
||||
FunctionLibraryDefinition function_library(OpRegistry::Global(),
|
||||
item_.graph.library());
|
||||
std::unordered_map<string, std::unordered_set<int>> fed_ports;
|
||||
@ -2225,9 +2229,11 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
&input_properties[i]);
|
||||
input.port_id = i;
|
||||
GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
|
||||
if (include_tensor_values) {
|
||||
// Export tensor value to input_properties.value.
|
||||
if (IsConstant(*fanin.node)) {
|
||||
const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
|
||||
const TensorProto& raw_val =
|
||||
fanin.node->attr().at("value").tensor();
|
||||
*input_properties[i].mutable_value() = raw_val;
|
||||
} else if (ctx->input_tensor_protos.size() > i &&
|
||||
ctx->input_tensor_protos[i] != nullptr) {
|
||||
@ -2242,6 +2248,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fill output properties.
|
||||
{
|
||||
@ -2254,13 +2261,16 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
for (int i = 0; i < ic->num_outputs(); ++i) {
|
||||
shape_manager->AsTensorProperties(ic->output(i), ctx->output_types[i],
|
||||
&output_properties[i]);
|
||||
if (include_tensor_values) {
|
||||
// Export tensor value to output_properties.value.
|
||||
if (IsConstant(node)) {
|
||||
// TODO(rmlarsen): Eliminate this copy.
|
||||
const TensorProto& raw_val = node.attr().at("value").tensor();
|
||||
*output_properties[i].mutable_value() = raw_val;
|
||||
} else if (ctx->output_tensor_protos.size() > i &&
|
||||
ctx->output_tensor_protos[i] != nullptr) {
|
||||
*output_properties[i].mutable_value() = *ctx->output_tensor_protos[i];
|
||||
*output_properties[i].mutable_value() =
|
||||
*ctx->output_tensor_protos[i];
|
||||
} else if (ctx->output_tensors_as_shapes.size() > i &&
|
||||
IsShapeFullyDefinedIntegerVectorOrScalar(
|
||||
ic, ic->output(i), ctx->output_tensors_as_shapes[i],
|
||||
@ -2272,6 +2282,7 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Help trace the unknown dimensions to their origins.
|
||||
VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
|
||||
|
@ -89,11 +89,15 @@ class GraphProperties {
|
||||
// output values when possible and does other aggressive strategies.
|
||||
// Similar to assuming_valid_feeds, this may cause incorrectness in graph
|
||||
// analyses, but is useful for simulation or scheduling.
|
||||
// If include_values is true, the values of constant tensors will be
|
||||
// included in the input and output properties.
|
||||
Status InferStatically(bool assume_valid_feeds,
|
||||
bool aggressive_shape_inference);
|
||||
bool aggressive_shape_inference,
|
||||
bool include_tensor_values);
|
||||
Status InferStatically(bool assume_valid_feeds) {
|
||||
return InferStatically(assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false);
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true);
|
||||
}
|
||||
// Infer the shape by running the graph on the specified cluster and recording
|
||||
// the shapes of the processed tensors.
|
||||
@ -117,6 +121,7 @@ class GraphProperties {
|
||||
const string& node_name) const;
|
||||
const std::vector<OpInfo::TensorProperties>& GetOutputProperties(
|
||||
const string& node_name) const;
|
||||
|
||||
// Invalidate input/output properties for nodes modified during graph
|
||||
// optimization pass, to prevent potential optimizations, based on incorrect
|
||||
// shape information.
|
||||
|
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
@ -1009,7 +1010,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto out_props = properties.GetOutputProperties("fill");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("float: [4,4]", PropToString(out_prop0));
|
||||
@ -1028,7 +1030,8 @@ TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) {
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto out_props = properties.GetOutputProperties("fill");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("float: [1000,1000,1000,1000]", PropToString(out_prop0));
|
||||
@ -1248,10 +1251,12 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
|
||||
// evaluate output value.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto out_props = properties.GetOutputProperties("MyFunc");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
|
||||
LOG(INFO) << out_prop0.DebugString();
|
||||
EXPECT_FALSE(out_prop0.has_value());
|
||||
}
|
||||
|
||||
@ -1260,7 +1265,8 @@ TEST_F(GraphPropertiesTest, ArithmeticFunctionReturnTensorValue) {
|
||||
// With aggressive_shape_inference, output value is evaluated.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto out_props = properties.GetOutputProperties("MyFunc");
|
||||
const OpInfo::TensorProperties out_prop0 = out_props[0];
|
||||
EXPECT_EQ("int32: [2]", PropToString(out_prop0));
|
||||
@ -1802,7 +1808,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
EXPECT_FALSE(properties.GetOutputProperties("slice").at(0).has_value());
|
||||
}
|
||||
|
||||
@ -1812,7 +1819,8 @@ TEST_F(GraphPropertiesTest, StridedSliceOfShapeWithShrinkAxisMask) {
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
EXPECT_TRUE(properties.GetOutputProperties("slice").at(0).has_value());
|
||||
const auto slice_value =
|
||||
properties.GetOutputProperties("slice").at(0).value();
|
||||
@ -1838,7 +1846,8 @@ TEST_F(GraphPropertiesTest, ValuePropagationThroughArithmeticOps) {
|
||||
GraphProperties properties(item);
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
|
||||
// Check output shapes and values.
|
||||
const auto& a_plus_one_prop = properties.GetOutputProperties("a_plus_one")[0];
|
||||
@ -1881,7 +1890,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
|
||||
// Without aggressive_shape_inference, ignore annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false));
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
@ -1895,7 +1905,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotation) {
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
@ -1923,7 +1934,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithCompatibleShapes) {
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
@ -1950,7 +1962,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithIncompatibleShapes) {
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto props = properties.GetOutputProperties("Identity");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
@ -1977,7 +1990,8 @@ TEST_F(GraphPropertiesTest, ShapeAnnotationWithoutInferenceFn) {
|
||||
// Use annotated information.
|
||||
TF_CHECK_OK(properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/true));
|
||||
/*aggressive_shape_inference=*/true,
|
||||
/*include_tensor_values=*/true));
|
||||
const auto props = properties.GetOutputProperties("TestOpWithNoInferenceFn");
|
||||
EXPECT_EQ(1, props.size());
|
||||
const OpInfo::TensorProperties& prop = props[0];
|
||||
|
@ -359,7 +359,7 @@ Status VirtualScheduler::Init(const GrapplerItem* item) {
|
||||
graph_properties_ = absl::make_unique<GraphProperties>(*item);
|
||||
if (use_static_shapes_) {
|
||||
TF_RETURN_IF_ERROR(graph_properties_->InferStatically(
|
||||
true, use_aggressive_shape_inference_));
|
||||
true, use_aggressive_shape_inference_, true));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(graph_properties_->InferDynamically(cluster_));
|
||||
}
|
||||
|
@ -234,6 +234,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
|
||||
DedupControlInputs(target_node);
|
||||
}
|
||||
|
||||
bool IsReallyConstant(const NodeDef& node) const {
|
||||
if (!IsConstant(node)) {
|
||||
return false;
|
||||
}
|
||||
// If the node is fed it's not constant anymore.
|
||||
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
|
||||
}
|
||||
|
||||
bool IsInPreserveSet(const NodeDef& node) const {
|
||||
return ctx().nodes_to_preserve->find(node.name()) !=
|
||||
ctx().nodes_to_preserve->end();
|
||||
@ -259,6 +267,14 @@ class ArithmeticOptimizerStage : public GraphOptimizerStage<string> {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool GetTensorFromConstNode(const string& node_name_or_input,
|
||||
Tensor* tensor) {
|
||||
const NodeDef* node = ctx().node_map->GetNode(node_name_or_input);
|
||||
return node != nullptr && IsReallyConstant(*node) &&
|
||||
CheckAttrExists(*node, "value").ok() &&
|
||||
tensor->FromProto(node->attr().at("value").tensor());
|
||||
}
|
||||
|
||||
private:
|
||||
// Extended context required for ArithmeticOptimizer.
|
||||
const ArithmeticOptimizerContext ctx_ext_;
|
||||
@ -2480,27 +2496,16 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
|
||||
bool IsSupported(const NodeDef* node) const override {
|
||||
return IsPow(*node) &&
|
||||
ctx().graph_properties->GetInputProperties(node->name()).size() == 2;
|
||||
ctx().graph_properties->HasOutputProperties(node->name()) &&
|
||||
ctx().graph_properties->HasInputProperties(node->name());
|
||||
}
|
||||
|
||||
Status TrySimplify(NodeDef* node, string* simplified_node_name) override {
|
||||
const auto& pow_props =
|
||||
ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
PartialTensorShape shape(pow_props.shape());
|
||||
if (!shape.IsFullyDefined()) {
|
||||
// skip if p is not fully defined.
|
||||
return Status::OK();
|
||||
}
|
||||
if (TensorShape::IsValid(pow_props.shape()) && pow_props.has_value()) {
|
||||
Tensor pow(pow_props.dtype(), pow_props.shape());
|
||||
if (!pow.FromProto(pow_props.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
pow_props.value().DebugString());
|
||||
}
|
||||
|
||||
Tensor pow;
|
||||
if (!GetTensorFromConstNode(node->input(1), &pow)) return Status::OK();
|
||||
complex128 prev, curr;
|
||||
for (int i = 0; i < pow.NumElements(); ++i) {
|
||||
if (!GetElementUnexhaustive(pow, i, {pow_props.dtype()}, &curr)) {
|
||||
if (!GetElementUnexhaustive(pow, i, {pow.dtype()}, &curr)) {
|
||||
// input data type is not supported by Pow. Skip.
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2536,34 +2541,23 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(y);
|
||||
} else if (curr == complex128(0, 0) &&
|
||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape)) {
|
||||
PartialTensorShape shape(value_props.shape());
|
||||
if (!shape.IsFullyDefined()) {
|
||||
// skip if b is not fully defined.
|
||||
return Status::OK();
|
||||
}
|
||||
if (TensorShape::IsValid(value_props.shape()) &&
|
||||
value_props.has_value()) {
|
||||
Tensor base(value_props.dtype(), value_props.shape());
|
||||
if (!base.FromProto(value_props.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
value_props.value().DebugString());
|
||||
ShapesSymbolicallyEqual(value_props.shape(), output_shape) &&
|
||||
PartialTensorShape(output_shape).IsFullyDefined()) {
|
||||
const auto dtype = node->attr().at("T").type();
|
||||
Tensor ones(dtype, output_shape);
|
||||
for (int i = 0; i < ones.NumElements(); ++i) {
|
||||
TF_RETURN_IF_ERROR(SetElementToOne(i, &ones));
|
||||
}
|
||||
node->set_op("Const");
|
||||
Tensor c(base.dtype(), base.shape());
|
||||
for (int i = 0; i < c.NumElements(); ++i) {
|
||||
TF_RETURN_IF_ERROR(SetElementToOne(i, &c));
|
||||
}
|
||||
(*node->mutable_attr())["dtype"].set_type(base.dtype());
|
||||
c.AsProtoTensorContent(
|
||||
(*node->mutable_attr())["value"].mutable_tensor());
|
||||
(*node->mutable_attr())["dtype"].set_type(dtype);
|
||||
node->mutable_attr()->erase("T");
|
||||
ones.AsProtoTensorContent(
|
||||
(*node->mutable_attr())["value"].mutable_tensor());
|
||||
node->set_input(0, AsControlDependency(x->name()));
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(x);
|
||||
AddToOptimizationQueue(y);
|
||||
}
|
||||
} else if (curr == complex128(-0.5, 0)) {
|
||||
node->set_op("Rsqrt");
|
||||
node->set_input(1, AsControlDependency(y->name()));
|
||||
@ -2575,7 +2569,6 @@ class ConvertPowStage : public ArithmeticOptimizerStage {
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(y);
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -2638,12 +2631,12 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
|
||||
}
|
||||
|
||||
private:
|
||||
Status TrySimplifyInternal(NodeDef* node, NodeDef* input, int i, int j,
|
||||
Status TrySimplifyInternal(NodeDef* node, NodeDef* add_node, int i, int j,
|
||||
bool* modified) {
|
||||
const auto& t =
|
||||
ctx().graph_properties->GetInputProperties(input->name())[i];
|
||||
ctx().graph_properties->GetInputProperties(add_node->name())[i];
|
||||
const auto& c =
|
||||
ctx().graph_properties->GetInputProperties(input->name())[j];
|
||||
ctx().graph_properties->GetInputProperties(add_node->name())[j];
|
||||
for (int k = 0; k < c.shape().dim_size(); ++k) {
|
||||
// Skip if c shape is not fully determined.
|
||||
if (c.shape().dim(k).size() < 0) {
|
||||
@ -2659,13 +2652,13 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
|
||||
// broadcast.
|
||||
return Status::OK();
|
||||
}
|
||||
if (TensorShape::IsValid(c.shape()) && c.has_value()) {
|
||||
Tensor constant(c.dtype(), c.shape());
|
||||
if (!constant.FromProto(c.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
c.value().DebugString());
|
||||
}
|
||||
Tensor constant;
|
||||
if (GetTensorFromConstNode(add_node->input(j), &constant)) {
|
||||
complex128 element;
|
||||
// TODO(rmlarsen): Refactor the more general IsOnes from
|
||||
// constant_folding.cc and use it here. Perhaps also convert log(x - (-1))
|
||||
// or (preferably) add a passes to canonicalize Sub(x, -1) to Add(x, 1),
|
||||
// and Neg(-1) to 1.
|
||||
for (int k = 0; k < constant.NumElements(); ++k) {
|
||||
if (!GetElementUnexhaustive(constant, k,
|
||||
{DT_BFLOAT16, DT_HALF, DT_FLOAT, DT_DOUBLE,
|
||||
@ -2680,15 +2673,15 @@ class ConvertLog1pStage : public ArithmeticOptimizerStage {
|
||||
}
|
||||
}
|
||||
NodeDef *x, *y;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(input->input(i), &x));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(input->input(j), &y));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(add_node->input(i), &x));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(add_node->input(j), &y));
|
||||
node->set_op("Log1p");
|
||||
node->set_input(0, input->input(i));
|
||||
node->set_input(0, add_node->input(i));
|
||||
node->add_input(AsControlDependency(y->name()));
|
||||
ForwardControlDependencies(node, {input});
|
||||
ForwardControlDependencies(node, {add_node});
|
||||
|
||||
AddToOptimizationQueue(node);
|
||||
AddToOptimizationQueue(input);
|
||||
AddToOptimizationQueue(add_node);
|
||||
AddToOptimizationQueue(x);
|
||||
AddToOptimizationQueue(y);
|
||||
*modified = true;
|
||||
@ -2717,25 +2710,8 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
if (ctx().graph_properties->GetInputProperties(node->name()).size() < 2) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
NodeDef* exp;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
|
||||
if (!IsExp(*exp)) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (ctx().graph_properties->GetInputProperties(exp->name()).empty()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const auto& t = ctx().graph_properties->GetInputProperties(exp->name())[0];
|
||||
const auto& t = ctx().graph_properties->GetInputProperties(node->name())[0];
|
||||
const auto& c = ctx().graph_properties->GetInputProperties(node->name())[1];
|
||||
for (int k = 0; k < c.shape().dim_size(); ++k) {
|
||||
// Skip if c shape is not fully determined.
|
||||
if (c.shape().dim(k).size() < 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
TensorShapeProto broadcast_shape;
|
||||
if (!ShapeAfterBroadcast(t.shape(), c.shape(), &broadcast_shape)) {
|
||||
return Status::OK();
|
||||
@ -2745,12 +2721,9 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
// broadcast.
|
||||
return Status::OK();
|
||||
}
|
||||
if (TensorShape::IsValid(c.shape()) && c.has_value()) {
|
||||
Tensor constant(c.dtype(), c.shape());
|
||||
if (!constant.FromProto(c.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
c.value().DebugString());
|
||||
}
|
||||
Tensor constant;
|
||||
if (!GetTensorFromConstNode(node->input(1), &constant)) return Status::OK();
|
||||
// TODO(rmlarsen): Use the more general IsOnes helper here.
|
||||
complex128 element;
|
||||
for (int k = 0; k < constant.NumElements(); ++k) {
|
||||
if (!GetElementUnexhaustive(constant, k,
|
||||
@ -2760,11 +2733,14 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
// input data type is not supported by expm1. Skip.
|
||||
return Status::OK();
|
||||
}
|
||||
LOG(INFO) << "Got element = " << element;
|
||||
if (element != complex128(1)) {
|
||||
// current element is not 1. Skip.
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
NodeDef* exp;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(0), &exp));
|
||||
NodeDef *exp_input, *ones;
|
||||
TF_RETURN_IF_ERROR(GetInputNode(exp->input(0), &exp_input));
|
||||
TF_RETURN_IF_ERROR(GetInputNode(node->input(1), &ones));
|
||||
@ -2777,7 +2753,7 @@ class ConvertExpm1Stage : public ArithmeticOptimizerStage {
|
||||
AddToOptimizationQueue(exp);
|
||||
AddToOptimizationQueue(exp_input);
|
||||
AddToOptimizationQueue(ones);
|
||||
}
|
||||
*simplified_node_name = node->name();
|
||||
return Status::OK();
|
||||
}
|
||||
};
|
||||
@ -3096,14 +3072,6 @@ class RemoveStackStridedSliceSameAxis : public ArithmeticOptimizerStage {
|
||||
}
|
||||
|
||||
protected:
|
||||
bool IsReallyConstant(const NodeDef& node) const {
|
||||
if (!IsConstant(node)) {
|
||||
return false;
|
||||
}
|
||||
// If the node is fed it's not constant anymore.
|
||||
return ctx().feed_nodes->find(node.name()) == ctx().feed_nodes->end();
|
||||
}
|
||||
|
||||
bool GetConstantAsInt64(const NodeDef& node, DataType dtype,
|
||||
std::vector<int64>* values) {
|
||||
if (dtype == DT_INT32) {
|
||||
@ -3430,8 +3398,6 @@ bool ArithmeticOptimizer::CanDedup(const NodeDef& node) const {
|
||||
|
||||
void ArithmeticOptimizer::DedupComputations() {
|
||||
CanonicalizeGraph(optimized_graph_);
|
||||
// LOG(INFO) << "Graph after canonicalization: \n"
|
||||
// << optimized_graph_->DebugString();
|
||||
|
||||
GraphTopologyView graph_view;
|
||||
if (!graph_view.InitializeFromGraph(*optimized_graph_).ok()) {
|
||||
@ -3683,7 +3649,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
|
||||
|
||||
graph_properties_.reset(new GraphProperties(optimized_item));
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
const Status status = graph_properties_->InferStatically(assume_valid_feeds);
|
||||
const Status status =
|
||||
graph_properties_->InferStatically(assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false);
|
||||
const bool can_use_shapes = status.ok();
|
||||
if (!can_use_shapes) {
|
||||
VLOG(1) << "Shape inference failed." << status.error_message();
|
||||
|
@ -122,27 +122,6 @@ bool MaybeRemoveControlInput(const string& old_input, NodeDef* node,
|
||||
return removed_input;
|
||||
}
|
||||
|
||||
bool GetConcatAxis(const GraphProperties& properties, NodeDef* node,
|
||||
int* axis) {
|
||||
if (node->op() != "ConcatV2" ||
|
||||
properties.GetInputProperties(node->name()).empty()) {
|
||||
return false;
|
||||
}
|
||||
const auto& axis_input = properties.GetInputProperties(node->name()).back();
|
||||
if (!TensorShape::IsValid(axis_input.shape()) || !axis_input.has_value()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor axis_tensor(axis_input.dtype(), axis_input.shape());
|
||||
if (!axis_tensor.FromProto(axis_input.value())) {
|
||||
return false;
|
||||
}
|
||||
*axis = axis_input.dtype() == DT_INT64
|
||||
? static_cast<int>(axis_tensor.scalar<int64>()())
|
||||
: axis_tensor.scalar<int32>()();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HasTPUAttributes(const NodeDef& node) {
|
||||
AttrSlice attrs(node);
|
||||
for (auto attr : attrs) {
|
||||
@ -220,9 +199,9 @@ string ConstantFolding::AddControlDependency(const string& input_name,
|
||||
if (IsControlInput(input_name)) {
|
||||
return input_name;
|
||||
}
|
||||
const NodeDef* node = node_map->GetNode(input_name);
|
||||
if (!IsSwitch(*node)) {
|
||||
return AsControlDependency(*node);
|
||||
const NodeDef& node = *node_map->GetNode(input_name);
|
||||
if (!IsSwitch(node)) {
|
||||
return AsControlDependency(node);
|
||||
} else {
|
||||
// We can't anchor control dependencies directly on the switch node: unlike
|
||||
// other nodes only one of the outputs of the switch node will be generated
|
||||
@ -230,10 +209,10 @@ string ConstantFolding::AddControlDependency(const string& input_name,
|
||||
// dependency is only triggered when the corresponding output is triggered.
|
||||
// We start by looking for an identity node connected to the output of the
|
||||
// switch node, and use it to anchor the control dependency.
|
||||
auto outputs = node_map->GetOutputs(node->name());
|
||||
auto outputs = node_map->GetOutputs(node.name());
|
||||
for (const NodeDef* output : outputs) {
|
||||
if (IsIdentity(*output) || IsIdentityNSingleInput(*output)) {
|
||||
if (IsSameInput(node->input(0), input_name)) {
|
||||
if (IsSameInput(node.input(0), input_name)) {
|
||||
return AsControlDependency(*output);
|
||||
}
|
||||
}
|
||||
@ -244,19 +223,19 @@ string ConstantFolding::AddControlDependency(const string& input_name,
|
||||
string ctrl_dep_name = ParseNodeName(input_name, &port);
|
||||
strings::StrAppend(&ctrl_dep_name, "_", port);
|
||||
ctrl_dep_name = AddPrefixToNodeName(ctrl_dep_name, kConstantFoldingCtrl);
|
||||
const DataType output_type = node->attr().at("T").type();
|
||||
const DataType output_type = node.attr().at("T").type();
|
||||
|
||||
NodeDef* added_node = node_map->GetNode(ctrl_dep_name);
|
||||
if (added_node == nullptr) {
|
||||
added_node = graph->add_node();
|
||||
added_node->set_name(ctrl_dep_name);
|
||||
added_node->set_op("Identity");
|
||||
added_node->set_device(node->device());
|
||||
added_node->set_device(node.device());
|
||||
|
||||
(*added_node->mutable_attr())["T"].set_type(output_type);
|
||||
*added_node->add_input() = input_name;
|
||||
node_map->AddNode(added_node->name(), added_node);
|
||||
node_map->AddOutput(node->name(), added_node->name());
|
||||
node_map->AddOutput(node.name(), added_node->name());
|
||||
}
|
||||
return AsControlDependency(*added_node);
|
||||
}
|
||||
@ -321,6 +300,15 @@ bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
|
||||
return feed_nodes_.find(node.name()) == feed_nodes_.end();
|
||||
}
|
||||
|
||||
// TODO(rmlarsen): Refactor to shared util.
|
||||
bool ConstantFolding::GetTensorFromConstNode(const string& node_name_or_input,
|
||||
Tensor* tensor) {
|
||||
const NodeDef* node = node_map_->GetNode(node_name_or_input);
|
||||
return node != nullptr && IsReallyConstant(*node) &&
|
||||
CheckAttrExists(*node, "value").ok() &&
|
||||
tensor->FromProto(node->attr().at("value").tensor());
|
||||
}
|
||||
|
||||
// Materialize the shapes using constants whenever possible.
|
||||
Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
|
||||
// We may add some nodes to the graph to encode control dependencies and hold
|
||||
@ -868,6 +856,9 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (node.op() == "AccumulateNV2") {
|
||||
return false;
|
||||
}
|
||||
// Skips ops that don't benefit from folding.
|
||||
if (IsPlaceholder(node)) {
|
||||
return false;
|
||||
@ -1856,9 +1847,9 @@ Status ConstantFolding::SimplifyNode(bool use_shape_info, NodeDef* node,
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialAssocOpConstFolding(optimized_graph, properties, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||
MergeConcat(use_shape_info, optimized_graph, node));
|
||||
SET_AND_RETURN_IF_MODIFIED(
|
||||
MergeConcat(*properties, use_shape_info, optimized_graph, node));
|
||||
PartialConcatConstFolding(optimized_graph, properties, node));
|
||||
|
||||
graph_modified_ = graph_modified_cached;
|
||||
return Status::OK();
|
||||
@ -1879,26 +1870,18 @@ void ConstantFolding::RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
Status ConstantFolding::RemoveShuffleOrTranspose(
|
||||
const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (use_shape_info && (IsShuffle(*node) || IsTranspose(*node)) &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||
if (shape.unknown_rank()) {
|
||||
// Not optimizable.
|
||||
if (!use_shape_info || !(IsShuffle(*node) || IsTranspose(*node)))
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& p = properties.GetInputProperties(node->name())[1];
|
||||
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
|
||||
Tensor perm(p.dtype(), p.shape());
|
||||
if (!perm.FromProto(p.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
p.value().DebugString());
|
||||
}
|
||||
Tensor permutation_tensor;
|
||||
if (GetTensorFromConstNode(node->input(1), &permutation_tensor) &&
|
||||
properties.HasInputProperties(node->name())) {
|
||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||
std::vector<int> permutation;
|
||||
for (int j = 0; j < perm.NumElements(); ++j) {
|
||||
if (perm.dtype() == DT_INT64) {
|
||||
permutation.push_back(perm.vec<int64>()(j));
|
||||
for (int j = 0; j < permutation_tensor.NumElements(); ++j) {
|
||||
if (permutation_tensor.dtype() == DT_INT64) {
|
||||
permutation.push_back(permutation_tensor.vec<int64>()(j));
|
||||
} else {
|
||||
permutation.push_back(perm.vec<int>()(j));
|
||||
permutation.push_back(permutation_tensor.vec<int>()(j));
|
||||
}
|
||||
}
|
||||
if (permutation.size() != shape.dim_size()) {
|
||||
@ -1914,8 +1897,6 @@ Status ConstantFolding::RemoveShuffleOrTranspose(
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -1941,20 +1922,12 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (use_shape_info && node->op() == "ReverseV2" &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
if (!use_shape_info || node->op() != "ReverseV2") return Status::OK();
|
||||
Tensor axis;
|
||||
if (properties.HasInputProperties(node->name()) &&
|
||||
GetTensorFromConstNode(node->input(1), &axis)) {
|
||||
const auto& shape = properties.GetInputProperties(node->name())[0].shape();
|
||||
if (shape.unknown_rank()) {
|
||||
// Not optimizable.
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& a = properties.GetInputProperties(node->name())[1];
|
||||
if (TensorShape::IsValid(a.shape()) && a.has_value()) {
|
||||
Tensor axis(a.dtype(), a.shape());
|
||||
if (!axis.FromProto(a.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
a.value().DebugString());
|
||||
}
|
||||
if (shape.unknown_rank()) return Status::OK();
|
||||
std::set<int> target_axes;
|
||||
for (int j = 0; j < axis.NumElements(); ++j) {
|
||||
// value of axis can be negative.
|
||||
@ -1971,16 +1944,15 @@ Status ConstantFolding::RemoveReverse(const GraphProperties& properties,
|
||||
// unknown_rank == false &&
|
||||
// (dim_size == 0 || all dims have size 1 ||
|
||||
// all dims with > 1 size are not in target_axes)
|
||||
bool replaceable = !shape.unknown_rank();
|
||||
bool replaceable = true;
|
||||
for (int j = 0; replaceable && j < shape.dim_size(); ++j) {
|
||||
replaceable &= shape.dim(j).size() == 1 ||
|
||||
target_axes.find(j) == target_axes.end();
|
||||
replaceable &=
|
||||
shape.dim(j).size() == 1 || target_axes.find(j) == target_axes.end();
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -1988,23 +1960,13 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph,
|
||||
NodeDef* node) {
|
||||
if (use_shape_info && IsSlice(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 3) {
|
||||
if (!use_shape_info || !IsSlice(*node)) return Status::OK();
|
||||
Tensor begin;
|
||||
Tensor size;
|
||||
if (properties.HasInputProperties(node->name()) &&
|
||||
GetTensorFromConstNode(node->input(1), &begin) &&
|
||||
GetTensorFromConstNode(node->input(2), &size)) {
|
||||
const auto& input = properties.GetInputProperties(node->name())[0];
|
||||
const auto& b = properties.GetInputProperties(node->name())[1];
|
||||
const auto& s = properties.GetInputProperties(node->name())[2];
|
||||
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
|
||||
TensorShape::IsValid(s.shape()) && s.has_value()) {
|
||||
Tensor begin(b.dtype(), b.shape());
|
||||
if (!begin.FromProto(b.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
b.value().DebugString());
|
||||
}
|
||||
Tensor size(s.dtype(), s.shape());
|
||||
if (!size.FromProto(s.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
s.value().DebugString());
|
||||
}
|
||||
// The node is replaceable iff unknown_rank == false &&
|
||||
// begin == 0 && (size == -1 || size == input_shape) for all dimensions
|
||||
bool replaceable = !input.shape().unknown_rank();
|
||||
@ -2024,8 +1986,6 @@ Status ConstantFolding::SimplifySlice(const GraphProperties& properties,
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
@ -2052,27 +2012,18 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
const auto& b = properties.GetInputProperties(node->name())[1];
|
||||
const auto& e = properties.GetInputProperties(node->name())[2];
|
||||
const auto& s = properties.GetInputProperties(node->name())[3];
|
||||
if (TensorShape::IsValid(b.shape()) && b.has_value() &&
|
||||
TensorShape::IsValid(e.shape()) && e.has_value() &&
|
||||
TensorShape::IsValid(s.shape()) && s.has_value()) {
|
||||
Tensor begin(b.dtype(), b.shape());
|
||||
if (!begin.FromProto(b.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
b.value().DebugString());
|
||||
|
||||
std::vector<Tensor> input_tensors(3);
|
||||
for (int i = 1; i < 4; ++i) {
|
||||
if (!GetTensorFromConstNode(node->input(i), &input_tensors[i - 1])) {
|
||||
return Status::OK();
|
||||
}
|
||||
Tensor end(e.dtype(), e.shape());
|
||||
if (!end.FromProto(e.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
e.value().DebugString());
|
||||
}
|
||||
Tensor strides(s.dtype(), s.shape());
|
||||
if (!strides.FromProto(s.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
s.value().DebugString());
|
||||
}
|
||||
|
||||
const Tensor& begin = input_tensors[0];
|
||||
const Tensor& end = input_tensors[1];
|
||||
const Tensor& strides = input_tensors[2];
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
CheckAttrsExist(*node, {"begin_mask", "end_mask", "ellipsis_mask"}));
|
||||
int begin_mask = node->attr().at("begin_mask").i();
|
||||
@ -2116,34 +2067,26 @@ Status ConstantFolding::SimplifyStridedSlice(const GraphProperties& properties,
|
||||
}
|
||||
int b = begin.dtype() == DT_INT32 ? begin.vec<int>()(i)
|
||||
: begin.vec<int64>()(i);
|
||||
int e =
|
||||
end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
|
||||
int e = end.dtype() == DT_INT32 ? end.vec<int>()(i) : end.vec<int64>()(i);
|
||||
int s = strides.dtype() == DT_INT32 ? strides.vec<int>()(i)
|
||||
: strides.vec<int64>()(i);
|
||||
replaceable &=
|
||||
(begin_mask & 1 << i || b == 0) &&
|
||||
(end_mask & 1 << i || e == input.shape().dim(j).size()) && s == 1;
|
||||
replaceable &= (begin_mask & 1 << i || b == 0) &&
|
||||
(end_mask & 1 << i || e == input.shape().dim(j).size()) &&
|
||||
s == 1;
|
||||
}
|
||||
if (replaceable) {
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
Tensor multiplies;
|
||||
if (use_shape_info && IsTile(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() == 2) {
|
||||
const auto& m = properties.GetInputProperties(node->name())[1];
|
||||
if (TensorShape::IsValid(m.shape()) && m.has_value()) {
|
||||
Tensor multiplies(m.dtype(), m.shape());
|
||||
if (!multiplies.FromProto(m.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
m.value().DebugString());
|
||||
}
|
||||
GetTensorFromConstNode(node->input(1), &multiplies)) {
|
||||
// The node is replaceable iff all values in multiplies are 1.
|
||||
bool replaceable = true;
|
||||
if (multiplies.dtype() == DT_INT32) {
|
||||
@ -2151,8 +2094,7 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||
replaceable &= multiplies.vec<int>()(j) == 1;
|
||||
}
|
||||
} else {
|
||||
for (int j = 0; replaceable && j < multiplies.vec<int64>().size();
|
||||
++j) {
|
||||
for (int j = 0; replaceable && j < multiplies.vec<int64>().size(); ++j) {
|
||||
replaceable &= multiplies.vec<int64>()(j) == 1;
|
||||
}
|
||||
}
|
||||
@ -2160,22 +2102,16 @@ Status ConstantFolding::SimplifyTile(const GraphProperties& properties,
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
if (use_shape_info && IsPad(*node) &&
|
||||
properties.GetInputProperties(node->name()).size() >= 2) {
|
||||
const auto& p = properties.GetInputProperties(node->name())[1];
|
||||
if (TensorShape::IsValid(p.shape()) && p.has_value()) {
|
||||
Tensor paddings(p.dtype(), p.shape());
|
||||
if (!paddings.FromProto(p.value())) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
p.value().DebugString());
|
||||
}
|
||||
if (!use_shape_info || !IsPad(*node)) return Status::OK();
|
||||
|
||||
Tensor paddings;
|
||||
if (GetTensorFromConstNode(node->input(1), &paddings)) {
|
||||
// The node is replaceable iff all values in paddings are 0.
|
||||
bool replaceable = true;
|
||||
// The operation requires it to be int32 value so we don't check for
|
||||
@ -2188,7 +2124,6 @@ Status ConstantFolding::SimplifyPad(const GraphProperties& properties,
|
||||
ReplaceOperationWithIdentity(0, properties, node, optimized_graph);
|
||||
}
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -3031,16 +2966,17 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
|
||||
// folding of ops when more than one but not all inputs are constant.
|
||||
// For AddN and AccumulateNV2, we may furthermore reorder inputs, since
|
||||
// addition is commutative.
|
||||
if (!IsAggregate(*node) || !IsCommutative(*node)) return false;
|
||||
|
||||
const int num_non_control_inputs = NumNonControlInputs(*node);
|
||||
if (IsAggregate(*node) && IsCommutative(*node) &&
|
||||
num_non_control_inputs > 2) {
|
||||
if (num_non_control_inputs <= 2) return false;
|
||||
const int num_control_inputs = node->input_size() - num_non_control_inputs;
|
||||
std::vector<int> const_inputs;
|
||||
std::vector<int> nonconst_inputs;
|
||||
for (int i = 0; i < node->input_size(); ++i) {
|
||||
const string& input = node->input(i);
|
||||
const NodeDef* input_node = node_map_->GetNode(NodeName(input));
|
||||
CHECK(input_node != nullptr) << input;
|
||||
if (input_node == nullptr) return false;
|
||||
if (!IsControlInput(input) && IsReallyConstant(*input_node)) {
|
||||
const_inputs.push_back(i);
|
||||
} else {
|
||||
@ -3058,8 +2994,7 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
|
||||
}
|
||||
const string new_node_name = OptimizedNodeName(
|
||||
*node, strings::StrCat("_partial_split_", const_inputs.size()));
|
||||
if (1 < const_inputs.size() &&
|
||||
const_inputs.size() < num_non_control_inputs &&
|
||||
if (const_inputs.size() > 1 && const_inputs.size() < num_non_control_inputs &&
|
||||
!node_map_->NodeExists(new_node_name)) {
|
||||
NodeDef* added_node = optimized_graph->add_node();
|
||||
*added_node = *node;
|
||||
@ -3091,13 +3026,11 @@ bool ConstantFolding::PartialAssocOpConstFolding(GraphDef* optimized_graph,
|
||||
}
|
||||
node->mutable_input()->DeleteSubrange(nonconst_inputs.size(),
|
||||
const_inputs.size() - 1);
|
||||
(*node->mutable_attr())["N"].set_i(node->input_size() -
|
||||
num_control_inputs);
|
||||
(*node->mutable_attr())["N"].set_i(node->input_size() - num_control_inputs);
|
||||
properties->ClearInputProperties(node->name());
|
||||
(*added_node->mutable_attr())["N"].set_i(const_inputs.size());
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -3107,9 +3040,12 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
// Partial constant folding for Concat which is not commutative, so
|
||||
// we have to preserve order and can only push consecutive runs of constant
|
||||
// inputs into sub-nodes.
|
||||
if (!IsConcat(*node) ||
|
||||
node->name().rfind("_partial_split_") != string::npos) {
|
||||
return false;
|
||||
}
|
||||
const int num_non_control_inputs = NumNonControlInputs(*node);
|
||||
if (IsConcat(*node) && num_non_control_inputs > 3 &&
|
||||
node->name().rfind("_partial_split_") == string::npos) {
|
||||
if (num_non_control_inputs <= 3) return false;
|
||||
int axis_arg = -1;
|
||||
int begin = 0;
|
||||
int end = num_non_control_inputs;
|
||||
@ -3123,14 +3059,6 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
return false;
|
||||
}
|
||||
|
||||
const NodeDef* axis_arg_node =
|
||||
node_map_->GetNode(NodeName(node->input(axis_arg)));
|
||||
if (axis_arg_node == nullptr || !IsReallyConstant(*axis_arg_node)) {
|
||||
// We cannot constant fold Concat unless we the axis argument is
|
||||
// constant. Skip node.
|
||||
return false;
|
||||
}
|
||||
|
||||
// We search for consecutive runs of constant inputs in the range
|
||||
// [begin:end[ and push then down into child nodes.
|
||||
std::vector<std::pair<int, int>> constant_input_runs;
|
||||
@ -3143,8 +3071,8 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
}
|
||||
// Invariant: node[first] is constant || first >= end.
|
||||
last = first + 1;
|
||||
while (last < end && IsReallyConstant(*node_map_->GetNode(
|
||||
NodeName(node->input(last))))) {
|
||||
while (last < end &&
|
||||
IsReallyConstant(*node_map_->GetNode(NodeName(node->input(last))))) {
|
||||
++last;
|
||||
}
|
||||
// Invariant: node[last] is not constant || last >= end
|
||||
@ -3156,9 +3084,9 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
}
|
||||
|
||||
// Skip if all inputs are constant, and let constant folding take over.
|
||||
if (constant_input_runs.size() == 1 &&
|
||||
if (constant_input_runs.empty() || (constant_input_runs.size() == 1 &&
|
||||
constant_input_runs[0].first == begin &&
|
||||
constant_input_runs[0].second == end) {
|
||||
constant_input_runs[0].second == end)) {
|
||||
return false;
|
||||
}
|
||||
std::set<int> inputs_to_delete;
|
||||
@ -3172,35 +3100,32 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
|
||||
NodeDef* added_node = optimized_graph->add_node();
|
||||
*added_node = *node;
|
||||
added_node->set_op("ConcatV2");
|
||||
added_node->set_name(new_node_name);
|
||||
node_map_->AddNode(added_node->name(), added_node);
|
||||
added_node->clear_input();
|
||||
for (int i = interval.first; i < interval.second; ++i) {
|
||||
added_node->add_input(node->input(i));
|
||||
node_map_->UpdateOutput(NodeName(node->input(i)), node->name(),
|
||||
added_node->name());
|
||||
node_map_->UpdateInput(node->name(), node->input(i), added_node->name());
|
||||
if (i != interval.first) {
|
||||
inputs_to_delete.insert(i);
|
||||
}
|
||||
}
|
||||
added_node->add_input(node->input(axis_arg));
|
||||
(*added_node->mutable_attr())["N"].set_i(interval.second -
|
||||
interval.first);
|
||||
(*added_node->mutable_attr())["N"].set_i(interval.second - interval.first);
|
||||
node_map_->AddOutput(NodeName(node->input(axis_arg)), added_node->name());
|
||||
|
||||
// Overwrite the first constant input with the result of the added
|
||||
// child node.
|
||||
node->set_input(interval.first, added_node->name());
|
||||
node_map_->AddOutput(added_node->name(), node->name());
|
||||
}
|
||||
if (!constant_input_runs.empty()) {
|
||||
if (!inputs_to_delete.empty()) {
|
||||
if (!constant_input_runs.empty() && !inputs_to_delete.empty()) {
|
||||
// Fix up the inputs to the original node.
|
||||
std::vector<string> tmp(node->input().begin(), node->input().end());
|
||||
node->clear_input();
|
||||
protobuf::RepeatedPtrField<string> tmp;
|
||||
tmp.Swap(node->mutable_input());
|
||||
for (int i = 0; i < tmp.size(); ++i) {
|
||||
if (inputs_to_delete.find(i) == inputs_to_delete.end()) {
|
||||
node->add_input(tmp[i]);
|
||||
node->add_input(tmp.Get(i));
|
||||
}
|
||||
}
|
||||
(*node->mutable_attr())["N"].set_i(node->input_size() - 1);
|
||||
@ -3208,55 +3133,83 @@ bool ConstantFolding::PartialConcatConstFolding(GraphDef* optimized_graph,
|
||||
}
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
bool ConstantFolding::GetConcatAxis(const NodeDef& node, int* axis) {
|
||||
if (node.op() != "ConcatV2") {
|
||||
return false;
|
||||
}
|
||||
int axis_idx = node.input_size() - 1;
|
||||
while (axis_idx > 0 && IsControlInput(node.input(axis_idx))) {
|
||||
--axis_idx;
|
||||
}
|
||||
if (axis_idx <= 0) {
|
||||
return false;
|
||||
}
|
||||
Tensor axis_tensor;
|
||||
if (!GetTensorFromConstNode(node.input(axis_idx), &axis_tensor)) {
|
||||
return false;
|
||||
}
|
||||
*axis = axis_tensor.dtype() == DT_INT64
|
||||
? static_cast<int>(axis_tensor.scalar<int64>()())
|
||||
: axis_tensor.scalar<int32>()();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool ConstantFolding::MergeConcat(const GraphProperties& properties,
|
||||
bool use_shape_info,
|
||||
bool ConstantFolding::MergeConcat(bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node) {
|
||||
// We only optimize for ConcatV2.
|
||||
int axis;
|
||||
if (!use_shape_info || !GetConcatAxis(properties, node, &axis) ||
|
||||
if (!use_shape_info || !GetConcatAxis(*node, &axis) ||
|
||||
nodes_to_preserve_.find(node->name()) != nodes_to_preserve_.end() ||
|
||||
node_map_->GetOutputs(node->name()).size() != 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// If all inputs are constant, don't merge and let folding take case of it.
|
||||
const int num_regular_inputs = NumNonControlInputs(*node);
|
||||
bool all_inputs_are_const = true;
|
||||
for (int i = 0; i < num_regular_inputs - 1; ++i) {
|
||||
const NodeDef* input_node = node_map_->GetNode(node->input(i));
|
||||
if (!IsReallyConstant(*input_node)) {
|
||||
all_inputs_are_const = false;
|
||||
}
|
||||
}
|
||||
if (all_inputs_are_const) return false;
|
||||
|
||||
NodeDef* parent = *node_map_->GetOutputs(node->name()).begin();
|
||||
int parent_axis;
|
||||
if (!GetConcatAxis(properties, parent, &parent_axis) || axis != parent_axis) {
|
||||
if (!GetConcatAxis(*parent, &parent_axis) || axis != parent_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int index = NumNonControlInputs(*node) - 1;
|
||||
auto inputs = parent->input();
|
||||
parent->clear_input();
|
||||
for (int i = 0; i < inputs.size(); ++i) {
|
||||
if (IsSameInput(inputs.Get(i), node->name())) {
|
||||
for (int j = 0; j < node->input_size(); ++j) {
|
||||
if (j < index) {
|
||||
// Input tensors (non axis), add to input list of parent.
|
||||
protobuf::RepeatedPtrField<string> parent_inputs;
|
||||
parent_inputs.Swap(parent->mutable_input());
|
||||
std::vector<string> ctrl_output;
|
||||
// TODO(rmlarsen): IF the child occurs more than once, is it beneficial to
|
||||
// collapse it into the parent multiple times? Probablyu not.
|
||||
for (const auto& input : parent_inputs) {
|
||||
if (IsSameInput(input, node->name())) {
|
||||
for (int j = 0; j < num_regular_inputs - 1; ++j) {
|
||||
// Add tensor inputs to first child concat tensors (exceptthe final axis
|
||||
// input) to the parent's inputs.
|
||||
parent->add_input(node->input(j));
|
||||
node_map_->RemoveOutput(node->input(j), node->name());
|
||||
node_map_->AddOutput(node->input(j), parent->name());
|
||||
}
|
||||
// Skip j == index, which means axis tensor.
|
||||
if (j > index) {
|
||||
// Control Dependencies, push back to inputs so they can be forwarded
|
||||
// to parent.
|
||||
*inputs.Add() = node->input(j);
|
||||
}
|
||||
node_map_->UpdateInput(parent->name(), node->name(), node->input(j));
|
||||
}
|
||||
} else {
|
||||
parent->add_input(inputs.Get(i));
|
||||
parent->add_input(input);
|
||||
}
|
||||
}
|
||||
// Forward Add control inputs
|
||||
for (int i = num_regular_inputs; i < node->input_size(); ++i) {
|
||||
parent->add_input(node->input(i));
|
||||
node_map_->UpdateInput(parent->name(), node->name(), node->input(i));
|
||||
}
|
||||
node->clear_input();
|
||||
node->set_op("NoOp");
|
||||
node->clear_attr();
|
||||
node_map_->RemoveNode(node->name());
|
||||
(*parent->mutable_attr())["N"].set_i(NumNonControlInputs(*parent) - 1);
|
||||
DedupControlInputs(parent);
|
||||
|
||||
return true;
|
||||
}
|
||||
@ -3344,7 +3297,9 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
|
||||
// that the shape inference deals with this conservatively unless we're in
|
||||
// aggressive mode.
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
Status s = properties.InferStatically(assume_valid_feeds);
|
||||
Status s = properties.InferStatically(assume_valid_feeds,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false);
|
||||
const bool can_use_shape_info = s.ok();
|
||||
|
||||
if (can_use_shape_info) {
|
||||
|
@ -61,6 +61,8 @@ class ConstantFolding : public GraphOptimizer {
|
||||
|
||||
bool IsReallyConstant(const NodeDef& node) const;
|
||||
|
||||
bool GetTensorFromConstNode(const string& node_name_or_input, Tensor* tensor);
|
||||
|
||||
Status MaterializeShapes(const GraphProperties& properties);
|
||||
|
||||
Status MaterializeBroadcastGradientArgs(const NodeDef& node,
|
||||
@ -239,8 +241,9 @@ class ConstantFolding : public GraphOptimizer {
|
||||
void RemoveSplitOrSplitV(const GraphProperties& properties,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
|
||||
bool MergeConcat(const GraphProperties& properties, bool use_shape_info,
|
||||
GraphDef* optimized_graph, NodeDef* node);
|
||||
bool GetConcatAxis(const NodeDef& node, int* axis);
|
||||
bool MergeConcat(bool use_shape_info, GraphDef* optimized_graph,
|
||||
NodeDef* node);
|
||||
|
||||
Status AddQuantizedMatMulMinMaxOutConstNodes(NodeDef* node,
|
||||
GraphDef* optimized_graph);
|
||||
|
@ -2390,12 +2390,11 @@ TEST_F(ConstantFoldingTest, MergeConcat_PartialFolding) {
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
GraphDef want;
|
||||
AddNode("ConstantFolding/concat2_partial_split_0_0", "Const", {}, {}, &want);
|
||||
AddNode("ConstantFolding/concat2_partial_split_0", "Const", {}, {}, &want);
|
||||
AddNode("axis", "Const", {}, {}, &want);
|
||||
AddNode("ph", "Placeholder", {}, {}, &want);
|
||||
AddNode("concat2", "ConcatV2",
|
||||
{"ConstantFolding/concat2_partial_split_0_0", "ph", "axis"}, {},
|
||||
&want);
|
||||
{"ConstantFolding/concat2_partial_split_0", "ph", "axis"}, {}, &want);
|
||||
|
||||
CompareGraphs(want, got);
|
||||
}
|
||||
|
@ -2209,7 +2209,10 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
|
||||
GraphProperties graph_properties(item);
|
||||
TF_RETURN_IF_ERROR(graph_properties.InferStatically(false));
|
||||
TF_RETURN_IF_ERROR(
|
||||
graph_properties.InferStatically(/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
|
||||
|
||||
virtual_placer_.reset(new VirtualPlacer(cluster->GetDevices()));
|
||||
|
@ -102,7 +102,8 @@ Status IsNodeOutputPortHostFriendly(const GraphView& graph,
|
||||
if (!properties->has_properties()) {
|
||||
// This is an expensive call, call it lazily.
|
||||
TF_RETURN_IF_ERROR(properties->InferStatically(
|
||||
/*assume_valid_feeds=*/false));
|
||||
/*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
}
|
||||
const auto& output_properties = properties->GetOutputProperties(node.name());
|
||||
if (port_id >= output_properties.size()) {
|
||||
@ -252,7 +253,8 @@ Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
|
||||
if (!properties->has_properties()) {
|
||||
// This is an expensive call, call it lazily.
|
||||
TF_RETURN_IF_ERROR(properties->InferStatically(
|
||||
/*assume_valid_feeds=*/false));
|
||||
/*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
}
|
||||
for (const auto& prop : properties->GetOutputProperties(node.name())) {
|
||||
if (!IsTensorSmall(prop)) {
|
||||
|
@ -1170,7 +1170,11 @@ Status Remapper::Optimize(Cluster* /*cluster*/, const GrapplerItem& item,
|
||||
|
||||
// Infer properties lazily in case they are not needed.
|
||||
if (!ctx.inferred_graph_properties && IsFusedBatchNormCandidate(node)) {
|
||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(false));
|
||||
// TODO(rmlarsen): Get rid of tensor value copies.
|
||||
TF_RETURN_IF_ERROR(ctx.graph_properties.InferStatically(
|
||||
/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/true));
|
||||
ctx.inferred_graph_properties = true;
|
||||
}
|
||||
|
||||
|
@ -739,9 +739,9 @@ Status ScopedAllocatorOptimizer::Optimize(Cluster* /*cluster*/,
|
||||
|
||||
GraphProperties graph_properties(item);
|
||||
const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
|
||||
LOG_WARNING_AND_RETURN_IF_ERROR(
|
||||
graph_properties.InferStatically(assume_valid_feeds));
|
||||
|
||||
LOG_WARNING_AND_RETURN_IF_ERROR(graph_properties.InferStatically(
|
||||
assume_valid_feeds, /*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
*optimized_graph = item.graph;
|
||||
node_map_.reset(new NodeMap(optimized_graph));
|
||||
|
||||
|
@ -87,7 +87,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
graph.GetRegularFanin(MutableGraphView::InputPort(fanout.node, 1));
|
||||
if (!inferred_properties) {
|
||||
// Infer properties lazily in case they are not needed.
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||
TF_RETURN_IF_ERROR(
|
||||
properties.InferStatically(/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
inferred_properties = true;
|
||||
}
|
||||
const auto& prop =
|
||||
@ -144,7 +147,10 @@ Status ShapeOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
}
|
||||
if (!inferred_properties) {
|
||||
// Infer properties lazily in case they are not needed.
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(false));
|
||||
TF_RETURN_IF_ERROR(
|
||||
properties.InferStatically(/*assume_valid_feeds=*/false,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
inferred_properties = true;
|
||||
}
|
||||
const auto& prop1 = properties.GetInputProperties(input1.node->name());
|
||||
|
@ -14,7 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/static_schedule.h"
|
||||
|
||||
#include <deque>
|
||||
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
@ -92,7 +94,10 @@ Status EstimateEarliestExecutionTimes(
|
||||
name_map.clear();
|
||||
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(true));
|
||||
TF_RETURN_IF_ERROR(
|
||||
properties.InferStatically(/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
OpLevelCostEstimator estimator;
|
||||
VirtualPlacer placer(cluster->GetDevices());
|
||||
|
||||
@ -160,7 +165,10 @@ Status EstimateRequiredTimes(
|
||||
}
|
||||
}
|
||||
GraphProperties properties(item);
|
||||
TF_RETURN_IF_ERROR(properties.InferStatically(true));
|
||||
TF_RETURN_IF_ERROR(
|
||||
properties.InferStatically(/*assume_valid_feeds=*/true,
|
||||
/*aggressive_shape_inference=*/false,
|
||||
/*include_tensor_values=*/false));
|
||||
OpLevelCostEstimator estimator;
|
||||
VirtualPlacer placer(cluster->GetDevices());
|
||||
|
||||
|
@ -346,7 +346,8 @@ class FoldOldBatchNormsTest : public ::testing::Test {
|
||||
std::vector<Tensor> fused_outputs;
|
||||
TF_ASSERT_OK(fused_session->Run({}, {"output"}, {}, &fused_outputs));
|
||||
|
||||
test::ExpectTensorNear<float>(original_outputs[0], fused_outputs[0], 2e-5);
|
||||
test::ExpectClose(original_outputs[0], fused_outputs[0], /*atol=*/2e-5,
|
||||
/*rtol=*/2e-5);
|
||||
|
||||
for (const NodeDef& node : fused_graph_def.node()) {
|
||||
EXPECT_NE("FusedBatchNorm", node.op());
|
||||
|
Loading…
Reference in New Issue
Block a user