[TF:XLA] Use HloEvaluator for ComputeConstant, remove the need of a dedicated

compute constant backend.

PiperOrigin-RevId: 164940970
This commit is contained in:
Kay Zhu 2017-08-10 21:12:44 -07:00 committed by TensorFlower Gardener
parent eeacdcdb14
commit 87605f3d6a
12 changed files with 90 additions and 161 deletions

View File

@ -119,22 +119,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
// Ask the XLA compiler to evaluate the data handle to a literal.
xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed =
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
builder()->ComputeConstant(handle, &layout);
if (!computed.ok()) {
return errors::InvalidArgument(
"Error evaluating ", context_->op_kernel().name(), " input ", index,
": ", computed.status().error_message());
}
// Fetch the literal from the compiler service.
xla::StatusOr<std::unique_ptr<xla::Literal>> constant =
builder()->client()->Transfer(*computed.ValueOrDie());
if (!constant.ok()) {
return errors::InvalidArgument(
"Error evaluating ", context_->op_kernel().name(), " input ", index,
": ", constant.status().error_message());
}
constant_literal->Swap(constant.ValueOrDie().get());
constant_literal->Swap(computed.ValueOrDie().get());
return Status::OK();
}

View File

@ -111,13 +111,12 @@ bool ComputationBuilder::MakeWindow(
return true;
} else {
NoteError(InvalidArgument(
"%s",
tensorflow::strings::StrCat(
"Window has different number of window dimensions than of ",
x_name, "\nNumber of window dimensions: ",
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
"\n")
.c_str())); //
"%s", tensorflow::strings::StrCat(
"Window has different number of window dimensions than of ",
x_name, "\nNumber of window dimensions: ",
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
"\n")
.c_str())); //
return false;
}
};
@ -663,24 +662,26 @@ bool ComputationBuilder::VerifyConvolution(
}
int num_spatial_dims = num_dims - 2;
const auto check_spatial_dimensions = [&](
const char* const field_name,
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
numbers) {
if (numbers.size() != num_spatial_dims) {
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
num_spatial_dims, field_name, numbers.size()));
return false;
}
for (int i = 0; i < numbers.size(); ++i) {
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
field_name, i, numbers.Get(i)));
return false;
}
}
return true;
};
const auto check_spatial_dimensions =
[&](const char* const field_name,
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
numbers) {
if (numbers.size() != num_spatial_dims) {
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
num_spatial_dims, field_name,
numbers.size()));
return false;
}
for (int i = 0; i < numbers.size(); ++i) {
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
NoteError(
InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
field_name, i, numbers.Get(i)));
return false;
}
}
return true;
};
return check_spatial_dimensions("spatial_dimensions",
dimension_numbers.spatial_dimensions()) &&
check_spatial_dimensions(
@ -1268,7 +1269,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
return response.is_constant();
}
StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
const ComputationDataHandle& operand, const Layout* output_layout) {
if (!first_error_.ok()) {
return first_error_;
@ -1291,8 +1292,14 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
return s;
}
TF_RET_CHECK(response.output().handle() != 0);
return MakeUnique<GlobalData>(client_->stub(), response.output());
VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
if (!response.has_literal()) {
return InternalError(
"no computed literal in the provided response in ComputeConstant "
"request");
}
return MakeUnique<Literal>(response.literal());
}
ComputationDataHandle ComputationBuilder::Map(

View File

@ -679,12 +679,12 @@ class ComputationBuilder {
// Computes the value of a constant indicated by a
// ComputationDataHandle.
//
// The handle must be from the computation currently being built -
// The operand must be from the computation currently being built -
// i.e., returned from this builder with no intervening call to
// Build(). This happens to currently work regardless of that, but
// that may stop working at any time.
//
// The handle must represent a constant value, which in this case
// The operand must represent a constant value, which in this case
// means that it must not statically depend on a parameter to the
// computation that is being built.
//
@ -702,8 +702,8 @@ class ComputationBuilder {
//
// If output_layout is non-null, then the output of the computation
// will be stored using that layout.
StatusOr<std::unique_ptr<GlobalData>> ComputeConstant(
const ComputationDataHandle& handle,
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
const ComputationDataHandle& operand,
const Layout* output_layout = nullptr);
// Returns a new ComputationBuilder whose resultant Computation is used only

View File

@ -428,6 +428,7 @@ cc_library(
":gpu_transfer_manager",
":hlo",
":hlo_cost_analysis",
":hlo_evaluator",
":hlo_execution_profile",
":hlo_module_config",
":platform_util",

View File

@ -50,19 +50,14 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
options, compiler, std::move(compute_constant_backend)));
std::unique_ptr<CompileOnlyService> service(
new CompileOnlyService(options, compiler));
return std::move(service);
}
CompileOnlyService::CompileOnlyService(
const ServiceOptions& options, Compiler* compiler,
std::unique_ptr<Backend> compute_constant_backend)
: Service(options, /*backend=*/nullptr,
std::move(compute_constant_backend)),
compiler_(compiler) {}
CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
Compiler* compiler)
: Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
CompileOnlyService::CompileAheadOfTime(

View File

@ -102,9 +102,8 @@ class CompileOnlyService : public Service {
}
private:
explicit CompileOnlyService(
const ServiceOptions& options, Compiler* compiler,
std::unique_ptr<Backend> compute_constant_backend);
explicit CompileOnlyService(const ServiceOptions& options,
Compiler* compiler);
CompileOnlyService(const CompileOnlyService&) = delete;
void operator=(const CompileOnlyService&) = delete;

View File

@ -54,23 +54,19 @@ namespace xla {
}
BackendOptions backend_options;
backend_options.set_platform(platform)
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
backend_options.set_platform(platform).set_intra_op_parallelism_threads(
options.intra_op_parallelism_threads());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<LocalService> service(new LocalService(
options, std::move(backend), std::move(compute_constant_backend)));
std::unique_ptr<LocalService> service(
new LocalService(options, std::move(backend)));
return std::move(service);
}
LocalService::LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend)
: Service(options, std::move(execute_backend),
std::move(compute_constant_backend)) {}
std::unique_ptr<Backend> execute_backend)
: Service(options, std::move(execute_backend)) {}
namespace {
// Returns the space required to allocate a shape. If
@ -161,7 +157,6 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
argument_layouts.size());
return BuildExecutable(versioned_handle, std::move(module_config),
/*executable_for_compute_constant=*/false,
argument_buffers, execute_backend_.get(), executor);
}

