[XLA][TF2XLA] Remove CF for Shape ops.

XLA do special shape value inferring, no need to do it in tf's constant folding.

PiperOrigin-RevId: 271218195
This commit is contained in:
Yunxing Dai 2019-09-25 15:20:01 -07:00 committed by TensorFlower Gardener
parent 7d4643432b
commit b6a97cecc5
2 changed files with 68 additions and 0 deletions

View File

@ -609,12 +609,23 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
// However since we are only allowed to specify the filter at the "Node" // However since we are only allowed to specify the filter at the "Node"
// level there is no good way to allow the above behavior. So we // level there is no good way to allow the above behavior. So we
// disallow any sort of constant folding on Variant nodes for now. // disallow any sort of constant folding on Variant nodes for now.
//
// Also do not consider constant folding Shape ops. When there is a dynamic
// dimension in a tensor, TF2XLA currently represent them as the static
// upperbound shape, which can be constant folded and then lose the info
// that this Shape is dynamic.
auto cf_consider_fn = [](const Node* n) { auto cf_consider_fn = [](const Node* n) {
for (const auto& output_arg : n->op_def().output_arg()) { for (const auto& output_arg : n->op_def().output_arg()) {
if (output_arg.type() == DT_VARIANT) { if (output_arg.type() == DT_VARIANT) {
return false; return false;
} }
} }
const auto& ts = n->type_string();
// XLA has special logic to handle dynamic shapes, don't constant fold
// them.
if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
return false;
}
return true; return true;
}; };
GraphOptimizer::Options graph_optimizer_options; GraphOptimizer::Options graph_optimizer_options;

View File

@ -1790,5 +1790,62 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
tuple_sharding.ToProto().SerializeAsString()); tuple_sharding.ToProto().SerializeAsString());
} }
TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
// When we have a dynamic shape input followed by a Shape op, the Shape op
// should return dynamic size:
//
// [2, b] // b's static size is 3 and dynamic size is 2
// |
// Size // should return 2, 2
Scope scope = Scope::NewRootScope().ExitOnError();
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
auto shape = ops::Shape(scope.WithOpName("shape"), a);
(void)ops::_Retval(scope.WithOpName("retval"), shape, 0);
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_ASSERT_OK(scope.ToGraph(graph.get()));
// Builds a description of the arguments.
std::vector<XlaCompiler::Argument> args(2);
args[0].kind = XlaCompiler::Argument::kParameter;
args[0].type = DT_INT32;
args[0].shape = TensorShape({2, 3});
// Indicates that first dimension is dynamic, and arg 1 holds the runtime
// value of it.
args[0].dynamic_dim_to_arg_num_map.insert({1, 1});
// Arg 1 holds the dynamic size.
args[1].kind = XlaCompiler::Argument::kParameter;
args[1].type = DT_INT32;
args[1].shape = TensorShape({});
// Compiles the graph.
XlaCompiler compiler(DefaultOptions());
XlaCompiler::CompilationResult result;
auto options = XlaCompiler::CompileOptions();
options.resolve_compile_time_constants = false;
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
/*user_aliases=*/{}, &result));
xla::Literal literal0 =
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
std::unique_ptr<xla::GlobalData> data0 =
client_->TransferToServer(literal0).ConsumeValueOrDie();
std::unique_ptr<xla::GlobalData> data1 =
client_->TransferToServer(literal1).ConsumeValueOrDie();
// Prepare arguments.
std::unique_ptr<xla::GlobalData> actual =
client_->Execute(*result.computation, {data0.get(), data1.get()})
.ConsumeValueOrDie();
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
// The dynamic size of the op is <2, 2> instead of static size <2, 3>
xla::Literal expected = xla::LiteralUtil::CreateR1<int32>({2, 2});
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected});
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow