[XLA][TF2XLA] Remove CF for Shape ops.
XLA do special shape value inferring, no need to do it in tf's constant folding. PiperOrigin-RevId: 271218195
This commit is contained in:
parent
7d4643432b
commit
b6a97cecc5
@ -609,12 +609,23 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
||||
// However since we are only allowed to specify the filter at the "Node"
|
||||
// level there is no good way to allow the above behavior. So we
|
||||
// disallow any sort of constant folding on Variant nodes for now.
|
||||
//
|
||||
// Also do not consider constant folding Shape ops. When there is a dynamic
|
||||
// dimension in a tensor, TF2XLA currently represent them as the static
|
||||
// upperbound shape, which can be constant folded and then lose the info
|
||||
// that this Shape is dynamic.
|
||||
auto cf_consider_fn = [](const Node* n) {
|
||||
for (const auto& output_arg : n->op_def().output_arg()) {
|
||||
if (output_arg.type() == DT_VARIANT) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
const auto& ts = n->type_string();
|
||||
// XLA has special logic to handle dynamic shapes, don't constant fold
|
||||
// them.
|
||||
if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
GraphOptimizer::Options graph_optimizer_options;
|
||||
|
@ -1790,5 +1790,62 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
|
||||
tuple_sharding.ToProto().SerializeAsString());
|
||||
}
|
||||
|
||||
TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
|
||||
// When we have a dynamic shape input followed by a Shape op, the Shape op
|
||||
// should return dynamic size:
|
||||
//
|
||||
// [2, b] // b's static size is 3 and dynamic size is 2
|
||||
// |
|
||||
// Size // should return 2, 2
|
||||
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
||||
auto shape = ops::Shape(scope.WithOpName("shape"), a);
|
||||
(void)ops::_Retval(scope.WithOpName("retval"), shape, 0);
|
||||
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||
|
||||
// Builds a description of the arguments.
|
||||
std::vector<XlaCompiler::Argument> args(2);
|
||||
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||
args[0].type = DT_INT32;
|
||||
args[0].shape = TensorShape({2, 3});
|
||||
// Indicates that first dimension is dynamic, and arg 1 holds the runtime
|
||||
// value of it.
|
||||
args[0].dynamic_dim_to_arg_num_map.insert({1, 1});
|
||||
|
||||
// Arg 1 holds the dynamic size.
|
||||
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||
args[1].type = DT_INT32;
|
||||
args[1].shape = TensorShape({});
|
||||
|
||||
// Compiles the graph.
|
||||
XlaCompiler compiler(DefaultOptions());
|
||||
|
||||
XlaCompiler::CompilationResult result;
|
||||
auto options = XlaCompiler::CompileOptions();
|
||||
options.resolve_compile_time_constants = false;
|
||||
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
|
||||
/*user_aliases=*/{}, &result));
|
||||
|
||||
xla::Literal literal0 =
|
||||
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
|
||||
xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
|
||||
std::unique_ptr<xla::GlobalData> data0 =
|
||||
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||
std::unique_ptr<xla::GlobalData> data1 =
|
||||
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||
|
||||
// Prepare arguments.
|
||||
std::unique_ptr<xla::GlobalData> actual =
|
||||
client_->Execute(*result.computation, {data0.get(), data1.get()})
|
||||
.ConsumeValueOrDie();
|
||||
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||
// The dynamic size of the op is <2, 2> instead of static size <2, 3>
|
||||
xla::Literal expected = xla::LiteralUtil::CreateR1<int32>({2, 2});
|
||||
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected});
|
||||
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
Reference in New Issue
Block a user