diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index b17bc658fa0..697599f3bb1 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code, } // Generate methods for args (inputs). -Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, +Status GenArgMethods(const tf2xla::Config& config, + const xla::ProgramShapeProto& ps, const CompileResult& compile_result, string* methods) { size_t num_args = ps.parameters_size(); if (config.feed_size() != num_args) { @@ -204,7 +205,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, // Generate methods for results (outputs). Status GenResultMethods(const tf2xla::Config& config, - const xla::ProgramShape& ps, string* methods) { + const xla::ProgramShapeProto& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { // The XlaCompiler we use to build the xla computation always generates a // tuple result, and we rely on this to simplify code generation. @@ -336,7 +337,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config, ExtractEntryParamBufferInfos(buffer_infos); std::vector buffer_infos_for_temps = ExtractTempBufferInfos(buffer_infos); - const xla::ProgramShape& ps = compile_result.program_shape; + const xla::ProgramShapeProto& ps = compile_result.program_shape; string methods_arg, methods_result; TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg)); TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result)); @@ -548,8 +549,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { static const char** StaticResultNames() {{RESULT_NAMES_CODE}} // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}}; return kShape; } @@ -615,11 +616,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts, Status GenerateMetadata(const CodegenOpts& opts, const CompileResult& compile_result, MetadataResult* metadata_result) { - std::unique_ptr program_shape; + std::unique_ptr program_shape; if (opts.gen_program_shape) { program_shape = - absl::make_unique(compile_result.program_shape); + absl::make_unique(compile_result.program_shape); // The parameter names are currently meaningless, and redundant with the // rest of our metadata, so clear them out to avoid confusion and save @@ -631,8 +632,8 @@ Status GenerateMetadata(const CodegenOpts& opts, // a shim that evaluates to nullptr, which is what we want. ProtobufToEmbed program_shape_protobuf{ - CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape", - program_shape.get()}; + CreateUniqueIdentifier(opts, "ProgramShapeProto"), + "xla::ProgramShapeProto", program_shape.get()}; ProtobufToEmbed hlo_profile_printer_data_protobuf{ CreateUniqueIdentifier(opts, "HloProfilePrinterData"), diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 90410c46a8e..9485e86b10e 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -57,7 +57,7 @@ struct MetadataResult { std::vector header_variable_decls; // program_shape_access_shim is a C++ expression that constructs the - // xla::ProgramShape instance for the CompileResult passed to + // xla::ProgramShapeProto instance for the CompileResult passed to // GenerateMetadata. string program_shape_access_shim; diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index bb288d23000..c1788ca32a1 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) { BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1), BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)}, 5, {})); - compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( - { - xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), - xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), - }, - xla::ShapeUtil::MakeTupleShape( - {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); + compile_result.program_shape = + xla::ShapeUtil::MakeProgramShape( + { + xla::ShapeUtil::MakeShape(xla::F32, {1, 2}), + xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), + }, + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})) + .ToProto(); compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index e4d8a02877c..a2cdab5d1a8 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -22,7 +22,7 @@ extern "C" void entry_point( void* result, const xla::ExecutableRunOptions* run_options, const void** args, void** temps, tensorflow::int64* profile_counters); -extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[]; +extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[]; namespace foo { @@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction { } // Shape of the args and results. - static const xla::ProgramShape* StaticProgramShape() { - static const xla::ProgramShape* kShape = []() { - xla::ProgramShape* proto = new xla::ProgramShape; - proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52); + static const xla::ProgramShapeProto* StaticProgramShape() { + static const xla::ProgramShapeProto* kShape = []() { + xla::ProgramShapeProto* proto = new xla::ProgramShapeProto; + proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52); return proto; }(); return kShape; diff --git a/tensorflow/compiler/aot/codegen_test_o.golden b/tensorflow/compiler/aot/codegen_test_o.golden index eb001c5d45b..ce8e5ec8c96 100644 Binary files a/tensorflow/compiler/aot/codegen_test_o.golden and b/tensorflow/compiler/aot/codegen_test_o.golden differ diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index 2b5f97b34cd..3bc99ef7e61 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -56,8 +56,8 @@ Status CompileXla(xla::CompileOnlyClient* client, return errors::Unknown("Couldn't get XLA program shape: ", pshape_or.status().error_message()); } - compile_result->program_shape = *pshape_or.ValueOrDie(); - xla::ProgramShape* pshape = &compile_result->program_shape; + compile_result->program_shape = pshape_or.ValueOrDie()->ToProto(); + xla::ProgramShapeProto* pshape = &compile_result->program_shape; std::vector arg_layouts; arg_layouts.reserve(pshape->parameters_size()); for (int i = 0; i < pshape->parameters_size(); ++i) { diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index e03c5b1aa77..ee7bb26fabd 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -33,9 +33,9 @@ namespace tfcompile { struct CompileResult { // Contains object file and meta-info. std::unique_ptr aot; - xla::ProgramShape program_shape; // Static shape of args and results. - string entry_point; // Name of generated function. - int pointer_size = 0; // Size of a pointer in bytes. + xla::ProgramShapeProto program_shape; // Static shape of args and results. + string entry_point; // Name of generated function. + int pointer_size = 0; // Size of a pointer in bytes. }; // CompileGraph compiles the graph_def into an object file containing a function diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index f10852c7850..711feed8f32 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -526,7 +526,7 @@ TEST(TFCompileTest, ProgramShape) { // muladd has the program shape defined. MatMulAndAddComp muladd; - const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape(); ASSERT_TRUE(muladd_shape != nullptr); ASSERT_EQ(muladd_shape->parameters_size(), 2); EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index 66206909a92..a1d359e97c4 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -26,7 +26,7 @@ limitations under the License. // Forward-declare, rather than include, to reduce code size for users that // never use this functionality. namespace xla { -class ProgramShape; +class ProgramShapeProto; class HloProfilePrinterData; } @@ -84,7 +84,7 @@ class XlaCompiledCpuFunction { void set_result_names(const char** result_names) { result_names_ = result_names; } - void set_program_shape(const xla::ProgramShape* program_shape) { + void set_program_shape(const xla::ProgramShapeProto* program_shape) { program_shape_ = program_shape; } const xla::HloProfilePrinterData* hlo_profile_printer_data() const { @@ -122,7 +122,7 @@ class XlaCompiledCpuFunction { const char** result_names_ = nullptr; // [Optional] Arg and result shapes. - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; // [Optional] Profile printer data. Null if profiling is disabled. const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; @@ -264,7 +264,7 @@ class XlaCompiledCpuFunction { // Returns the shape of the args and results. May return nullptr if the // program shape isn't available. - const xla::ProgramShape* ProgramShape() const { return program_shape_; } + const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; } bool hlo_profiling_enabled() const { return hlo_profile_printer_data_ != nullptr; @@ -305,7 +305,7 @@ class XlaCompiledCpuFunction { // Optional metadata. const char** arg_names_ = nullptr; const char** result_names_ = nullptr; - const xla::ProgramShape* program_shape_ = nullptr; + const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; }; diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index 86a78ee429e..fabbcd04fed 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile( jit->executable_ = std::move(executable); jit->buffer_infos_ = std::move(buffer_infos); jit->arg_index_table_ = std::move(arg_index_table); - jit->program_shape_ = std::move(program_shape); + jit->program_shape_ = + absl::make_unique(program_shape->ToProto()); jit->static_data_.set_raw_function(raw_function); jit->static_data_.set_buffer_infos(jit->buffer_infos_.data()); jit->static_data_.set_num_buffers(jit->buffer_infos_.size()); diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index d3c8f22a807..a5392057177 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction { std::vector arg_names_; std::vector result_names_; - // The backing data for the program shape. - std::unique_ptr program_shape_; + // The backing data for the program shape. The proto form of program shape is + // used because the program shape is serialized and embedded in the object + // file. + std::unique_ptr program_shape_; }; } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 6d49298a6f3..4496255d003 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -116,7 +116,7 @@ TEST(XlaJitCompiledCpuFunction, Sum) { // Check program shape. using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); - const xla::ProgramShape* program_shape = function.ProgramShape(); + const xla::ProgramShapeProto* program_shape = function.ProgramShape(); ASSERT_TRUE(program_shape != nullptr); ASSERT_EQ(program_shape->parameters_size(), 2); EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index d914e97b6bd..4360e085796 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -226,12 +226,14 @@ cc_library( "index_util.cc", "layout_util.cc", "primitive_util.cc", + "shape.cc", "shape_util.cc", ], hdrs = [ "index_util.h", "layout_util.h", "primitive_util.h", + "shape.h", "shape_util.h", ], visibility = ["//visibility:public"], @@ -254,6 +256,23 @@ cc_library( ], ) +tf_cc_test( + name = "shape_test", + srcs = ["shape_test.cc"], + deps = [ + ":shape_util", + ":status_macros", + ":test", + ":test_helpers", + ":types", + ":util", + ":xla_data_proto", + "//tensorflow/core:lib", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings", + ], +) + tf_cc_test( name = "shape_util_test", srcs = ["shape_util_test.cc"], diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 42da0ebf499..e1f3674c4f8 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -191,6 +191,7 @@ cc_library( hdrs = ["xla_computation.h"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index f508ffb9c95..8a33b3930f3 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -288,7 +288,8 @@ StatusOr XlaBuilder::Build(int64 root_id) { HloComputationProto entry; SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId()); - TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id)); + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id)); + *entry.mutable_program_shape() = program_shape.ToProto(); entry.set_root_id(root_id); for (auto& instruction : instructions_) { @@ -2372,7 +2373,7 @@ StatusOr XlaBuilder::BuildConstantSubGraph( SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator, GetNextId()); entry.set_root_id(root->id()); - ProgramShape* program_shape = entry.mutable_program_shape(); + ProgramShapeProto* program_shape = entry.mutable_program_shape(); *program_shape->mutable_result() = root->shape(); // We use std::set to keep the instruction ids in ascending order (which is diff --git a/tensorflow/compiler/xla/client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_computation.cc index c9870b65b91..f317892c125 100644 --- a/tensorflow/compiler/xla/client/xla_computation.cc +++ b/tensorflow/compiler/xla/client/xla_computation.cc @@ -25,7 +25,7 @@ namespace xla { StatusOr XlaComputation::GetProgramShape() const { TF_RET_CHECK(proto_.has_host_program_shape()); - return proto_.host_program_shape(); + return ProgramShape(proto_.host_program_shape()); } StatusOr> XlaComputation::Snapshot() const { diff --git a/tensorflow/compiler/xla/client/xla_computation.h b/tensorflow/compiler/xla/client/xla_computation.h index 71598ef8b29..3ccbfb28bd0 100644 --- a/tensorflow/compiler/xla/client/xla_computation.h +++ b/tensorflow/compiler/xla/client/xla_computation.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 6e0390763da..6c298e57252 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -21,6 +21,7 @@ limitations under the License. #include #include "absl/types/span.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" diff --git a/tensorflow/compiler/xla/python/local_computation_builder.cc b/tensorflow/compiler/xla/python/local_computation_builder.cc index 4d2a37cfac3..2768ed618d8 100644 --- a/tensorflow/compiler/xla/python/local_computation_builder.cc +++ b/tensorflow/compiler/xla/python/local_computation_builder.cc @@ -487,12 +487,13 @@ StatusOr LocalComputation::CompileForXrt( xrt::XLAComputation c; auto config = c.mutable_config(); - auto shapes = config->mutable_program_shape(); + ProgramShape shapes; for (auto& shape : argument_shapes) { - *shapes->add_parameters() = shape; + *shapes.add_parameters() = shape; } - TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape()); - LayoutUtil::SetToDefaultLayout(shapes); + TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape()); + LayoutUtil::SetToDefaultLayout(&shapes); + *config->mutable_program_shape() = shapes.ToProto(); auto snapshot = computation().Snapshot().ValueOrDie(); *c.mutable_hlo_snapshot() = *snapshot; diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index 67132274c0d..0237f16673e 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime( Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot)); } - const auto& program_shape = instance.computation.host_program_shape(); ExecutionOptions execution_options; *execution_options.mutable_debug_options() = debug_options; *execution_options.mutable_shape_with_output_layout() = *instance.result_layout; TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(program_shape, instance.argument_layouts, - &execution_options)); + CreateModuleConfig( + ProgramShape(instance.computation.host_program_shape()), + instance.argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN( std::unique_ptr hlo_module, diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 913d4c34b43..c62c935af78 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -205,7 +205,8 @@ message HloComputationProto { repeated HloInstructionProto instructions = 2; // The program shape (with layout) of this computation. - xla.ProgramShape program_shape = 4; + + xla.ProgramShapeProto program_shape = 4; // The id of this computation. int64 id = 5; @@ -297,7 +298,7 @@ message HloModuleProto { repeated HloComputationProto computations = 3; // The host program shape (with layout) of the entry computation. - xla.ProgramShape host_program_shape = 4; + xla.ProgramShapeProto host_program_shape = 4; // The id of this module. int64 id = 5; diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 65bd251dd86..d06c2207cb7 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const { proto.add_instructions()->Swap(&instruction_proto); } proto.set_root_id(root_instruction()->unique_id()); - *proto.mutable_program_shape() = ComputeProgramShape(); + *proto.mutable_program_shape() = ComputeProgramShape().ToProto(); return proto; } diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 59f44475df5..a01853fe1f2 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -240,7 +240,7 @@ HloModuleProto HloModule::ToProto() const { *proto.mutable_schedule() = schedule().ToProto().ValueOrDie(); } *proto.mutable_host_program_shape() = - entry_computation_layout().ComputeProgramShape(); + entry_computation_layout().ComputeProgramShape().ToProto(); *proto.mutable_input_output_alias() = input_output_alias_config().ToProto(); *proto.mutable_dynamic_parameter_binding() = dynamic_parameter_binding().ToProto(); @@ -371,7 +371,7 @@ StatusOr HloModule::CreateModuleConfigFromProto( << "No program shape found in the proto"; const auto& program_shape = module.host_program_shape(); - HloModuleConfig module_config(program_shape); + HloModuleConfig module_config(ProgramShape{program_shape}); module_config.set_debug_options(debug_options); // The module config is constructed with default layouts regardless of what is diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 2180ac845dd..5c105908f36 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -145,7 +145,7 @@ StatusOr> LocalService::CompileExecutable( const ExecutableBuildOptions& build_options) { const HloModuleProto& proto = computation.proto(); TF_RET_CHECK(proto.has_host_program_shape()); - const ProgramShape& program_shape = proto.host_program_shape(); + ProgramShape program_shape(proto.host_program_shape()); // Validate incoming layouts. if (argument_layouts.size() != program_shape.parameters_size()) { diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 13fd6bc0093..c4b0a5c0805 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -658,9 +658,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg, // replica 0. TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(request.computation().host_program_shape(), - replicated_arguments.front(), - request.execution_options())); + CreateModuleConfig( + ProgramShape{request.computation().host_program_shape()}, + replicated_arguments.front(), request.execution_options())); VLOG(3) << "ExecuteGraphParallel created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -824,7 +824,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) { [](const Shape& shape) { return &shape; }); TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(arg->computation().host_program_shape(), + CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()}, argument_shapes, &arg->execution_options())); VLOG(3) << "Compile created HloModuleConfig computation layout: " << module_config->entry_computation_layout().ToString(); @@ -1072,7 +1072,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg, "constant computation may not depend on any parameters."); } - ProgramShape program_shape = arg->computation().host_program_shape(); + ProgramShape program_shape(arg->computation().host_program_shape()); TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result())); if (arg->has_output_layout()) { TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape( @@ -1116,7 +1116,7 @@ Status Service::GetComputationGraphStats( return InvalidArgument("Program shape may not be empty."); } - HloModuleConfig config(arg->computation().host_program_shape()); + HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()}); config.set_debug_options(arg->debug_options()); TF_ASSIGN_OR_RETURN(std::unique_ptr module, CreateModuleFromProto(arg->computation(), config)); diff --git a/tensorflow/compiler/xla/shape.cc b/tensorflow/compiler/xla/shape.cc new file mode 100644 index 00000000000..d209389c745 --- /dev/null +++ b/tensorflow/compiler/xla/shape.cc @@ -0,0 +1,62 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/shape_util.h" + +namespace xla { + +ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) { + for (const Shape& shape : program_shape_proto.parameters()) { + *add_parameters() = shape; + } + *mutable_result() = program_shape_proto.result(); + for (const string& name : program_shape_proto.parameter_names()) { + add_parameter_names(name); + } +} + +ProgramShapeProto ProgramShape::ToProto() const { + ProgramShapeProto proto; + for (const Shape& shape : parameters()) { + *proto.add_parameters() = shape; + } + *proto.mutable_result() = result(); + for (const string& name : parameter_names()) { + proto.add_parameter_names(name); + } + return proto; +} + +string ProgramShape::ToString() const { + std::vector parameter_strings(parameters_size()); + for (int i = 0; i < parameters_size(); ++i) { + parameter_strings[i] = absl::StrCat( + i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ", + ShapeUtil::HumanString(parameters(i))); + } + return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ", + ShapeUtil::HumanString(result())); +} + +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) { + out << program_shape.ToString() << "\n"; + return out; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/shape.h b/tensorflow/compiler/xla/shape.h new file mode 100644 index 00000000000..c3aecb17368 --- /dev/null +++ b/tensorflow/compiler/xla/shape.h @@ -0,0 +1,108 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_ +#define TENSORFLOW_COMPILER_XLA_SHAPE_H_ + +#include +#include + +#include "absl/types/optional.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// Shape of the parameters and output of an XLA computation. This is analogous +// to a traditional function signature. +class ProgramShape { + public: + ProgramShape() = default; + + // Creates a ProgramShape from a ProgramShapeProto protobuf. + explicit ProgramShape(const ProgramShapeProto& program_shape_proto); + + // Returns a proto representation of the object. + ProgramShapeProto ToProto() const; + + string ToString() const; + + // The following methods mirror the protobuf generated code interface for the + // message ProgramShapeProto. This enabled easy migration of this data + // structure from a proto to a proper C++ class. + // TODO(b/29771030): Replace or augment these methods with a more ergonomic + // interface. + + // Methods for accessing and manipulating the Shape of the parameters. + int parameters_size() const { return parameters_.size(); } + const Shape& parameters(int index) const { return parameters_.at(index); } + Shape* mutable_parameters(int index) { return ¶meters_.at(index); } + Shape* add_parameters() { + parameters_.emplace_back(); + return ¶meters_.back(); + } + void clear_parameters() { parameters_.clear(); } + const std::vector& parameters() const { return parameters_; } + std::vector* mutable_parameters() { return ¶meters_; } + + // Methods for accessing and manipulating the Shape of the result. + const Shape& result() const { return result_; } + Shape* mutable_result() { return &result_; } + void clear_result() { result_.Clear(); } + + // Methods for accessing and manipulating the names of the parameters. + int parameter_names_size() const { return parameter_names_.size(); } + const string& parameter_names(int index) const { + return parameter_names_.at(index); + } + void set_parameter_names(int index, const string& value) { + parameter_names_.at(index) = value; + } + string* mutable_parameter_names(int index) { + return ¶meter_names_.at(index); + } + void add_parameter_names(const string& value) { + parameter_names_.push_back(value); + } + string* add_parameter_names() { + parameter_names_.push_back(""); + return ¶meter_names_.back(); + } + void clear_parameter_names() { parameter_names_.clear(); } + const std::vector& parameter_names() const { + return parameter_names_; + } + std::vector* mutable_parameter_names() { return ¶meter_names_; } + + string ShortDebugString() const { return ToProto().ShortDebugString(); } + string DebugString() const { return ToProto().DebugString(); } + + private: + // The shapes of the parameters of the computation represented by this object. + std::vector parameters_; + + // The names of the parameters of the computation represented by this object. + std::vector parameter_names_; + + // The shape of the result of the computation represented by this object. + Shape result_; +}; + +std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape); + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_ diff --git a/tensorflow/compiler/xla/shape_test.cc b/tensorflow/compiler/xla/shape_test.cc new file mode 100644 index 00000000000..cc3a5eb1d64 --- /dev/null +++ b/tensorflow/compiler/xla/shape_test.cc @@ -0,0 +1,112 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/shape.h" + +#include +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace { + +TEST(ShapeTest, ProgramShapeToFromProto) { + ProgramShape program_shape; + *program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3}); + *program_shape.add_parameters() = ShapeUtil::MakeTokenShape(); + *program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {}); + *program_shape.add_parameters() = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(S32, {}), + ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}), + ShapeUtil::MakeShape(F32, {42, 42})}); + + *program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7}); + + program_shape.add_parameter_names("foo"); + program_shape.add_parameter_names("bar"); + program_shape.add_parameter_names("baz"); + program_shape.add_parameter_names("qux qux"); + + // Create a copy of the program shape by round-tripping through a proto. + ProgramShape program_shape_copy(program_shape.ToProto()); + ASSERT_EQ(program_shape.parameters_size(), + program_shape_copy.parameters_size()); + for (int i = 0; i < program_shape.parameters_size(); ++i) { + EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i), + program_shape_copy.parameters(i))); + } + + EXPECT_TRUE( + ShapeUtil::Equal(program_shape.result(), program_shape_copy.result())); + + ASSERT_EQ(program_shape.parameter_names_size(), + program_shape_copy.parameter_names_size()); + for (int i = 0; i < program_shape.parameter_names_size(); ++i) { + EXPECT_EQ(program_shape.parameter_names(i), + program_shape_copy.parameter_names(i)); + } +} + +TEST(ShapeTest, ProgramShapeToString) { + Shape opaque = ShapeUtil::MakeOpaqueShape(); + Shape token = ShapeUtil::MakeTokenShape(); + Shape scalar = ShapeUtil::MakeShape(F32, {}); + Shape matrix = ShapeUtil::MakeShape(U32, {1, 2}); + Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1}); + Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2}); + Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token}); + + ProgramShape prog = ShapeUtil::MakeProgramShape( + {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); + EXPECT_EQ( + "((unknown): opaque[], " + "(unknown): f32[], " + "(unknown): u32[1,2], " + "(unknown): s32[3,4], " + "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " + "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + ShapeUtil::HumanString(prog)); + + prog.add_parameter_names("arg0"); + prog.add_parameter_names("scalar"); + prog.add_parameter_names("matrix"); + prog.add_parameter_names("matrix2"); + prog.add_parameter_names("tuple"); + prog.add_parameter_names("nested_tuple"); + EXPECT_EQ( + "(arg0: opaque[], " + "scalar: f32[], " + "matrix: u32[1,2], " + "matrix2: s32[3,4], " + "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " + "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " + "token[])) " + "-> " + "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", + ShapeUtil::HumanString(prog)); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 7d011bfc658..b05ec209cc6 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -563,6 +563,20 @@ StatusOr StringToPrimitiveType(const string& name) { HumanString(program_shape.result())); } +/* static */ string ShapeUtil::HumanString( + const ProgramShapeProto& program_shape_proto) { + std::vector parameters; + for (auto& shape : program_shape_proto.parameters()) { + const int i = parameters.size(); + parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size() + ? program_shape_proto.parameter_names(i) + : "(unknown)", + ": ", HumanString(shape))); + } + return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ", + HumanString(program_shape_proto.result())); +} + namespace { // Parses shapes with simple recursive descent structure -- consumes from the // front of s and passes that view recursively as required. diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 7f72e57d008..3796c5be5d7 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -28,6 +28,7 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/primitive_util.h" +#include "tensorflow/compiler/xla/shape.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -239,6 +240,7 @@ class ShapeUtil { // // (param_name: f32[42x12], ...) -> f32[24x42] static string HumanString(const ProgramShape& program_shape); + static string HumanString(const ProgramShapeProto& program_shape_proto); // Parses a ShapeUtil::HumanString-format shape string back into a shape // object. diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 11b493323cb..ce6330a0dc3 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -575,37 +575,6 @@ TEST(ShapeUtilTest, HumanString) { "((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, " "token[])", ShapeUtil::HumanStringWithLayout(nested_tuple)); - - ProgramShape prog = ShapeUtil::MakeProgramShape( - {opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple); - EXPECT_EQ( - "((unknown): opaque[], " - "(unknown): f32[], " - "(unknown): u32[1,2], " - "(unknown): s32[3,4], " - "(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), " - "(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); - - prog.add_parameter_names("arg0"); - prog.add_parameter_names("scalar"); - prog.add_parameter_names("matrix"); - prog.add_parameter_names("matrix2"); - prog.add_parameter_names("tuple"); - prog.add_parameter_names("nested_tuple"); - EXPECT_EQ( - "(arg0: opaque[], " - "scalar: f32[], " - "matrix: u32[1,2], " - "matrix2: s32[3,4], " - "tuple: (opaque[], f32[], u32[1,2], s32[3,4]), " - "nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], " - "token[])) " - "-> " - "((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])", - ShapeUtil::HumanString(prog)); } TEST(ShapeUtilTest, ForEachSubshapeArray) { diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 5cf87e565bf..34c7dc7c464 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = @@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. std::unique_ptr x_data = @@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { client_->GetComputationShape(computation).ConsumeValueOrDie(); std::unique_ptr replayed_shape = client_->GetComputationShape(replayed).ConsumeValueOrDie(); - ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape)); + ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(), + replayed_shape->ToProto())); // Run it. Literal literal = diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 683ccc40f16..27ef86ab2eb 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -183,7 +183,7 @@ message Shape { // Shape of the parameters and output of a computation (like a traditional // function signature). -message ProgramShape { +message ProgramShapeProto { repeated Shape parameters = 1; Shape result = 2; repeated string parameter_names = 3; diff --git a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc index dc62cf7a6b2..db43aeaafec 100644 --- a/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc +++ b/tensorflow/compiler/xrt/kernels/xrt_compile_ops.cc @@ -174,11 +174,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) { ctx->set_output(0, handle_output); xla::LocalExecutable* executable = entry->get().get_executable(); - xla::ProgramShape program_shape = executable->executable() - ->module() - .config() - .entry_computation_layout() - .ComputeProgramShape(); + xla::ProgramShapeProto program_shape = executable->executable() + ->module() + .config() + .entry_computation_layout() + .ComputeProgramShape() + .ToProto(); Tensor program_shape_output(DT_STRING, TensorShape({1})); program_shape_output.vec()(0) = program_shape.SerializeAsString(); ctx->set_output(1, program_shape_output); diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc index 25464b5554d..7e73db98f78 100644 --- a/tensorflow/compiler/xrt/tests/raw_api_test.cc +++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc @@ -411,7 +411,7 @@ TEST(RawApiTest, CompileAndExecute) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -465,7 +465,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) { auto expected = xla::LiteralUtil::CreateR1({27.0f, 21.0f}); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); } @@ -510,7 +510,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(), {c_handle.program_shape}, {release}, &outputs)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 1); @@ -520,7 +520,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) { << xla::ShapeUtil::HumanStringWithLayout(program_shape.result()); xla::ProgramShape xla_program_shape = - XlaCompiledProgramShape(xla_computation, *shapes); + XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes)); EXPECT_TRUE(xla::LayoutUtil::Equal( xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(), xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0}) @@ -739,7 +739,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) { auto expected = xla::LiteralUtil::CreateR0(15123899); EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response)); - xla::ProgramShape program_shape; + xla::ProgramShapeProto program_shape; EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec()(0))); EXPECT_EQ(program_shape.parameters_size(), 2); EXPECT_TRUE( diff --git a/tensorflow/compiler/xrt/xrt.proto b/tensorflow/compiler/xrt/xrt.proto index 6ab77fbaaf0..ae44f717409 100644 --- a/tensorflow/compiler/xrt/xrt.proto +++ b/tensorflow/compiler/xrt/xrt.proto @@ -36,11 +36,11 @@ message XLAComputationConfig { tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3; // The arg/result shapes for the whole computation. - xla.ProgramShape program_shape = 4; + xla.ProgramShapeProto program_shape = 4; // The arg/result shapes for each core of a model-parallel // computation. per_core_args_and_result_shapes is optional for a // single-core computation. - repeated xla.ProgramShape per_core_program_shape = 5; + repeated xla.ProgramShapeProto per_core_program_shape = 5; // Describes how replicated computation instances should be assigned to // devices. There are num_cores_per_replica computations, and each one will be // sent and executed to the set of replica device numbers described in the