[TF:XLA] Use HloEvaluator for ComputeConstant, remove the need of a dedicated
compute constant backend. PiperOrigin-RevId: 164940970
This commit is contained in:
parent
eeacdcdb14
commit
87605f3d6a
@ -119,22 +119,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||
|
||||
// Ask the XLA compiler to evaluate the data handle to a literal.
|
||||
xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed =
|
||||
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
|
||||
builder()->ComputeConstant(handle, &layout);
|
||||
if (!computed.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Error evaluating ", context_->op_kernel().name(), " input ", index,
|
||||
": ", computed.status().error_message());
|
||||
}
|
||||
// Fetch the literal from the compiler service.
|
||||
xla::StatusOr<std::unique_ptr<xla::Literal>> constant =
|
||||
builder()->client()->Transfer(*computed.ValueOrDie());
|
||||
if (!constant.ok()) {
|
||||
return errors::InvalidArgument(
|
||||
"Error evaluating ", context_->op_kernel().name(), " input ", index,
|
||||
": ", constant.status().error_message());
|
||||
}
|
||||
constant_literal->Swap(constant.ValueOrDie().get());
|
||||
constant_literal->Swap(computed.ValueOrDie().get());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -111,13 +111,12 @@ bool ComputationBuilder::MakeWindow(
|
||||
return true;
|
||||
} else {
|
||||
NoteError(InvalidArgument(
|
||||
"%s",
|
||||
tensorflow::strings::StrCat(
|
||||
"Window has different number of window dimensions than of ",
|
||||
x_name, "\nNumber of window dimensions: ",
|
||||
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
|
||||
"\n")
|
||||
.c_str())); //
|
||||
"%s", tensorflow::strings::StrCat(
|
||||
"Window has different number of window dimensions than of ",
|
||||
x_name, "\nNumber of window dimensions: ",
|
||||
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
|
||||
"\n")
|
||||
.c_str())); //
|
||||
return false;
|
||||
}
|
||||
};
|
||||
@ -663,24 +662,26 @@ bool ComputationBuilder::VerifyConvolution(
|
||||
}
|
||||
int num_spatial_dims = num_dims - 2;
|
||||
|
||||
const auto check_spatial_dimensions = [&](
|
||||
const char* const field_name,
|
||||
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
|
||||
numbers) {
|
||||
if (numbers.size() != num_spatial_dims) {
|
||||
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
|
||||
num_spatial_dims, field_name, numbers.size()));
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < numbers.size(); ++i) {
|
||||
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
|
||||
NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
|
||||
field_name, i, numbers.Get(i)));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
const auto check_spatial_dimensions =
|
||||
[&](const char* const field_name,
|
||||
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
|
||||
numbers) {
|
||||
if (numbers.size() != num_spatial_dims) {
|
||||
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
|
||||
num_spatial_dims, field_name,
|
||||
numbers.size()));
|
||||
return false;
|
||||
}
|
||||
for (int i = 0; i < numbers.size(); ++i) {
|
||||
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
|
||||
NoteError(
|
||||
InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
|
||||
field_name, i, numbers.Get(i)));
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
return check_spatial_dimensions("spatial_dimensions",
|
||||
dimension_numbers.spatial_dimensions()) &&
|
||||
check_spatial_dimensions(
|
||||
@ -1268,7 +1269,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
||||
return response.is_constant();
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
||||
StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
|
||||
const ComputationDataHandle& operand, const Layout* output_layout) {
|
||||
if (!first_error_.ok()) {
|
||||
return first_error_;
|
||||
@ -1291,8 +1292,14 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
||||
return s;
|
||||
}
|
||||
|
||||
TF_RET_CHECK(response.output().handle() != 0);
|
||||
return MakeUnique<GlobalData>(client_->stub(), response.output());
|
||||
VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
|
||||
|
||||
if (!response.has_literal()) {
|
||||
return InternalError(
|
||||
"no computed literal in the provided response in ComputeConstant "
|
||||
"request");
|
||||
}
|
||||
return MakeUnique<Literal>(response.literal());
|
||||
}
|
||||
|
||||
ComputationDataHandle ComputationBuilder::Map(
|
||||
|
@ -679,12 +679,12 @@ class ComputationBuilder {
|
||||
// Computes the value of a constant indicated by a
|
||||
// ComputationDataHandle.
|
||||
//
|
||||
// The handle must be from the computation currently being built -
|
||||
// The operand must be from the computation currently being built -
|
||||
// i.e., returned from this builder with no intervening call to
|
||||
// Build(). This happens to currently work regardless of that, but
|
||||
// that may stop working at any time.
|
||||
//
|
||||
// The handle must represent a constant value, which in this case
|
||||
// The operand must represent a constant value, which in this case
|
||||
// means that it must not statically depend on a parameter to the
|
||||
// computation that is being built.
|
||||
//
|
||||
@ -702,8 +702,8 @@ class ComputationBuilder {
|
||||
//
|
||||
// If output_layout is non-null, then the output of the computation
|
||||
// will be stored using that layout.
|
||||
StatusOr<std::unique_ptr<GlobalData>> ComputeConstant(
|
||||
const ComputationDataHandle& handle,
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
|
||||
const ComputationDataHandle& operand,
|
||||
const Layout* output_layout = nullptr);
|
||||
|
||||
// Returns a new ComputationBuilder whose resultant Computation is used only
|
||||
|
@ -428,6 +428,7 @@ cc_library(
|
||||
":gpu_transfer_manager",
|
||||
":hlo",
|
||||
":hlo_cost_analysis",
|
||||
":hlo_evaluator",
|
||||
":hlo_execution_profile",
|
||||
":hlo_module_config",
|
||||
":platform_util",
|
||||
|
@ -50,19 +50,14 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
|
||||
options, compiler, std::move(compute_constant_backend)));
|
||||
std::unique_ptr<CompileOnlyService> service(
|
||||
new CompileOnlyService(options, compiler));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
CompileOnlyService::CompileOnlyService(
|
||||
const ServiceOptions& options, Compiler* compiler,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(options, /*backend=*/nullptr,
|
||||
std::move(compute_constant_backend)),
|
||||
compiler_(compiler) {}
|
||||
CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
|
||||
Compiler* compiler)
|
||||
: Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
|
||||
|
||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||
CompileOnlyService::CompileAheadOfTime(
|
||||
|
@ -102,9 +102,8 @@ class CompileOnlyService : public Service {
|
||||
}
|
||||
|
||||
private:
|
||||
explicit CompileOnlyService(
|
||||
const ServiceOptions& options, Compiler* compiler,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
explicit CompileOnlyService(const ServiceOptions& options,
|
||||
Compiler* compiler);
|
||||
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||
void operator=(const CompileOnlyService&) = delete;
|
||||
|
||||
|
@ -54,23 +54,19 @@ namespace xla {
|
||||
}
|
||||
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform)
|
||||
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
|
||||
backend_options.set_platform(platform).set_intra_op_parallelism_threads(
|
||||
options.intra_op_parallelism_threads());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
||||
Backend::CreateBackend(backend_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<LocalService> service(new LocalService(
|
||||
options, std::move(backend), std::move(compute_constant_backend)));
|
||||
std::unique_ptr<LocalService> service(
|
||||
new LocalService(options, std::move(backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
LocalService::LocalService(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> execute_backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: Service(options, std::move(execute_backend),
|
||||
std::move(compute_constant_backend)) {}
|
||||
std::unique_ptr<Backend> execute_backend)
|
||||
: Service(options, std::move(execute_backend)) {}
|
||||
|
||||
namespace {
|
||||
// Returns the space required to allocate a shape. If
|
||||
@ -161,7 +157,6 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||
std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
|
||||
argument_layouts.size());
|
||||
return BuildExecutable(versioned_handle, std::move(module_config),
|
||||
/*executable_for_compute_constant=*/false,
|
||||
argument_buffers, execute_backend_.get(), executor);
|
||||
}
|
||||
|
||||
|
@ -57,8 +57,7 @@ class LocalService : public Service {
|
||||
|
||||
private:
|
||||
explicit LocalService(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
std::unique_ptr<Backend> backend);
|
||||
LocalService(const LocalService&) = delete;
|
||||
void operator=(const LocalService&) = delete;
|
||||
};
|
||||
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/service/executable.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||
@ -144,36 +145,15 @@ int ServiceOptions::intra_op_parallelism_threads() const {
|
||||
backend_options.set_platform(platform);
|
||||
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
||||
CreateComputeConstantBackend());
|
||||
std::unique_ptr<Service> service(
|
||||
new Service(options, std::move(execute_backend),
|
||||
std::move(compute_constant_backend)));
|
||||
new Service(options, std::move(execute_backend)));
|
||||
return std::move(service);
|
||||
}
|
||||
|
||||
/* static */ StatusOr<std::unique_ptr<Backend>>
|
||||
Service::CreateComputeConstantBackend() {
|
||||
TF_ASSIGN_OR_RETURN(std::vector<se::Platform*> platforms,
|
||||
PlatformUtil::GetSupportedPlatforms());
|
||||
for (auto* platform : platforms) {
|
||||
if (platform->id() == se::host::kHostPlatformId) {
|
||||
BackendOptions backend_options;
|
||||
backend_options.set_platform(platform);
|
||||
return Backend::CreateBackend(backend_options);
|
||||
}
|
||||
}
|
||||
return NotFound("CPU platform not found");
|
||||
}
|
||||
|
||||
Service::Service(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> execute_backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend)
|
||||
: options_(options),
|
||||
execute_backend_(std::move(execute_backend)),
|
||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
||||
std::unique_ptr<Backend> execute_backend)
|
||||
: options_(options), execute_backend_(std::move(execute_backend)) {
|
||||
CHECK(options_.number_of_replicas() > 0);
|
||||
|
||||
if (execute_backend_) {
|
||||
if (execute_backend_->device_count() > 0) {
|
||||
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
|
||||
@ -418,7 +398,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
const VersionedComputationHandle& versioned_handle,
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
bool executable_for_compute_constant,
|
||||
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Backend* backend, se::StreamExecutor* executor) {
|
||||
@ -431,8 +410,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
module_config->debug_options().xla_dump_computations_to();
|
||||
const string& other_directory_path =
|
||||
module_config->debug_options().xla_dump_executions_to();
|
||||
if (!executable_for_compute_constant &&
|
||||
(!directory_path.empty() || !other_directory_path.empty())) {
|
||||
if (!directory_path.empty() || !other_directory_path.empty()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
session_module,
|
||||
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
||||
@ -450,7 +428,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||
std::unique_ptr<HloModule> module,
|
||||
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
|
||||
/*include_unreachable_instructions=*/
|
||||
!executable_for_compute_constant));
|
||||
true));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable,
|
||||
@ -490,8 +468,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
|
||||
HloModuleConfig original_module_config = *module_config;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::unique_ptr<Executable> executable_unique_ptr,
|
||||
BuildExecutable(versioned_handle, std::move(module_config),
|
||||
/*executable_for_compute_constant=*/false, arguments,
|
||||
BuildExecutable(versioned_handle, std::move(module_config), arguments,
|
||||
backend, executor));
|
||||
|
||||
if (profile != nullptr) {
|
||||
@ -1098,7 +1075,6 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
|
||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||
user_computation->IsConstant(arg->operand()));
|
||||
|
||||
if (!is_constant) {
|
||||
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
|
||||
}
|
||||
@ -1114,8 +1090,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
|
||||
ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
|
||||
execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
|
||||
execution_options.mutable_debug_options()->set_xla_backend_optimization_level(
|
||||
0);
|
||||
execution_options.mutable_debug_options()
|
||||
->set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||
*execution_options.mutable_shape_with_output_layout() =
|
||||
program_shape.result();
|
||||
|
||||
@ -1130,20 +1106,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||
CreateModuleConfig(program_shape, {}, execution_options));
|
||||
|
||||
// Exclude dead parameter instructions for the purpose of computing constants.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
std::shared_ptr<Executable> executable,
|
||||
BuildExecutable(versioned_handle, std::move(module_config),
|
||||
/*executable_for_compute_constant=*/true,
|
||||
/*arguments=*/{}, compute_constant_backend_.get(),
|
||||
compute_constant_backend_->default_stream_executor()));
|
||||
std::unique_ptr<HloModule> module,
|
||||
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
|
||||
/*include_unreachable_instructions=*/
|
||||
false));
|
||||
|
||||
HloEvaluator evaluator;
|
||||
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
|
||||
// Since the shape_with_output_layout option in ExecutionOption is
|
||||
// non-effective to the Evaluator results, explicit relayout here.
|
||||
if (arg->has_output_layout()) {
|
||||
result_literal = result_literal->Relayout(arg->output_layout());
|
||||
}
|
||||
*result->mutable_literal() = result_literal->ToProto();
|
||||
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
*result->mutable_output(),
|
||||
ExecuteAndRegisterResult(
|
||||
executable.get(), /*arguments=*/{}, compute_constant_backend_.get(),
|
||||
compute_constant_backend_->default_stream_executor(),
|
||||
"constant computed from " + user_computation->name(),
|
||||
/*profile=*/nullptr));
|
||||
return tensorflow::Status::OK();
|
||||
}
|
||||
|
||||
|
@ -71,9 +71,9 @@ class ServiceOptions {
|
||||
int intra_op_parallelism_threads_ = -1;
|
||||
};
|
||||
|
||||
// The XLA service object, which is the same across all
|
||||
// platforms. It maintains the service state of computations and allocations,
|
||||
// and delegates target-specific requests to the target-specific infrastructure
|
||||
// The XLA service object, which is the same across all platforms. It maintains
|
||||
// the service state of computations and allocations, and delegates
|
||||
// target-specific requests to the target-specific infrastructure
|
||||
// (target-specific compiler, StreamExecutor).
|
||||
class Service : public ServiceInterface {
|
||||
public:
|
||||
@ -258,8 +258,8 @@ class Service : public ServiceInterface {
|
||||
|
||||
// The constructor is private. Use the NewService factory to create new
|
||||
// service objects.
|
||||
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
|
||||
std::unique_ptr<Backend> compute_constant_backend);
|
||||
Service(const ServiceOptions& options,
|
||||
std::unique_ptr<Backend> execute_backend);
|
||||
|
||||
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
||||
|
||||
@ -280,16 +280,10 @@ class Service : public ServiceInterface {
|
||||
const ExecutionOptions* execution_options,
|
||||
bool has_hybrid_result = false);
|
||||
|
||||
// Builds an Executable for the given parameters. If
|
||||
// executable_for_compute_constant is true, then the executable is intended to
|
||||
// be used for ComputeConstant which means dead parameter instructions are not
|
||||
// included in the executable.The parameter "profile" can optionally point to
|
||||
// an ExecutionProfile object which will be filled in with profile data
|
||||
// relevant to compilation.
|
||||
// Builds an Executable for the given parameters.
|
||||
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
||||
const VersionedComputationHandle& versioned_handle,
|
||||
std::unique_ptr<HloModuleConfig> module_config,
|
||||
bool executable_for_compute_constant,
|
||||
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||
arguments,
|
||||
Backend* backend, perftools::gputools::StreamExecutor* executor);
|
||||
@ -381,9 +375,6 @@ class Service : public ServiceInterface {
|
||||
// TODO(b/28616830): Support multiple backends for execution.
|
||||
std::unique_ptr<Backend> execute_backend_;
|
||||
|
||||
// Backend to use when executing ComputeConstant.
|
||||
std::unique_ptr<Backend> compute_constant_backend_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(Service);
|
||||
};
|
||||
|
||||
|
@ -72,9 +72,8 @@ class ComputeConstantTest : public ::testing::Test {
|
||||
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
|
||||
Client* client, const ComputationDataHandle& operand,
|
||||
ComputationBuilder* builder, Layout* output_layout = nullptr) {
|
||||
TF_ASSIGN_OR_RETURN(auto remote_computed,
|
||||
TF_ASSIGN_OR_RETURN(auto computed,
|
||||
builder->ComputeConstant(operand, output_layout));
|
||||
TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed));
|
||||
return std::move(computed);
|
||||
}
|
||||
|
||||
@ -253,35 +252,5 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
||||
}
|
||||
}
|
||||
|
||||
// This test is permanently disabled on CPU because it requires that the
|
||||
// backend used for execution is different than the backend used for
|
||||
// ComputeConstant which is always cpu.
|
||||
TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
|
||||
// Compute a trivial constant, then try to use the value in an Execute
|
||||
// call. This should fail because the constant resides on the CPU and the
|
||||
// Execute call is executed on a different backend. This test only makes
|
||||
// sense with LocalClient, since CompileOnlyClient does not support
|
||||
// execution.
|
||||
Client* client = ClientOrDie(platform_, ClientType::kLocal);
|
||||
ComputationBuilder constant_b(client, TestName());
|
||||
auto constant = constant_b.ConstantR0<int32>(42);
|
||||
auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
|
||||
auto literal = client->Transfer(*handle).ConsumeValueOrDie();
|
||||
LiteralTestUtil::ExpectR0Equal(42, *literal);
|
||||
|
||||
// Build trivial computation which takes one parameter.
|
||||
ComputationBuilder b(client, TestName());
|
||||
b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
|
||||
auto computation = b.Build().ConsumeValueOrDie();
|
||||
|
||||
// Try to use value from ComputeConstant in Execute.
|
||||
auto execute_status = client->Execute(computation, {handle.get()});
|
||||
EXPECT_FALSE(execute_status.ok());
|
||||
EXPECT_THAT(
|
||||
execute_status.status().error_message(),
|
||||
::testing::ContainsRegex("argument 0 is on device Host:0 but computation "
|
||||
"will be executed on device"));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -350,7 +350,9 @@ message ComputeConstantRequest {
|
||||
}
|
||||
|
||||
message ComputeConstantResponse {
|
||||
GlobalDataHandle output = 1;
|
||||
// A LiteralProto is returned directly for this request, instead of a
|
||||
// ComputationDataHandle.
|
||||
LiteralProto literal = 1;
|
||||
}
|
||||
|
||||
message DeconstructTupleRequest {
|
||||
|
Loading…
Reference in New Issue
Block a user