Allow the XRTCompile op to return the ProgramShape resulted form the XLA compilation.
PiperOrigin-RevId: 216619617
This commit is contained in:
parent
0be7b32fa4
commit
3abfe2cd9b
@ -89,6 +89,8 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
const auto& program_shape = instance.computation.program_shape();
|
||||
ExecutionOptions execution_options;
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*instance.result_layout;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, instance.argument_layouts,
|
||||
|
@ -166,10 +166,21 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "Compiling XLA executable";
|
||||
return Compile(ctx, computation_proto, program);
|
||||
}));
|
||||
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
|
||||
OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
|
||||
|
||||
Tensor output(DT_INT64, TensorShape({}));
|
||||
output.scalar<int64>()() = uid;
|
||||
ctx->set_output(0, output);
|
||||
Tensor handle_output(DT_INT64, TensorShape({}));
|
||||
handle_output.scalar<int64>()() = uid;
|
||||
ctx->set_output(0, handle_output);
|
||||
|
||||
xla::LocalExecutable* executable = entry->get().get_executable();
|
||||
xla::ProgramShape program_shape = executable->executable()
|
||||
->module()
|
||||
.entry_computation()
|
||||
->ComputeProgramShape();
|
||||
Tensor program_shape_output(DT_STRING, TensorShape({1}));
|
||||
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
|
||||
ctx->set_output(1, program_shape_output);
|
||||
}
|
||||
|
||||
XRTCompileOp::~XRTCompileOp() = default;
|
||||
|
@ -64,14 +64,6 @@ uint32 GetXLARandomSeed() {
|
||||
return counter.fetch_add(2);
|
||||
}
|
||||
|
||||
// Looks up the input `key` in the compilation cache.
|
||||
Status GetComputationCacheEntry(
|
||||
XRTCompilationCache* cache, int64 key,
|
||||
std::unique_ptr<XRTCompilationCacheEntryRef>* entry) {
|
||||
TF_RETURN_IF_ERROR(cache->Lookup(key, entry));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Populates `inputs` with the input tensors to the computation.
|
||||
Status GetComputationInputs(OpKernelContext* context, ResourceMgr* rm,
|
||||
bool release_inputs,
|
||||
|
@ -23,7 +23,12 @@ namespace tensorflow {
|
||||
REGISTER_OP("XRTCompile")
|
||||
.Input("computation: string")
|
||||
.Output("handle: int64")
|
||||
.SetShapeFn(tensorflow::shape_inference::ScalarShape)
|
||||
.Output("program_shape: string")
|
||||
.SetShapeFn([](shape_inference::InferenceContext* c) {
|
||||
c->set_output(0, c->Scalar());
|
||||
c->set_output(1, c->UnknownShapeOfRank(1));
|
||||
return Status::OK();
|
||||
})
|
||||
.Doc(
|
||||
R"(
|
||||
Reads a computation proto, compiles it, and places it in the global compilation
|
||||
|
@ -29,8 +29,11 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/service:platform_util",
|
||||
"//tensorflow/compiler/xrt:xrt_proto",
|
||||
"//tensorflow/compiler/xrt:xrt_server",
|
||||
"//tensorflow/compiler/xrt/cc:xrt_ops",
|
||||
@ -49,7 +52,10 @@ tf_cc_test(
|
||||
name = "raw_api_test_cpu",
|
||||
size = "medium",
|
||||
srcs = [],
|
||||
args = ["--xla_test_device=XLA_CPU"],
|
||||
args = [
|
||||
"--xla_test_device=XLA_CPU",
|
||||
"--xla_platform=CPU",
|
||||
],
|
||||
deps = [
|
||||
":raw_api_test_lib",
|
||||
"//tensorflow/compiler/jit:xla_cpu_device",
|
||||
@ -60,7 +66,10 @@ tf_cuda_cc_test(
|
||||
name = "raw_api_test_gpu",
|
||||
size = "medium",
|
||||
srcs = [],
|
||||
args = ["--xla_test_device=XLA_GPU"],
|
||||
args = [
|
||||
"--xla_test_device=XLA_GPU",
|
||||
"--xla_platform=GPU",
|
||||
],
|
||||
tags = tf_cuda_tests_tags(),
|
||||
deps = [
|
||||
":raw_api_test_lib",
|
||||
|
@ -22,10 +22,13 @@ limitations under the License.
|
||||
#include "tensorflow/cc/framework/ops.h"
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/literal.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/platform_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/compiler/xrt/cc/ops/xrt_compile_ops.h"
|
||||
@ -43,6 +46,7 @@ namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
string* xla_test_device_ptr; // initial value set in main()
|
||||
string* xla_platform_ptr; // initial value set in main()
|
||||
|
||||
string DeviceFromFlag() {
|
||||
string xla_test_device = *xla_test_device_ptr;
|
||||
@ -145,6 +149,28 @@ void StoreComputationSnapshot(const xla::XlaComputation& computation,
|
||||
*dst = *snapshot;
|
||||
}
|
||||
|
||||
xla::ProgramShape XlaCompiledProgramShape(
|
||||
const xla::XlaComputation& computation,
|
||||
const xla::ProgramShape& input_program_shape) {
|
||||
se::Platform* platform =
|
||||
xla::PlatformUtil::GetPlatform(*xla_platform_ptr).ValueOrDie();
|
||||
xla::LocalClient* client =
|
||||
xla::ClientLibrary::GetOrCreateLocalClient(platform).ValueOrDie();
|
||||
xla::ExecutableBuildOptions exec_options;
|
||||
exec_options.set_result_layout(input_program_shape.result());
|
||||
std::vector<const xla::Shape*> parameters_shapes;
|
||||
for (int64 i = 0; i < input_program_shape.parameters_size(); ++i) {
|
||||
parameters_shapes.push_back(&input_program_shape.parameters(i));
|
||||
}
|
||||
auto local_executable =
|
||||
client->Compile(computation, parameters_shapes, exec_options)
|
||||
.ValueOrDie();
|
||||
return local_executable->executable()
|
||||
->module()
|
||||
.entry_computation()
|
||||
->ComputeProgramShape();
|
||||
}
|
||||
|
||||
TEST(RawApiTest, ReadAndWriteState) {
|
||||
xrt::XLAAllocation alloc;
|
||||
alloc.set_device_ordinal(0);
|
||||
@ -338,20 +364,87 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
auto p1_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto result = ops::XRTExecute(root, c_handle, e_config,
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
{Output(p0_handle), Output(p1_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({read_back}, &outputs));
|
||||
TF_EXPECT_OK(session.Run({read_back, c_handle.program_shape}, &outputs));
|
||||
|
||||
xla::LiteralProto response;
|
||||
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
|
||||
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
xla::XlaBuilder builder("XrtXlaShapes");
|
||||
auto input_shape = xla::ShapeUtil::MakeShape(xla::BF16, {32, 3, 128, 128});
|
||||
auto kernel_shape = xla::ShapeUtil::MakeShape(xla::BF16, {3, 3, 5, 5});
|
||||
// Clear layouts to signal XLA we are ready to get whatever are coming out of
|
||||
// the compilation process.
|
||||
xla::LayoutUtil::ClearLayout(&input_shape);
|
||||
xla::LayoutUtil::ClearLayout(&kernel_shape);
|
||||
auto param_shape =
|
||||
xla::ShapeUtil::MakeTupleShape({input_shape, kernel_shape});
|
||||
auto param = xla::Parameter(&builder, 0, param_shape, "param");
|
||||
auto input = xla::GetTupleElement(param, 0);
|
||||
auto kernel = xla::GetTupleElement(param, 1);
|
||||
xla::Conv(input, kernel, {1, 1}, xla::Padding::kSame);
|
||||
TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation xla_computation, builder.Build());
|
||||
|
||||
auto result_shape = xla_computation.GetProgramShape().ValueOrDie().result();
|
||||
// Clear the result shape layout to tell XLA we are accepting whatever are
|
||||
// coming out of the compilation process.
|
||||
xla::LayoutUtil::ClearLayout(&result_shape);
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
*shapes->add_parameters() = param_shape;
|
||||
*shapes->mutable_result() = result_shape;
|
||||
StoreComputationSnapshot(xla_computation, c.mutable_hlo_snapshot());
|
||||
|
||||
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
|
||||
auto computation =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto release = ops::XRTReleaseCompilationHandle(root, c_handle.handle);
|
||||
TF_ASSERT_OK(root.status());
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
|
||||
{c_handle.program_shape}, {release}, &outputs));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 1);
|
||||
|
||||
VLOG(2) << "Param: "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.parameters(0));
|
||||
VLOG(2) << "Result: "
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
|
||||
|
||||
xla::ProgramShape xla_program_shape =
|
||||
XlaCompiledProgramShape(xla_computation, *shapes);
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
|
||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
|
||||
.layout()));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {1}).layout(),
|
||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {1})
|
||||
.layout()));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(program_shape.result().layout(),
|
||||
xla_program_shape.result().layout()));
|
||||
}
|
||||
|
||||
TEST(RawApiTest, CompileAndExecuteZeroArg) {
|
||||
@ -371,7 +464,7 @@ TEST(RawApiTest, CompileAndExecuteZeroArg) {
|
||||
auto computation =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
|
||||
auto c_handle = ops::XRTCompile(root, computation);
|
||||
auto result = ops::XRTExecute(root, c_handle, e_config,
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
std::initializer_list<Input>({}));
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
@ -420,7 +513,7 @@ TEST(RawApiTest, CompileAndExecuteReturnTuple) {
|
||||
auto p1_value =
|
||||
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
|
||||
auto p1_handle = ops::XRTAllocate(root, p1_value);
|
||||
auto result = ops::XRTExecute(root, c_handle, e_config,
|
||||
auto result = ops::XRTExecute(root, c_handle.handle, e_config,
|
||||
{Output(p0_handle), Output(p1_handle)});
|
||||
auto read_back = ops::XRTReadLiteralAndRelease(root, result);
|
||||
TF_ASSERT_OK(root.status());
|
||||
@ -455,7 +548,7 @@ TEST(RawApiTest, LeakCompilationReference) {
|
||||
|
||||
ClientSession session(root);
|
||||
std::vector<Tensor> outputs;
|
||||
TF_EXPECT_OK(session.Run({c_handle}, &outputs));
|
||||
TF_EXPECT_OK(session.Run({c_handle.handle}, &outputs));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -464,9 +557,12 @@ TEST(RawApiTest, LeakCompilationReference) {
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
tensorflow::xla_test_device_ptr = new tensorflow::string("XLA_CPU");
|
||||
tensorflow::xla_platform_ptr = new tensorflow::string("CPU");
|
||||
std::vector<tensorflow::Flag> flag_list = {
|
||||
tensorflow::Flag("xla_test_device", tensorflow::xla_test_device_ptr,
|
||||
"Tensorflow device type to use for test, e.g., XLA_CPU"),
|
||||
tensorflow::Flag("xla_platform", tensorflow::xla_platform_ptr,
|
||||
"The XLA platform to select for the device"),
|
||||
};
|
||||
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
|
Loading…
x
Reference in New Issue
Block a user