Replace ProgramShape proto with a C++ class.
Rename the protobuf message ProgramShape to ProgramShapeProto and create a new ProgramShape C++ class with an interface which mirrors the protobuf generated code interface. This CL is a step toward replacing Shape proto with a C++ class. ProgramShape needs to be migrated first because ProgramShape contains Shapes. PiperOrigin-RevId: 222435461
This commit is contained in:
parent
f6ce9fd485
commit
f22eec10b6
@ -164,7 +164,8 @@ string RewriteWithName(const string& name, string code,
|
||||
}
|
||||
|
||||
// Generate methods for args (inputs).
|
||||
Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
||||
Status GenArgMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShapeProto& ps,
|
||||
const CompileResult& compile_result, string* methods) {
|
||||
size_t num_args = ps.parameters_size();
|
||||
if (config.feed_size() != num_args) {
|
||||
@ -204,7 +205,7 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps,
|
||||
|
||||
// Generate methods for results (outputs).
|
||||
Status GenResultMethods(const tf2xla::Config& config,
|
||||
const xla::ProgramShape& ps, string* methods) {
|
||||
const xla::ProgramShapeProto& ps, string* methods) {
|
||||
if (ps.result().element_type() != xla::TUPLE) {
|
||||
// The XlaCompiler we use to build the xla computation always generates a
|
||||
// tuple result, and we rely on this to simplify code generation.
|
||||
@ -336,7 +337,7 @@ Status GenerateHeader(const CodegenOpts& opts, const tf2xla::Config& config,
|
||||
ExtractEntryParamBufferInfos(buffer_infos);
|
||||
std::vector<BufferInfo> buffer_infos_for_temps =
|
||||
ExtractTempBufferInfos(buffer_infos);
|
||||
const xla::ProgramShape& ps = compile_result.program_shape;
|
||||
const xla::ProgramShapeProto& ps = compile_result.program_shape;
|
||||
string methods_arg, methods_result;
|
||||
TF_RETURN_IF_ERROR(GenArgMethods(config, ps, compile_result, &methods_arg));
|
||||
TF_RETURN_IF_ERROR(GenResultMethods(config, ps, &methods_result));
|
||||
@ -548,8 +549,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
|
||||
static const char** StaticResultNames() {{RESULT_NAMES_CODE}}
|
||||
|
||||
// Shape of the args and results.
|
||||
static const xla::ProgramShape* StaticProgramShape() {
|
||||
static const xla::ProgramShape* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
||||
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||
static const xla::ProgramShapeProto* kShape = {{PROGRAM_SHAPE_SHIM_EXPRESSION}};
|
||||
return kShape;
|
||||
}
|
||||
|
||||
@ -615,11 +616,11 @@ static string CreateUniqueIdentifier(const CodegenOpts& opts,
|
||||
Status GenerateMetadata(const CodegenOpts& opts,
|
||||
const CompileResult& compile_result,
|
||||
MetadataResult* metadata_result) {
|
||||
std::unique_ptr<xla::ProgramShape> program_shape;
|
||||
std::unique_ptr<xla::ProgramShapeProto> program_shape;
|
||||
|
||||
if (opts.gen_program_shape) {
|
||||
program_shape =
|
||||
absl::make_unique<xla::ProgramShape>(compile_result.program_shape);
|
||||
absl::make_unique<xla::ProgramShapeProto>(compile_result.program_shape);
|
||||
|
||||
// The parameter names are currently meaningless, and redundant with the
|
||||
// rest of our metadata, so clear them out to avoid confusion and save
|
||||
@ -631,8 +632,8 @@ Status GenerateMetadata(const CodegenOpts& opts,
|
||||
// a shim that evaluates to nullptr, which is what we want.
|
||||
|
||||
ProtobufToEmbed program_shape_protobuf{
|
||||
CreateUniqueIdentifier(opts, "ProgramShape"), "xla::ProgramShape",
|
||||
program_shape.get()};
|
||||
CreateUniqueIdentifier(opts, "ProgramShapeProto"),
|
||||
"xla::ProgramShapeProto", program_shape.get()};
|
||||
|
||||
ProtobufToEmbed hlo_profile_printer_data_protobuf{
|
||||
CreateUniqueIdentifier(opts, "HloProfilePrinterData"),
|
||||
|
@ -57,7 +57,7 @@ struct MetadataResult {
|
||||
std::vector<string> header_variable_decls;
|
||||
|
||||
// program_shape_access_shim is a C++ expression that constructs the
|
||||
// xla::ProgramShape instance for the CompileResult passed to
|
||||
// xla::ProgramShapeProto instance for the CompileResult passed to
|
||||
// GenerateMetadata.
|
||||
string program_shape_access_shim;
|
||||
|
||||
|
@ -181,13 +181,15 @@ TEST(CodegenTest, Golden) {
|
||||
BufferInfo::MakeEntryParameter(/*size=*/96, /*param_number=*/1),
|
||||
BufferInfo::MakeTempBuffer(3), BufferInfo::MakeTempBuffer(120)},
|
||||
5, {}));
|
||||
compile_result.program_shape = xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||
},
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}));
|
||||
compile_result.program_shape =
|
||||
xla::ShapeUtil::MakeProgramShape(
|
||||
{
|
||||
xla::ShapeUtil::MakeShape(xla::F32, {1, 2}),
|
||||
xla::ShapeUtil::MakeShape(xla::S64, {3, 4}),
|
||||
},
|
||||
xla::ShapeUtil::MakeTupleShape(
|
||||
{xla::ShapeUtil::MakeShape(xla::U32, {5, 6})}))
|
||||
.ToProto();
|
||||
compile_result.entry_point = "entry_point";
|
||||
compile_result.pointer_size = 8;
|
||||
|
||||
|
@ -22,7 +22,7 @@ extern "C" void entry_point(
|
||||
void* result, const xla::ExecutableRunOptions* run_options,
|
||||
const void** args, void** temps, tensorflow::int64* profile_counters);
|
||||
|
||||
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[];
|
||||
extern "C" char __tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[];
|
||||
|
||||
|
||||
namespace foo {
|
||||
@ -253,10 +253,10 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
|
||||
}
|
||||
|
||||
// Shape of the args and results.
|
||||
static const xla::ProgramShape* StaticProgramShape() {
|
||||
static const xla::ProgramShape* kShape = []() {
|
||||
xla::ProgramShape* proto = new xla::ProgramShape;
|
||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShape_protobuf_array_contents[0], 52);
|
||||
static const xla::ProgramShapeProto* StaticProgramShape() {
|
||||
static const xla::ProgramShapeProto* kShape = []() {
|
||||
xla::ProgramShapeProto* proto = new xla::ProgramShapeProto;
|
||||
proto->ParseFromArray(&__tfcompile_foo_bar_MyClass_ProgramShapeProto_protobuf_array_contents[0], 52);
|
||||
return proto;
|
||||
}();
|
||||
return kShape;
|
||||
|
Binary file not shown.
@ -56,8 +56,8 @@ Status CompileXla(xla::CompileOnlyClient* client,
|
||||
return errors::Unknown("Couldn't get XLA program shape: ",
|
||||
pshape_or.status().error_message());
|
||||
}
|
||||
compile_result->program_shape = *pshape_or.ValueOrDie();
|
||||
xla::ProgramShape* pshape = &compile_result->program_shape;
|
||||
compile_result->program_shape = pshape_or.ValueOrDie()->ToProto();
|
||||
xla::ProgramShapeProto* pshape = &compile_result->program_shape;
|
||||
std::vector<const xla::Shape*> arg_layouts;
|
||||
arg_layouts.reserve(pshape->parameters_size());
|
||||
for (int i = 0; i < pshape->parameters_size(); ++i) {
|
||||
|
@ -33,9 +33,9 @@ namespace tfcompile {
|
||||
struct CompileResult {
|
||||
// Contains object file and meta-info.
|
||||
std::unique_ptr<xla::cpu::CpuAotCompilationResult> aot;
|
||||
xla::ProgramShape program_shape; // Static shape of args and results.
|
||||
string entry_point; // Name of generated function.
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
xla::ProgramShapeProto program_shape; // Static shape of args and results.
|
||||
string entry_point; // Name of generated function.
|
||||
int pointer_size = 0; // Size of a pointer in bytes.
|
||||
};
|
||||
|
||||
// CompileGraph compiles the graph_def into an object file containing a function
|
||||
|
@ -526,7 +526,7 @@ TEST(TFCompileTest, ProgramShape) {
|
||||
|
||||
// muladd has the program shape defined.
|
||||
MatMulAndAddComp muladd;
|
||||
const xla::ProgramShape* muladd_shape = muladd.ProgramShape();
|
||||
const xla::ProgramShapeProto* muladd_shape = muladd.ProgramShape();
|
||||
ASSERT_TRUE(muladd_shape != nullptr);
|
||||
ASSERT_EQ(muladd_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2));
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
// Forward-declare, rather than include, to reduce code size for users that
|
||||
// never use this functionality.
|
||||
namespace xla {
|
||||
class ProgramShape;
|
||||
class ProgramShapeProto;
|
||||
class HloProfilePrinterData;
|
||||
}
|
||||
|
||||
@ -84,7 +84,7 @@ class XlaCompiledCpuFunction {
|
||||
void set_result_names(const char** result_names) {
|
||||
result_names_ = result_names;
|
||||
}
|
||||
void set_program_shape(const xla::ProgramShape* program_shape) {
|
||||
void set_program_shape(const xla::ProgramShapeProto* program_shape) {
|
||||
program_shape_ = program_shape;
|
||||
}
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data() const {
|
||||
@ -122,7 +122,7 @@ class XlaCompiledCpuFunction {
|
||||
const char** result_names_ = nullptr;
|
||||
|
||||
// [Optional] Arg and result shapes.
|
||||
const xla::ProgramShape* program_shape_ = nullptr;
|
||||
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
||||
|
||||
// [Optional] Profile printer data. Null if profiling is disabled.
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||
@ -264,7 +264,7 @@ class XlaCompiledCpuFunction {
|
||||
|
||||
// Returns the shape of the args and results. May return nullptr if the
|
||||
// program shape isn't available.
|
||||
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
|
||||
const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
|
||||
|
||||
bool hlo_profiling_enabled() const {
|
||||
return hlo_profile_printer_data_ != nullptr;
|
||||
@ -305,7 +305,7 @@ class XlaCompiledCpuFunction {
|
||||
// Optional metadata.
|
||||
const char** arg_names_ = nullptr;
|
||||
const char** result_names_ = nullptr;
|
||||
const xla::ProgramShape* program_shape_ = nullptr;
|
||||
const xla::ProgramShapeProto* program_shape_ = nullptr;
|
||||
const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
|
||||
};
|
||||
|
||||
|
@ -133,7 +133,8 @@ XlaJitCompiledCpuFunction::Compile(
|
||||
jit->executable_ = std::move(executable);
|
||||
jit->buffer_infos_ = std::move(buffer_infos);
|
||||
jit->arg_index_table_ = std::move(arg_index_table);
|
||||
jit->program_shape_ = std::move(program_shape);
|
||||
jit->program_shape_ =
|
||||
absl::make_unique<xla::ProgramShapeProto>(program_shape->ToProto());
|
||||
jit->static_data_.set_raw_function(raw_function);
|
||||
jit->static_data_.set_buffer_infos(jit->buffer_infos_.data());
|
||||
jit->static_data_.set_num_buffers(jit->buffer_infos_.size());
|
||||
|
@ -80,8 +80,10 @@ class XlaJitCompiledCpuFunction {
|
||||
std::vector<const char*> arg_names_;
|
||||
std::vector<const char*> result_names_;
|
||||
|
||||
// The backing data for the program shape.
|
||||
std::unique_ptr<const xla::ProgramShape> program_shape_;
|
||||
// The backing data for the program shape. The proto form of program shape is
|
||||
// used because the program shape is serialized and embedded in the object
|
||||
// file.
|
||||
std::unique_ptr<const xla::ProgramShapeProto> program_shape_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -116,7 +116,7 @@ TEST(XlaJitCompiledCpuFunction, Sum) {
|
||||
// Check program shape.
|
||||
using xla::ShapeUtil;
|
||||
const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {});
|
||||
const xla::ProgramShape* program_shape = function.ProgramShape();
|
||||
const xla::ProgramShapeProto* program_shape = function.ProgramShape();
|
||||
ASSERT_TRUE(program_shape != nullptr);
|
||||
ASSERT_EQ(program_shape->parameters_size(), 2);
|
||||
EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32));
|
||||
|
@ -226,12 +226,14 @@ cc_library(
|
||||
"index_util.cc",
|
||||
"layout_util.cc",
|
||||
"primitive_util.cc",
|
||||
"shape.cc",
|
||||
"shape_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"index_util.h",
|
||||
"layout_util.h",
|
||||
"primitive_util.h",
|
||||
"shape.h",
|
||||
"shape_util.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
@ -254,6 +256,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "shape_test",
|
||||
srcs = ["shape_test.cc"],
|
||||
deps = [
|
||||
":shape_util",
|
||||
":status_macros",
|
||||
":test",
|
||||
":test_helpers",
|
||||
":types",
|
||||
":util",
|
||||
":xla_data_proto",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test_main",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "shape_util_test",
|
||||
srcs = ["shape_util_test.cc"],
|
||||
|
@ -191,6 +191,7 @@ cc_library(
|
||||
hdrs = ["xla_computation.h"],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
|
@ -288,7 +288,8 @@ StatusOr<XlaComputation> XlaBuilder::Build(int64 root_id) {
|
||||
|
||||
HloComputationProto entry;
|
||||
SetProtoIdAndName(&entry, name_, kNameSeparator, GetNextId());
|
||||
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(), GetProgramShape(root_id));
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape, GetProgramShape(root_id));
|
||||
*entry.mutable_program_shape() = program_shape.ToProto();
|
||||
entry.set_root_id(root_id);
|
||||
|
||||
for (auto& instruction : instructions_) {
|
||||
@ -2372,7 +2373,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
|
||||
SetProtoIdAndName(&entry, StrCat(name_, "_compute_constant"), kNameSeparator,
|
||||
GetNextId());
|
||||
entry.set_root_id(root->id());
|
||||
ProgramShape* program_shape = entry.mutable_program_shape();
|
||||
ProgramShapeProto* program_shape = entry.mutable_program_shape();
|
||||
*program_shape->mutable_result() = root->shape();
|
||||
|
||||
// We use std::set to keep the instruction ids in ascending order (which is
|
||||
|
@ -25,7 +25,7 @@ namespace xla {
|
||||
|
||||
StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
|
||||
TF_RET_CHECK(proto_.has_host_program_shape());
|
||||
return proto_.host_program_shape();
|
||||
return ProgramShape(proto_.host_program_shape());
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
|
||||
|
@ -19,6 +19,7 @@ limitations under the License.
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo.pb.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include <string>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
@ -487,12 +487,13 @@ StatusOr<CompiledXrtComputation*> LocalComputation::CompileForXrt(
|
||||
|
||||
xrt::XLAComputation c;
|
||||
auto config = c.mutable_config();
|
||||
auto shapes = config->mutable_program_shape();
|
||||
ProgramShape shapes;
|
||||
for (auto& shape : argument_shapes) {
|
||||
*shapes->add_parameters() = shape;
|
||||
*shapes.add_parameters() = shape;
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(*shapes->mutable_result(), GetReturnValueShape());
|
||||
LayoutUtil::SetToDefaultLayout(shapes);
|
||||
TF_ASSIGN_OR_RETURN(*shapes.mutable_result(), GetReturnValueShape());
|
||||
LayoutUtil::SetToDefaultLayout(&shapes);
|
||||
*config->mutable_program_shape() = shapes.ToProto();
|
||||
auto snapshot = computation().Snapshot().ValueOrDie();
|
||||
*c.mutable_hlo_snapshot() = *snapshot;
|
||||
|
||||
|
@ -86,15 +86,15 @@ CompileOnlyService::CompileAheadOfTime(
|
||||
Executable::DumpToDirectory(per_host_path, filename, hlo_snapshot));
|
||||
}
|
||||
|
||||
const auto& program_shape = instance.computation.host_program_shape();
|
||||
ExecutionOptions execution_options;
|
||||
*execution_options.mutable_debug_options() = debug_options;
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
*instance.result_layout;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, instance.argument_layouts,
|
||||
&execution_options));
|
||||
CreateModuleConfig(
|
||||
ProgramShape(instance.computation.host_program_shape()),
|
||||
instance.argument_layouts, &execution_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModule> hlo_module,
|
||||
|
@ -205,7 +205,8 @@ message HloComputationProto {
|
||||
repeated HloInstructionProto instructions = 2;
|
||||
|
||||
// The program shape (with layout) of this computation.
|
||||
xla.ProgramShape program_shape = 4;
|
||||
|
||||
xla.ProgramShapeProto program_shape = 4;
|
||||
|
||||
// The id of this computation.
|
||||
int64 id = 5;
|
||||
@ -297,7 +298,7 @@ message HloModuleProto {
|
||||
repeated HloComputationProto computations = 3;
|
||||
|
||||
// The host program shape (with layout) of the entry computation.
|
||||
xla.ProgramShape host_program_shape = 4;
|
||||
xla.ProgramShapeProto host_program_shape = 4;
|
||||
|
||||
// The id of this module.
|
||||
int64 id = 5;
|
||||
|
@ -499,7 +499,7 @@ HloComputationProto HloComputation::ToProto() const {
|
||||
proto.add_instructions()->Swap(&instruction_proto);
|
||||
}
|
||||
proto.set_root_id(root_instruction()->unique_id());
|
||||
*proto.mutable_program_shape() = ComputeProgramShape();
|
||||
*proto.mutable_program_shape() = ComputeProgramShape().ToProto();
|
||||
return proto;
|
||||
}
|
||||
|
||||
|
@ -240,7 +240,7 @@ HloModuleProto HloModule::ToProto() const {
|
||||
*proto.mutable_schedule() = schedule().ToProto().ValueOrDie();
|
||||
}
|
||||
*proto.mutable_host_program_shape() =
|
||||
entry_computation_layout().ComputeProgramShape();
|
||||
entry_computation_layout().ComputeProgramShape().ToProto();
|
||||
*proto.mutable_input_output_alias() = input_output_alias_config().ToProto();
|
||||
*proto.mutable_dynamic_parameter_binding() =
|
||||
dynamic_parameter_binding().ToProto();
|
||||
@ -371,7 +371,7 @@ StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||
<< "No program shape found in the proto";
|
||||
const auto& program_shape = module.host_program_shape();
|
||||
|
||||
HloModuleConfig module_config(program_shape);
|
||||
HloModuleConfig module_config(ProgramShape{program_shape});
|
||||
module_config.set_debug_options(debug_options);
|
||||
|
||||
// The module config is constructed with default layouts regardless of what is
|
||||
|
@ -145,7 +145,7 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
const ExecutableBuildOptions& build_options) {
|
||||
const HloModuleProto& proto = computation.proto();
|
||||
TF_RET_CHECK(proto.has_host_program_shape());
|
||||
const ProgramShape& program_shape = proto.host_program_shape();
|
||||
ProgramShape program_shape(proto.host_program_shape());
|
||||
|
||||
// Validate incoming layouts.
|
||||
if (argument_layouts.size() != program_shape.parameters_size()) {
|
||||
|
@ -658,9 +658,9 @@ Status Service::ExecuteGraphParallel(const ExecuteGraphParallelRequest* arg,
|
||||
// replica 0.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(request.computation().host_program_shape(),
|
||||
replicated_arguments.front(),
|
||||
request.execution_options()));
|
||||
CreateModuleConfig(
|
||||
ProgramShape{request.computation().host_program_shape()},
|
||||
replicated_arguments.front(), request.execution_options()));
|
||||
VLOG(3)
|
||||
<< "ExecuteGraphParallel created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
@ -824,7 +824,7 @@ Status Service::Compile(const CompileRequest* arg, CompileResponse* result) {
|
||||
[](const Shape& shape) { return &shape; });
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(arg->computation().host_program_shape(),
|
||||
CreateModuleConfig(ProgramShape{arg->computation().host_program_shape()},
|
||||
argument_shapes, &arg->execution_options()));
|
||||
VLOG(3) << "Compile created HloModuleConfig computation layout: "
|
||||
<< module_config->entry_computation_layout().ToString();
|
||||
@ -1072,7 +1072,7 @@ Status Service::ComputeConstantGraph(const ComputeConstantGraphRequest* arg,
|
||||
"constant computation may not depend on any parameters.");
|
||||
}
|
||||
|
||||
ProgramShape program_shape = arg->computation().host_program_shape();
|
||||
ProgramShape program_shape(arg->computation().host_program_shape());
|
||||
TF_DCHECK_OK(ShapeUtil::ValidateShape(program_shape.result()));
|
||||
if (arg->has_output_layout()) {
|
||||
TF_RETURN_IF_ERROR(LayoutUtil::ValidateLayoutForShape(
|
||||
@ -1116,7 +1116,7 @@ Status Service::GetComputationGraphStats(
|
||||
return InvalidArgument("Program shape may not be empty.");
|
||||
}
|
||||
|
||||
HloModuleConfig config(arg->computation().host_program_shape());
|
||||
HloModuleConfig config(ProgramShape{arg->computation().host_program_shape()});
|
||||
config.set_debug_options(arg->debug_options());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModule> module,
|
||||
CreateModuleFromProto(arg->computation(), config));
|
||||
|
62
tensorflow/compiler/xla/shape.cc
Normal file
62
tensorflow/compiler/xla/shape.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
ProgramShape::ProgramShape(const ProgramShapeProto& program_shape_proto) {
|
||||
for (const Shape& shape : program_shape_proto.parameters()) {
|
||||
*add_parameters() = shape;
|
||||
}
|
||||
*mutable_result() = program_shape_proto.result();
|
||||
for (const string& name : program_shape_proto.parameter_names()) {
|
||||
add_parameter_names(name);
|
||||
}
|
||||
}
|
||||
|
||||
ProgramShapeProto ProgramShape::ToProto() const {
|
||||
ProgramShapeProto proto;
|
||||
for (const Shape& shape : parameters()) {
|
||||
*proto.add_parameters() = shape;
|
||||
}
|
||||
*proto.mutable_result() = result();
|
||||
for (const string& name : parameter_names()) {
|
||||
proto.add_parameter_names(name);
|
||||
}
|
||||
return proto;
|
||||
}
|
||||
|
||||
string ProgramShape::ToString() const {
|
||||
std::vector<string> parameter_strings(parameters_size());
|
||||
for (int i = 0; i < parameters_size(); ++i) {
|
||||
parameter_strings[i] = absl::StrCat(
|
||||
i < parameter_names_size() ? parameter_names(i) : "(unknown)", ": ",
|
||||
ShapeUtil::HumanString(parameters(i)));
|
||||
}
|
||||
return absl::StrCat("(", absl::StrJoin(parameter_strings, ", "), ") -> ",
|
||||
ShapeUtil::HumanString(result()));
|
||||
}
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape) {
|
||||
out << program_shape.ToString() << "\n";
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace xla
|
108
tensorflow/compiler/xla/shape.h
Normal file
108
tensorflow/compiler/xla/shape.h
Normal file
@ -0,0 +1,108 @@
|
||||
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
||||
#define TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace xla {
|
||||
|
||||
// Shape of the parameters and output of an XLA computation. This is analogous
|
||||
// to a traditional function signature.
|
||||
class ProgramShape {
|
||||
public:
|
||||
ProgramShape() = default;
|
||||
|
||||
// Creates a ProgramShape from a ProgramShapeProto protobuf.
|
||||
explicit ProgramShape(const ProgramShapeProto& program_shape_proto);
|
||||
|
||||
// Returns a proto representation of the object.
|
||||
ProgramShapeProto ToProto() const;
|
||||
|
||||
string ToString() const;
|
||||
|
||||
// The following methods mirror the protobuf generated code interface for the
|
||||
// message ProgramShapeProto. This enabled easy migration of this data
|
||||
// structure from a proto to a proper C++ class.
|
||||
// TODO(b/29771030): Replace or augment these methods with a more ergonomic
|
||||
// interface.
|
||||
|
||||
// Methods for accessing and manipulating the Shape of the parameters.
|
||||
int parameters_size() const { return parameters_.size(); }
|
||||
const Shape& parameters(int index) const { return parameters_.at(index); }
|
||||
Shape* mutable_parameters(int index) { return ¶meters_.at(index); }
|
||||
Shape* add_parameters() {
|
||||
parameters_.emplace_back();
|
||||
return ¶meters_.back();
|
||||
}
|
||||
void clear_parameters() { parameters_.clear(); }
|
||||
const std::vector<Shape>& parameters() const { return parameters_; }
|
||||
std::vector<Shape>* mutable_parameters() { return ¶meters_; }
|
||||
|
||||
// Methods for accessing and manipulating the Shape of the result.
|
||||
const Shape& result() const { return result_; }
|
||||
Shape* mutable_result() { return &result_; }
|
||||
void clear_result() { result_.Clear(); }
|
||||
|
||||
// Methods for accessing and manipulating the names of the parameters.
|
||||
int parameter_names_size() const { return parameter_names_.size(); }
|
||||
const string& parameter_names(int index) const {
|
||||
return parameter_names_.at(index);
|
||||
}
|
||||
void set_parameter_names(int index, const string& value) {
|
||||
parameter_names_.at(index) = value;
|
||||
}
|
||||
string* mutable_parameter_names(int index) {
|
||||
return ¶meter_names_.at(index);
|
||||
}
|
||||
void add_parameter_names(const string& value) {
|
||||
parameter_names_.push_back(value);
|
||||
}
|
||||
string* add_parameter_names() {
|
||||
parameter_names_.push_back("");
|
||||
return ¶meter_names_.back();
|
||||
}
|
||||
void clear_parameter_names() { parameter_names_.clear(); }
|
||||
const std::vector<string>& parameter_names() const {
|
||||
return parameter_names_;
|
||||
}
|
||||
std::vector<string>* mutable_parameter_names() { return ¶meter_names_; }
|
||||
|
||||
string ShortDebugString() const { return ToProto().ShortDebugString(); }
|
||||
string DebugString() const { return ToProto().DebugString(); }
|
||||
|
||||
private:
|
||||
// The shapes of the parameters of the computation represented by this object.
|
||||
std::vector<Shape> parameters_;
|
||||
|
||||
// The names of the parameters of the computation represented by this object.
|
||||
std::vector<string> parameter_names_;
|
||||
|
||||
// The shape of the result of the computation represented by this object.
|
||||
Shape result_;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, const ProgramShape& program_shape);
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_XLA_SHAPE_H_
|
112
tensorflow/compiler/xla/shape_test.cc
Normal file
112
tensorflow/compiler/xla/shape_test.cc
Normal file
@ -0,0 +1,112 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
|
||||
#include <numeric>
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "absl/strings/str_join.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
#include "tensorflow/compiler/xla/test_helpers.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/compiler/xla/util.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
||||
TEST(ShapeTest, ProgramShapeToFromProto) {
|
||||
ProgramShape program_shape;
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeShape(F32, {1, 2, 3});
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeTokenShape();
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeShape(S64, {});
|
||||
*program_shape.add_parameters() = ShapeUtil::MakeTupleShape(
|
||||
{ShapeUtil::MakeShape(S32, {}),
|
||||
ShapeUtil::MakeTupleShape({ShapeUtil::MakeTokenShape()}),
|
||||
ShapeUtil::MakeShape(F32, {42, 42})});
|
||||
|
||||
*program_shape.mutable_result() = ShapeUtil::MakeShape(F32, {7});
|
||||
|
||||
program_shape.add_parameter_names("foo");
|
||||
program_shape.add_parameter_names("bar");
|
||||
program_shape.add_parameter_names("baz");
|
||||
program_shape.add_parameter_names("qux qux");
|
||||
|
||||
// Create a copy of the program shape by round-tripping through a proto.
|
||||
ProgramShape program_shape_copy(program_shape.ToProto());
|
||||
ASSERT_EQ(program_shape.parameters_size(),
|
||||
program_shape_copy.parameters_size());
|
||||
for (int i = 0; i < program_shape.parameters_size(); ++i) {
|
||||
EXPECT_TRUE(ShapeUtil::Equal(program_shape.parameters(i),
|
||||
program_shape_copy.parameters(i)));
|
||||
}
|
||||
|
||||
EXPECT_TRUE(
|
||||
ShapeUtil::Equal(program_shape.result(), program_shape_copy.result()));
|
||||
|
||||
ASSERT_EQ(program_shape.parameter_names_size(),
|
||||
program_shape_copy.parameter_names_size());
|
||||
for (int i = 0; i < program_shape.parameter_names_size(); ++i) {
|
||||
EXPECT_EQ(program_shape.parameter_names(i),
|
||||
program_shape_copy.parameter_names(i));
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ShapeTest, ProgramShapeToString) {
|
||||
Shape opaque = ShapeUtil::MakeOpaqueShape();
|
||||
Shape token = ShapeUtil::MakeTokenShape();
|
||||
Shape scalar = ShapeUtil::MakeShape(F32, {});
|
||||
Shape matrix = ShapeUtil::MakeShape(U32, {1, 2});
|
||||
Shape matrix2 = ShapeUtil::MakeShapeWithLayout(S32, {3, 4}, {0, 1});
|
||||
Shape tuple = ShapeUtil::MakeTupleShape({opaque, scalar, matrix, matrix2});
|
||||
Shape nested_tuple = ShapeUtil::MakeTupleShape({tuple, matrix, token});
|
||||
|
||||
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
||||
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
||||
EXPECT_EQ(
|
||||
"((unknown): opaque[], "
|
||||
"(unknown): f32[], "
|
||||
"(unknown): u32[1,2], "
|
||||
"(unknown): s32[3,4], "
|
||||
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
|
||||
prog.add_parameter_names("arg0");
|
||||
prog.add_parameter_names("scalar");
|
||||
prog.add_parameter_names("matrix");
|
||||
prog.add_parameter_names("matrix2");
|
||||
prog.add_parameter_names("tuple");
|
||||
prog.add_parameter_names("nested_tuple");
|
||||
EXPECT_EQ(
|
||||
"(arg0: opaque[], "
|
||||
"scalar: f32[], "
|
||||
"matrix: u32[1,2], "
|
||||
"matrix2: s32[3,4], "
|
||||
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
|
||||
"token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
@ -563,6 +563,20 @@ StatusOr<PrimitiveType> StringToPrimitiveType(const string& name) {
|
||||
HumanString(program_shape.result()));
|
||||
}
|
||||
|
||||
/* static */ string ShapeUtil::HumanString(
|
||||
const ProgramShapeProto& program_shape_proto) {
|
||||
std::vector<string> parameters;
|
||||
for (auto& shape : program_shape_proto.parameters()) {
|
||||
const int i = parameters.size();
|
||||
parameters.push_back(StrCat(i < program_shape_proto.parameter_names_size()
|
||||
? program_shape_proto.parameter_names(i)
|
||||
: "(unknown)",
|
||||
": ", HumanString(shape)));
|
||||
}
|
||||
return StrCat("(", absl::StrJoin(parameters, ", "), ") -> ",
|
||||
HumanString(program_shape_proto.result()));
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Parses shapes with simple recursive descent structure -- consumes from the
|
||||
// front of s and passes that view recursively as required.
|
||||
|
@ -28,6 +28,7 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/compiler/xla/shape.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
@ -239,6 +240,7 @@ class ShapeUtil {
|
||||
//
|
||||
// (param_name: f32[42x12], ...) -> f32[24x42]
|
||||
static string HumanString(const ProgramShape& program_shape);
|
||||
static string HumanString(const ProgramShapeProto& program_shape_proto);
|
||||
|
||||
// Parses a ShapeUtil::HumanString-format shape string back into a shape
|
||||
// object.
|
||||
|
@ -575,37 +575,6 @@ TEST(ShapeUtilTest, HumanString) {
|
||||
"((opaque[], f32[], u32[1,2]{1,0}, s32[3,4]{0,1}), u32[1,2]{1,0}, "
|
||||
"token[])",
|
||||
ShapeUtil::HumanStringWithLayout(nested_tuple));
|
||||
|
||||
ProgramShape prog = ShapeUtil::MakeProgramShape(
|
||||
{opaque, scalar, matrix, matrix2, tuple, nested_tuple}, nested_tuple);
|
||||
EXPECT_EQ(
|
||||
"((unknown): opaque[], "
|
||||
"(unknown): f32[], "
|
||||
"(unknown): u32[1,2], "
|
||||
"(unknown): s32[3,4], "
|
||||
"(unknown): (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"(unknown): ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
|
||||
prog.add_parameter_names("arg0");
|
||||
prog.add_parameter_names("scalar");
|
||||
prog.add_parameter_names("matrix");
|
||||
prog.add_parameter_names("matrix2");
|
||||
prog.add_parameter_names("tuple");
|
||||
prog.add_parameter_names("nested_tuple");
|
||||
EXPECT_EQ(
|
||||
"(arg0: opaque[], "
|
||||
"scalar: f32[], "
|
||||
"matrix: u32[1,2], "
|
||||
"matrix2: s32[3,4], "
|
||||
"tuple: (opaque[], f32[], u32[1,2], s32[3,4]), "
|
||||
"nested_tuple: ((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], "
|
||||
"token[])) "
|
||||
"-> "
|
||||
"((opaque[], f32[], u32[1,2], s32[3,4]), u32[1,2], token[])",
|
||||
ShapeUtil::HumanString(prog));
|
||||
}
|
||||
|
||||
TEST(ShapeUtilTest, ForEachSubshapeArray) {
|
||||
|
@ -55,7 +55,8 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
|
||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
std::unique_ptr<ProgramShape> replayed_shape =
|
||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||
replayed_shape->ToProto()));
|
||||
|
||||
// Run it.
|
||||
Literal literal =
|
||||
@ -87,7 +88,8 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
|
||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
std::unique_ptr<ProgramShape> replayed_shape =
|
||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||
replayed_shape->ToProto()));
|
||||
|
||||
// Run it.
|
||||
std::unique_ptr<GlobalData> x_data =
|
||||
@ -133,7 +135,8 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
|
||||
client_->GetComputationShape(computation).ConsumeValueOrDie();
|
||||
std::unique_ptr<ProgramShape> replayed_shape =
|
||||
client_->GetComputationShape(replayed).ConsumeValueOrDie();
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
|
||||
ASSERT_TRUE(protobuf_util::ProtobufEquals(original_shape->ToProto(),
|
||||
replayed_shape->ToProto()));
|
||||
|
||||
// Run it.
|
||||
Literal literal =
|
||||
|
@ -183,7 +183,7 @@ message Shape {
|
||||
|
||||
// Shape of the parameters and output of a computation (like a traditional
|
||||
// function signature).
|
||||
message ProgramShape {
|
||||
message ProgramShapeProto {
|
||||
repeated Shape parameters = 1;
|
||||
Shape result = 2;
|
||||
repeated string parameter_names = 3;
|
||||
|
@ -174,11 +174,12 @@ void XRTCompileOp::Compute(OpKernelContext* ctx) {
|
||||
ctx->set_output(0, handle_output);
|
||||
|
||||
xla::LocalExecutable* executable = entry->get().get_executable();
|
||||
xla::ProgramShape program_shape = executable->executable()
|
||||
->module()
|
||||
.config()
|
||||
.entry_computation_layout()
|
||||
.ComputeProgramShape();
|
||||
xla::ProgramShapeProto program_shape = executable->executable()
|
||||
->module()
|
||||
.config()
|
||||
.entry_computation_layout()
|
||||
.ComputeProgramShape()
|
||||
.ToProto();
|
||||
Tensor program_shape_output(DT_STRING, TensorShape({1}));
|
||||
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
|
||||
ctx->set_output(1, program_shape_output);
|
||||
|
@ -411,7 +411,7 @@ TEST(RawApiTest, CompileAndExecute) {
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
@ -465,7 +465,7 @@ TEST(RawApiTest, CompileAndExecuteWithArgumentVector) {
|
||||
auto expected = xla::LiteralUtil::CreateR1<float>({27.0f, 21.0f});
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
}
|
||||
@ -510,7 +510,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
TF_EXPECT_OK(session.Run(tensorflow::ClientSession::FeedType(),
|
||||
{c_handle.program_shape}, {release}, &outputs));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[0].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 1);
|
||||
|
||||
@ -520,7 +520,7 @@ TEST(RawApiTest, CompileWithXlaReturnShapes) {
|
||||
<< xla::ShapeUtil::HumanStringWithLayout(program_shape.result());
|
||||
|
||||
xla::ProgramShape xla_program_shape =
|
||||
XlaCompiledProgramShape(xla_computation, *shapes);
|
||||
XlaCompiledProgramShape(xla_computation, xla::ProgramShape(*shapes));
|
||||
EXPECT_TRUE(xla::LayoutUtil::Equal(
|
||||
xla::ShapeUtil::GetSubshape(program_shape.parameters(0), {0}).layout(),
|
||||
xla::ShapeUtil::GetSubshape(xla_program_shape.parameters(0), {0})
|
||||
@ -739,7 +739,7 @@ TEST(RawApiTest, CompileAndExecuteWithS64Argument) {
|
||||
auto expected = xla::LiteralUtil::CreateR0<int64>(15123899);
|
||||
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
|
||||
|
||||
xla::ProgramShape program_shape;
|
||||
xla::ProgramShapeProto program_shape;
|
||||
EXPECT_TRUE(program_shape.ParseFromString(outputs[1].vec<string>()(0)));
|
||||
EXPECT_EQ(program_shape.parameters_size(), 2);
|
||||
EXPECT_TRUE(
|
||||
|
@ -36,11 +36,11 @@ message XLAComputationConfig {
|
||||
tensorflow.tf2xla.HostComputeMetadata host_compute_metadata = 3;
|
||||
|
||||
// The arg/result shapes for the whole computation.
|
||||
xla.ProgramShape program_shape = 4;
|
||||
xla.ProgramShapeProto program_shape = 4;
|
||||
// The arg/result shapes for each core of a model-parallel
|
||||
// computation. per_core_args_and_result_shapes is optional for a
|
||||
// single-core computation.
|
||||
repeated xla.ProgramShape per_core_program_shape = 5;
|
||||
repeated xla.ProgramShapeProto per_core_program_shape = 5;
|
||||
// Describes how replicated computation instances should be assigned to
|
||||
// devices. There are num_cores_per_replica computations, and each one will be
|
||||
// sent and executed to the set of replica device numbers described in the
|
||||
|
Loading…
Reference in New Issue
Block a user