From 3abfe2cd9befa263de57edfae7d4c0d29c9c9182 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 10 Oct 2018 17:07:19 -0700 Subject: [PATCH] Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation. PiperOrigin-RevId: 216619617 --- .../xla/service/compile_only_service.cc | 2 + .../compiler/xrt/kernels/xrt_compile_ops.cc | 17 ++- .../compiler/xrt/kernels/xrt_execute_op.cc | 8 -- .../compiler/xrt/ops/xrt_compile_ops.cc | 7 +- tensorflow/compiler/xrt/tests/BUILD | 13 ++- tensorflow/compiler/xrt/tests/raw_api_test.cc | 106 +++++++++++++++++- 6 files changed, 134 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 96bd2616f56..bd5045b9b91 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -89,6 +89,8 @@ CompileOnlyService::CompileAheadOfTime( const auto& program_shape = instance.computation.program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; + *execution_options.mutable_shape_with_output_layout() = + *instance.result_layout; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(program_shape, instance.argument_layouts, diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index 1d4f8d97f2e..1ab836a4960 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -166,10 +166,21 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { VLOG(1) << "Compiling XLA executable"; return Compile(ctx, computation_proto, program); })); + std::unique_ptr entry; + OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry)); - Tensor output(DT_INT64, TensorShape({})); - output.scalar()() = uid; - ctx->set_output(0, output); + Tensor handle_output(DT_INT64, TensorShape({})); + handle_output.scalar()() = uid; + ctx->set_output(0, handle_output); + + xla::LocalExecutable* executable = entry->get().get_executable(); + xla::ProgramShape program_shape = executable->executable() + ->module() + .entry_computation() + ->ComputeProgramShape(); + Tensor program_shape_output(DT_STRING, TensorShape({1})); + program_shape_output.vec()(0) = program_shape.SerializeAsString(); + ctx->set_output(1, program_shape_output); } XRTCompileOp::~XRTCompileOp() = default; diff --git a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc index 257b054f16a..3a1e03280a3 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_execute_op.cc @@ -64,14 +64,6 @@ uint32 GetXLARandomSeed() { return counter.fetch_add(2); } -// Looks up the input `key` in the compilation cache. -Status GetComputationCacheEntry( - XRTCompilationCache* cache, int64 key, - std::unique_ptr* entry) { - TF_RETURN_IF_ERROR(cache->Lookup(key, entry)); - return Status::OK(); -} - // Populates `inputs` with the input tensors to the computation. Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm, bool release_inputs, diff --git a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc index 5cfc8711f9f..7b3b50c6955 100644 --- a/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/ops/xrt_compile_ops.cc @@ -23,7 +23,12 @@ namespace tensorflow { REGISTER_OP("XRTCompile") .Input("computation: string") .Output("handle: int64") - .SetShapeFn(tensorflow::shape_inference::ScalarShape) + .Output("program_shape: string") + .SetShapeFn([](shape_inference::InferenceContext* c) { + c->set_output(0, c->Scalar()); + c->set_output(1, c->UnknownShapeOfRank(1)); + return Status::OK(); + }) .Doc( R"( Reads a computation proto, compiles it, and places it in the global compilation diff --git a/tensorflow/compiler/xrt/tests/BUILD b/tensorflow/compiler/xrt/tests/BUILD index b6dcfc4eb96..be44a3474ac 100644 --- a/tensorflow/compiler/xrt/tests/BUILD +++ b/tensorflow/compiler/xrt/tests/BUILD @@ -29,8 +29,11 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/service:platform_util", "//tensorflow/compiler/xrt:xrt_proto", "//tensorflow/compiler/xrt:xrt_server", "//tensorflow/compiler/xrt/cc:xrt_ops", @@ -49,7 +52,10 @@ tf_cc_test( name = "raw_api_test_cpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_CPU"], + args = [ + "--xla_test_device=XLA_CPU", + "--xla_platform=CPU", + ], deps = [ ":raw_api_test_lib", "//tensorflow/compiler/jit:xla_cpu_device", @@ -60,7 +66,10 @@ tf_cuda_cc_test( name = "raw_api_test_gpu", size = "medium", srcs = [], - args = ["--xla_test_device=XLA_GPU"], + args = [ + "--xla_test_device=XLA_GPU", + "--xla_platform=GPU", + ], tags = tf_cuda_tests_tags(), deps = [ ":raw_api_test_lib", diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 9fc01e6304c..ee6734020d9 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -22,10 +22,13 @@ limitations under the License. #include "tensorflow/cc/framework/ops.h" #include "tensorflow/cc/framework/scope.h" #include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/platform_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h" @@ -43,6 +46,7 @@ namespace tensorflow { namespace { string* xla_test_device_ptr; // initial value set in main() +string* xla_platform_ptr; // initial value set in main() string DeviceFromFlag() { string xla_test_device = *xla_test_device_ptr; @@ -145,6 +149,28 @@ void StoreComputationSnapshot(const xla::XlaComputation& computation, *dst = *snapshot; } +xla::ProgramShape XlaCompiledProgramShape( + const xla::XlaComputation& computation, + const xla::ProgramShape& input_program_shape) { + se::Platform* platform = + xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie(); + xla::LocalClient* client = + xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie(); + xla::ExecutableBuildOptions exec_options; + exec_options.set_result_layout(input_program_shape.result()); + std::vector parameters_shapes; + for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) { + parameters_shapes.push_back(&input_program_shape.parameters(i)); + } + auto local_executable = + client->Compile(computation, parameters_shapes, exec_options) + .ValueOrDie(); + return local_executable->executable() + ->module() + .entry_computation() + ->ComputeProgramShape(); +} + TEST(RawApiTest, ReadAndWriteState) { xrt::XLAAllocation alloc; alloc.set_device_ordinal(0); @@ -338,20 +364,87 @@ TEST(RawApiTest, CompileAndExecute) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({read_back}, &outputs)); + TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs)); xla::LiteralProto response; EXPECT_TRUE(response.ParseFromString(outputs[0].scalar()())); auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); + + xla::ProgramShape program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 2); +} + +TEST(RawApiTest, CompileWithXlaReturnShapes) { + xla::XlaBuilder builder("XrtXlaShapes"); + auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128}); + auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5}); + // Clear layouts to signal XLA we are ready to get whatever are coming out of + // the compilation process. + xla::LayoutUtil::ClearLayout(&input_shape); + xla::LayoutUtil::ClearLayout(&kernel_shape); + auto param_shape = + xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape}); + auto param = xla::Parameter(&builder, 0, param_shape, "param"); + auto input = xla::GetTupleElement(param, 0); + auto kernel = xla::GetTupleElement(param, 1); + xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build()); + + auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result(); + // Clear the result shape layout to tell XLA we are accepting whatever are + // coming out of the compilation process. + xla::LayoutUtil::ClearLayout(&result_shape); + + xrt::XLAComputation c; + auto config = c.mutable_config(); + auto shapes = config->mutable_program_shape(); + *shapes->add_parameters() = param_shape; + *shapes->mutable_result() = result_shape; + StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot()); + + Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag()); + auto computation = + ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); + auto c_handle = ops::XRTCompile(root, computation); + auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle); + TF_ASSERT_OK(root.status()); + + ClientSession session(root); + std::vector outputs; + TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), + {c_handle.program_shape}, {release}, &outputs)); + + xla::ProgramShape program_shape; + EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec()(0))); + EXPECT_EQ(program_shape.parameters_size(), 1); + + VLOG(2) << "Param: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0)); + VLOG(2) << "Result: " + << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); + + xla::ProgramShape xla_program_shape = + XlaCompiledProgramShape(xla_computation, *shapes); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal( + xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(), + xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1}) + .layout())); + EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(), + xla_program_shape.result().layout())); } TEST(RawApiTest, CompileAndExecuteZeroArg) { @@ -371,7 +464,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) { auto computation = ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString()); auto c_handle = ops::XRTCompile(root, computation); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, std::initializer_list({})); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); @@ -420,7 +513,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) { auto p1_value = ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString()); auto p1_handle = ops::XRTAllocate(root, p1_value); - auto result = ops::XRTExecute(root, c_handle, e_config, + auto result = ops::XRTExecute(root, c_handle.handle, e_config, {Output(p0_handle), Output(p1_handle)}); auto read_back = ops::XRTReadLiteralAndRelease(root, result); TF_ASSERT_OK(root.status()); @@ -455,7 +548,7 @@ TEST(RawApiTest, LeakCompilationReference) { ClientSession session(root); std::vector outputs; - TF_EXPECT_OK(session.Run({c_handle}, &outputs)); + TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs)); } } // namespace @@ -464,9 +557,12 @@ TEST(RawApiTest, LeakCompilationReference) { int main(int argc, char** argv) { tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU"); + tensorflow::xla_platform_ptr = new tensorflow::string("CPU"); std::vector flag_list = { tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr, "Tensorflow device type to use for test, e.g., XLA_CPU"), + tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr, + "The XLA platform to select for the device"), }; tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list); const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);