View File

@ -57,8 +57,7 @@ class LocalService : public Service {
private:
explicit LocalService(const ServiceOptions& options,
std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend);
std::unique_ptr<Backend> backend);
LocalService(const LocalService&) = delete;
void operator=(const LocalService&) = delete;
};

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/executable.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
@ -144,36 +145,15 @@ int ServiceOptions::intra_op_parallelism_threads() const {
backend_options.set_platform(platform);
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
CreateComputeConstantBackend());
std::unique_ptr<Service> service(
new Service(options, std::move(execute_backend),
std::move(compute_constant_backend)));
new Service(options, std::move(execute_backend)));
return std::move(service);
}
/* static */ StatusOr<std::unique_ptr<Backend>>
Service::CreateComputeConstantBackend() {
TF_ASSIGN_OR_RETURN(std::vector<se::Platform*> platforms,
PlatformUtil::GetSupportedPlatforms());
for (auto* platform : platforms) {
if (platform->id() == se::host::kHostPlatformId) {
BackendOptions backend_options;
backend_options.set_platform(platform);
return Backend::CreateBackend(backend_options);
}
}
return NotFound("CPU platform not found");
}
Service::Service(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend,
std::unique_ptr<Backend> compute_constant_backend)
: options_(options),
execute_backend_(std::move(execute_backend)),
compute_constant_backend_(std::move(compute_constant_backend)) {
std::unique_ptr<Backend> execute_backend)
: options_(options), execute_backend_(std::move(execute_backend)) {
CHECK(options_.number_of_replicas() > 0);
if (execute_backend_) {
if (execute_backend_->device_count() > 0) {
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
@ -418,7 +398,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config,
bool executable_for_compute_constant,
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Backend* backend, se::StreamExecutor* executor) {
@ -431,8 +410,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
module_config->debug_options().xla_dump_computations_to();
const string& other_directory_path =
module_config->debug_options().xla_dump_executions_to();
if (!executable_for_compute_constant &&
(!directory_path.empty() || !other_directory_path.empty())) {
if (!directory_path.empty() || !other_directory_path.empty()) {
TF_ASSIGN_OR_RETURN(
session_module,
computation_tracker_.SnapshotComputation(versioned_handle.handle));
@ -450,7 +428,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
std::unique_ptr<HloModule> module,
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
/*include_unreachable_instructions=*/
!executable_for_compute_constant));
true));
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable,
@ -490,8 +468,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
HloModuleConfig original_module_config = *module_config;
TF_ASSIGN_OR_RETURN(
std::unique_ptr<Executable> executable_unique_ptr,
BuildExecutable(versioned_handle, std::move(module_config),
/*executable_for_compute_constant=*/false, arguments,
BuildExecutable(versioned_handle, std::move(module_config), arguments,
backend, executor));
if (profile != nullptr) {
@ -1098,7 +1075,6 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
TF_ASSIGN_OR_RETURN(bool is_constant,
user_computation->IsConstant(arg->operand()));
if (!is_constant) {
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
}
@ -1114,8 +1090,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
execution_options.mutable_debug_options()->set_xla_backend_optimization_level(
0);
execution_options.mutable_debug_options()
->set_xla_eliminate_hlo_implicit_broadcast(true);
*execution_options.mutable_shape_with_output_layout() =
program_shape.result();
@ -1130,20 +1106,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
CreateModuleConfig(program_shape, {}, execution_options));
// Exclude dead parameter instructions for the purpose of computing constants.
TF_ASSIGN_OR_RETURN(
std::shared_ptr<Executable> executable,
BuildExecutable(versioned_handle, std::move(module_config),
/*executable_for_compute_constant=*/true,
/*arguments=*/{}, compute_constant_backend_.get(),
compute_constant_backend_->default_stream_executor()));
std::unique_ptr<HloModule> module,
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
/*include_unreachable_instructions=*/
false));
HloEvaluator evaluator;
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
// Since the shape_with_output_layout option in ExecutionOption is
// non-effective to the Evaluator results, explicit relayout here.
if (arg->has_output_layout()) {
result_literal = result_literal->Relayout(arg->output_layout());
}
*result->mutable_literal() = result_literal->ToProto();
TF_ASSIGN_OR_RETURN(
*result->mutable_output(),
ExecuteAndRegisterResult(
executable.get(), /*arguments=*/{}, compute_constant_backend_.get(),
compute_constant_backend_->default_stream_executor(),
"constant computed from " + user_computation->name(),
/*profile=*/nullptr));
return tensorflow::Status::OK();
}

