[XLA][TF2XLA] Remove CF for Shape ops.
XLA do special shape value inferring, no need to do it in tf's constant folding. PiperOrigin-RevId: 271218195
This commit is contained in:
parent
7d4643432b
commit
b6a97cecc5
@ -609,12 +609,23 @@ std::unique_ptr<Graph> XlaCompiler::GetGraph(const FunctionBody* fbody) {
|
|||||||
// However since we are only allowed to specify the filter at the "Node"
|
// However since we are only allowed to specify the filter at the "Node"
|
||||||
// level there is no good way to allow the above behavior. So we
|
// level there is no good way to allow the above behavior. So we
|
||||||
// disallow any sort of constant folding on Variant nodes for now.
|
// disallow any sort of constant folding on Variant nodes for now.
|
||||||
|
//
|
||||||
|
// Also do not consider constant folding Shape ops. When there is a dynamic
|
||||||
|
// dimension in a tensor, TF2XLA currently represent them as the static
|
||||||
|
// upperbound shape, which can be constant folded and then lose the info
|
||||||
|
// that this Shape is dynamic.
|
||||||
auto cf_consider_fn = [](const Node* n) {
|
auto cf_consider_fn = [](const Node* n) {
|
||||||
for (const auto& output_arg : n->op_def().output_arg()) {
|
for (const auto& output_arg : n->op_def().output_arg()) {
|
||||||
if (output_arg.type() == DT_VARIANT) {
|
if (output_arg.type() == DT_VARIANT) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
const auto& ts = n->type_string();
|
||||||
|
// XLA has special logic to handle dynamic shapes, don't constant fold
|
||||||
|
// them.
|
||||||
|
if (ts == "Shape" || ts == "ShapeN" || ts == "Size") {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
};
|
};
|
||||||
GraphOptimizer::Options graph_optimizer_options;
|
GraphOptimizer::Options graph_optimizer_options;
|
||||||
|
@ -1790,5 +1790,62 @@ TEST_F(XlaCompilerTest, SetShardingForReturnedTuple) {
|
|||||||
tuple_sharding.ToProto().SerializeAsString());
|
tuple_sharding.ToProto().SerializeAsString());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(XlaCompilerTest, DoNotConstantFoldShapeOp) {
|
||||||
|
// When we have a dynamic shape input followed by a Shape op, the Shape op
|
||||||
|
// should return dynamic size:
|
||||||
|
//
|
||||||
|
// [2, b] // b's static size is 3 and dynamic size is 2
|
||||||
|
// |
|
||||||
|
// Size // should return 2, 2
|
||||||
|
Scope scope = Scope::NewRootScope().ExitOnError();
|
||||||
|
auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0);
|
||||||
|
auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1);
|
||||||
|
auto shape = ops::Shape(scope.WithOpName("shape"), a);
|
||||||
|
(void)ops::_Retval(scope.WithOpName("retval"), shape, 0);
|
||||||
|
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
|
||||||
|
TF_ASSERT_OK(scope.ToGraph(graph.get()));
|
||||||
|
|
||||||
|
// Builds a description of the arguments.
|
||||||
|
std::vector<XlaCompiler::Argument> args(2);
|
||||||
|
args[0].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[0].type = DT_INT32;
|
||||||
|
args[0].shape = TensorShape({2, 3});
|
||||||
|
// Indicates that first dimension is dynamic, and arg 1 holds the runtime
|
||||||
|
// value of it.
|
||||||
|
args[0].dynamic_dim_to_arg_num_map.insert({1, 1});
|
||||||
|
|
||||||
|
// Arg 1 holds the dynamic size.
|
||||||
|
args[1].kind = XlaCompiler::Argument::kParameter;
|
||||||
|
args[1].type = DT_INT32;
|
||||||
|
args[1].shape = TensorShape({});
|
||||||
|
|
||||||
|
// Compiles the graph.
|
||||||
|
XlaCompiler compiler(DefaultOptions());
|
||||||
|
|
||||||
|
XlaCompiler::CompilationResult result;
|
||||||
|
auto options = XlaCompiler::CompileOptions();
|
||||||
|
options.resolve_compile_time_constants = false;
|
||||||
|
TF_ASSERT_OK(compiler.CompileGraph(options, "test", std::move(graph), args,
|
||||||
|
/*user_aliases=*/{}, &result));
|
||||||
|
|
||||||
|
xla::Literal literal0 =
|
||||||
|
xla::LiteralUtil::CreateR2<int32>({{0, 1, 2}, {3, 4, 5}});
|
||||||
|
xla::Literal literal1 = xla::LiteralUtil::CreateR0<int32>(2);
|
||||||
|
std::unique_ptr<xla::GlobalData> data0 =
|
||||||
|
client_->TransferToServer(literal0).ConsumeValueOrDie();
|
||||||
|
std::unique_ptr<xla::GlobalData> data1 =
|
||||||
|
client_->TransferToServer(literal1).ConsumeValueOrDie();
|
||||||
|
|
||||||
|
// Prepare arguments.
|
||||||
|
std::unique_ptr<xla::GlobalData> actual =
|
||||||
|
client_->Execute(*result.computation, {data0.get(), data1.get()})
|
||||||
|
.ConsumeValueOrDie();
|
||||||
|
xla::Literal actual_literal = client_->Transfer(*actual).ConsumeValueOrDie();
|
||||||
|
// The dynamic size of the op is <2, 2> instead of static size <2, 3>
|
||||||
|
xla::Literal expected = xla::LiteralUtil::CreateR1<int32>({2, 2});
|
||||||
|
xla::Literal expected_literal = xla::LiteralUtil::MakeTuple({&expected});
|
||||||
|
EXPECT_TRUE(xla::LiteralTestUtil::Equal(expected_literal, actual_literal));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
Loading…
Reference in New Issue
Block a user