View File

@ -71,9 +71,9 @@ class ServiceOptions {
int intra_op_parallelism_threads_ = -1;
};
// The XLA service object, which is the same across all
// platforms. It maintains the service state of computations and allocations,
// and delegates target-specific requests to the target-specific infrastructure
// The XLA service object, which is the same across all platforms. It maintains
// the service state of computations and allocations, and delegates
// target-specific requests to the target-specific infrastructure
// (target-specific compiler, StreamExecutor).
class Service : public ServiceInterface {
public:
@ -258,8 +258,8 @@ class Service : public ServiceInterface {
// The constructor is private. Use the NewService factory to create new
// service objects.
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
std::unique_ptr<Backend> compute_constant_backend);
Service(const ServiceOptions& options,
std::unique_ptr<Backend> execute_backend);
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
@ -280,16 +280,10 @@ class Service : public ServiceInterface {
const ExecutionOptions* execution_options,
bool has_hybrid_result = false);
// Builds an Executable for the given parameters. If
// executable_for_compute_constant is true, then the executable is intended to
// be used for ComputeConstant which means dead parameter instructions are not
// included in the executable.The parameter "profile" can optionally point to
// an ExecutionProfile object which will be filled in with profile data
// relevant to compilation.
// Builds an Executable for the given parameters.
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
const VersionedComputationHandle& versioned_handle,
std::unique_ptr<HloModuleConfig> module_config,
bool executable_for_compute_constant,
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
arguments,
Backend* backend, perftools::gputools::StreamExecutor* executor);
@ -381,9 +375,6 @@ class Service : public ServiceInterface {
// TODO(b/28616830): Support multiple backends for execution.
std::unique_ptr<Backend> execute_backend_;
// Backend to use when executing ComputeConstant.
std::unique_ptr<Backend> compute_constant_backend_;
TF_DISALLOW_COPY_AND_ASSIGN(Service);
};

View File

@ -72,9 +72,8 @@ class ComputeConstantTest : public ::testing::Test {
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
Client* client, const ComputationDataHandle& operand,
ComputationBuilder* builder, Layout* output_layout = nullptr) {
TF_ASSIGN_OR_RETURN(auto remote_computed,
TF_ASSIGN_OR_RETURN(auto computed,
builder->ComputeConstant(operand, output_layout));
TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed));
return std::move(computed);
}
@ -253,35 +252,5 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
}
}
// This test is permanently disabled on CPU because it requires that the
// backend used for execution is different than the backend used for
// ComputeConstant which is always cpu.
TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
// Compute a trivial constant, then try to use the value in an Execute
// call. This should fail because the constant resides on the CPU and the
// Execute call is executed on a different backend. This test only makes
// sense with LocalClient, since CompileOnlyClient does not support
// execution.
Client* client = ClientOrDie(platform_, ClientType::kLocal);
ComputationBuilder constant_b(client, TestName());
auto constant = constant_b.ConstantR0<int32>(42);
auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
auto literal = client->Transfer(*handle).ConsumeValueOrDie();
LiteralTestUtil::ExpectR0Equal(42, *literal);
// Build trivial computation which takes one parameter.
ComputationBuilder b(client, TestName());
b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
auto computation = b.Build().ConsumeValueOrDie();
// Try to use value from ComputeConstant in Execute.
auto execute_status = client->Execute(computation, {handle.get()});
EXPECT_FALSE(execute_status.ok());
EXPECT_THAT(
execute_status.status().error_message(),
::testing::ContainsRegex("argument 0 is on device Host:0 but computation "
"will be executed on device"));
}
} // namespace
} // namespace xla

View File

@ -350,7 +350,9 @@ message ComputeConstantRequest {
}
message ComputeConstantResponse {
GlobalDataHandle output = 1;
// A LiteralProto is returned directly for this request, instead of a
// ComputationDataHandle.
LiteralProto literal = 1;
}
message DeconstructTupleRequest {