Add TpuCompileOp implementation.
PiperOrigin-RevId: 321717418 Change-Id: I4e0fb203014c54252511d6596063461dcc5de250
This commit is contained in:
parent
d7918bcd43
commit
94d2ab31f8
@ -178,14 +178,12 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_layout",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/compiler/xla/service:dump",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||
"//tensorflow/compiler/xla/service:hlo_module_group",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core/framework:protos_all_cc",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
"@com_google_absl//absl/strings",
|
||||
@ -496,10 +494,7 @@ tf_proto_library_cc(
|
||||
cc_library(
|
||||
name = "tpu_compile_op_hdrs",
|
||||
hdrs = ["tpu_compile_op.h"],
|
||||
deps = [
|
||||
":tpu_compile_op_common",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
deps = ["//tensorflow/core:framework"],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
@ -558,60 +553,3 @@ cc_library(
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_registration",
|
||||
srcs = ["tpu_compile_op_registration.cc"],
|
||||
deps = [
|
||||
":tpu_compile_op_common",
|
||||
":tpu_compile_op_impl",
|
||||
":tpu_compile_op_impl_factory",
|
||||
":tpu_compile_op_support",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_impl_factory",
|
||||
srcs = ["tpu_compile_op_impl_factory.cc"],
|
||||
hdrs = ["tpu_compile_op_impl_factory.h"],
|
||||
deps = [
|
||||
":tpu_compile_op_common",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op_lib",
|
||||
srcs = ["tpu_compile_op.cc"],
|
||||
deps = [
|
||||
":tpu_compile_op_hdrs",
|
||||
":tpu_compile_op_impl_factory",
|
||||
":tpu_compile_op_options",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_compile_op",
|
||||
deps = [
|
||||
":tpu_compile_op_hdrs",
|
||||
":tpu_compile_op_impl",
|
||||
":tpu_compile_op_impl_factory",
|
||||
":tpu_compile_op_lib",
|
||||
":tpu_compile_op_options",
|
||||
":tpu_compile_op_registration",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
@ -1,94 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
using ::stream_executor::port::StatusOr;
|
||||
|
||||
TpuCompileOp::TpuCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
|
||||
(*GetTpuCompileOpCreateFn())(ctx);
|
||||
OP_REQUIRES_OK(ctx, compile_op.status());
|
||||
impl_ = std::move(compile_op.ValueOrDie());
|
||||
}
|
||||
|
||||
void TpuCompileOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
|
||||
|
||||
TpuCompileMlirOp::TpuCompileMlirOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
|
||||
(*GetTpuCompileOpMlirCreateFn())(ctx);
|
||||
OP_REQUIRES_OK(ctx, compile_op.status());
|
||||
impl_ = std::move(compile_op.ValueOrDie());
|
||||
}
|
||||
|
||||
void TpuCompileMlirOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
|
||||
|
||||
void TpuCompileSucceededAssertOp::Compute(OpKernelContext* ctx) {
|
||||
const Tensor compilation_result = ctx->input(0);
|
||||
CompilationResultProto proto;
|
||||
Status status;
|
||||
if (!proto.ParseFromString(compilation_result.scalar<tstring>()())) {
|
||||
status =
|
||||
errors::InvalidArgument("Unable to parse compilation result proto");
|
||||
}
|
||||
if (!status.ok() || proto.status_code() != error::Code::OK) {
|
||||
status.Update(Status(proto.status_code(), proto.status_error_message()));
|
||||
errors::AppendToMessage(&status, "TPU compilation failed");
|
||||
if (tensorflow::internal::TpuCompilationFailureClosesChips()) {
|
||||
// At this point, if compilation fails we do not know if a task
|
||||
// is already running that expects results from this compiled
|
||||
// program to complete. So close the TPU driver to release all
|
||||
// awaiting interactions (all awaiting interaction will fail and
|
||||
// continue to fail until reinitialized).
|
||||
LOG(ERROR) << "Cloud TPU: Closing chips. TPU compilation is considered "
|
||||
"as part of device state, and a failed compilation results "
|
||||
"in a device reset.";
|
||||
|
||||
Status close_status = TpuNodeContext::CloseTpuHost();
|
||||
|
||||
if (!close_status.ok()) {
|
||||
errors::AppendToMessage(&status, close_status.error_message());
|
||||
}
|
||||
}
|
||||
ctx->CtxFailureWithWarning(status);
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_MODULE_INITIALIZER(register_tpu_compile_op_impl, {
|
||||
#if !defined(LIBTFTPU)
|
||||
VLOG(1) << "register_tpu_compile_op_impl: TpuCompileOpKernelImpl";
|
||||
SetTpuCompileOpCreateFn(CreateTpuCompileOpImpl);
|
||||
SetTpuCompileOpMlirCreateFn(CreateTpuCompileOpMlirImpl);
|
||||
#endif // LIBTFTPU
|
||||
});
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TPUCompile").Device(DEVICE_CPU), TpuCompileOp);
|
||||
REGISTER_KERNEL_BUILDER(Name("_TPUCompileMlir").Device(DEVICE_CPU),
|
||||
TpuCompileMlirOp);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("TPUCompileSucceededAssert").Device(DEVICE_CPU),
|
||||
TpuCompileSucceededAssertOp);
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -18,10 +18,18 @@ limitations under the License.
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
// Forward declaration.
|
||||
#if defined(LIBTFTPU)
|
||||
class TpuCompileOpKernelImpl;
|
||||
#else
|
||||
namespace internal {
|
||||
class TpuCompileOpKernelImpl;
|
||||
}
|
||||
#endif
|
||||
} // namespace tpu
|
||||
|
||||
// The TPUCompile operator compiles a Tensorflow function into a
|
||||
// TPU executable to be run by TPUExecute.
|
||||
@ -34,9 +42,13 @@ class TpuCompileOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
|
||||
#if defined(LIBTFTPU)
|
||||
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
|
||||
#else
|
||||
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
|
||||
#endif
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
|
||||
DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
|
||||
};
|
||||
|
||||
// The TPUCompile operator compiles a MLIR module into a
|
||||
@ -50,9 +62,13 @@ class TpuCompileMlirOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
|
||||
#if defined(LIBTFTPU)
|
||||
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
|
||||
#else
|
||||
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
|
||||
#endif
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
|
||||
DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
|
||||
};
|
||||
|
||||
class TpuCompileSucceededAssertOp : public OpKernel {
|
||||
@ -64,10 +80,9 @@ class TpuCompileSucceededAssertOp : public OpKernel {
|
||||
void Compute(OpKernelContext* ctx) override;
|
||||
|
||||
private:
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
|
||||
DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
|
||||
};
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_
|
||||
|
@ -1,46 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
namespace {
|
||||
static TpuCompileOpImplCreateFn* tpu_compile_op_impl_creation_fn =
|
||||
new TpuCompileOpImplCreateFn(CreateTpuCompileOpImpl);
|
||||
static TpuCompileOpImplCreateFn* tpu_compile_op_mlir_impl_creation_fn =
|
||||
new TpuCompileOpImplCreateFn(CreateTpuCompileOpMlirImpl);
|
||||
} // namespace
|
||||
|
||||
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn() {
|
||||
return tpu_compile_op_impl_creation_fn;
|
||||
}
|
||||
|
||||
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn() {
|
||||
return tpu_compile_op_mlir_impl_creation_fn;
|
||||
}
|
||||
|
||||
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn) {
|
||||
VLOG(1) << "SetTpuCompileOpCreateFn.";
|
||||
delete tpu_compile_op_impl_creation_fn;
|
||||
tpu_compile_op_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
|
||||
}
|
||||
|
||||
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn) {
|
||||
VLOG(1) << "SetTpuCompileOpMlirCreateFn.";
|
||||
delete tpu_compile_op_mlir_impl_creation_fn;
|
||||
tpu_compile_op_mlir_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -1,55 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
|
||||
typedef std::function<stream_executor::port::StatusOr<
|
||||
std::unique_ptr<TpuCompileOpKernelCommon>>(OpKernelConstruction*)>
|
||||
TpuCompileOpImplCreateFn;
|
||||
|
||||
// Creates the callback for creating `TpuCompileOpImpl` instance.
|
||||
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
|
||||
CreateTpuCompileOpImpl(OpKernelConstruction* ctx);
|
||||
|
||||
// Creates the callback for creating Mlir `TpuCompileOpImpl` instance.
|
||||
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
|
||||
CreateTpuCompileOpMlirImpl(OpKernelConstruction* ctx);
|
||||
|
||||
// Gets the callback for creating default `TpuCompileOpImpl` instance.
|
||||
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn();
|
||||
|
||||
// Gets the callback for creating Mlir `TpuCompileOpImpl` instance.
|
||||
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn();
|
||||
|
||||
// Sets the callback for creating default `TpuCompileOpImpl` instance.
|
||||
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn);
|
||||
|
||||
// Sets the callback for creating Mlir `TpuCompileOpImpl` instance.
|
||||
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
@ -1,52 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
using ::stream_executor::port::StatusOr;
|
||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpImpl(
|
||||
OpKernelConstruction* ctx) {
|
||||
NameAttrList function_name;
|
||||
TPUCompileMetadataProto metadata;
|
||||
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(ctx, &metadata,
|
||||
&function_name,
|
||||
/*mlir_module=*/nullptr));
|
||||
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
|
||||
return {std::make_unique<TpuCompileOpKernelImpl>(
|
||||
function_name, metadata, metadata.num_cores_per_replica(),
|
||||
/*return_hlo_protos=*/false,
|
||||
/*unload_cache_on_session_close=*/false)};
|
||||
}
|
||||
|
||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpMlirImpl(
|
||||
OpKernelConstruction* ctx) {
|
||||
TPUCompileMetadataProto metadata;
|
||||
std::string mlir_module;
|
||||
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(
|
||||
ctx, &metadata, /*function_name=*/nullptr, &mlir_module));
|
||||
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
|
||||
return {std::make_unique<TpuCompileOpKernelImpl>(
|
||||
mlir_module, metadata, metadata.num_cores_per_replica(),
|
||||
/*return_hlo_protos=*/false,
|
||||
/*unload_cache_on_session_close=*/false)};
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
@ -16,28 +16,27 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace tpu {
|
||||
using ::stream_executor::port::Status;
|
||||
using ::stream_executor::port::StatusOr;
|
||||
using ::xla::ComputationLayout;
|
||||
using ::xla::DebugOptions;
|
||||
using ::xla::DeviceAssignment;
|
||||
using ::xla::HloModuleConfig;
|
||||
using ::xla::HloSharding;
|
||||
using ::xla::InvalidArgument;
|
||||
using ::xla::ProgramShape;
|
||||
using ::xla::Shape;
|
||||
using ::xla::ShapeTree;
|
||||
using ::xla::ShapeUtil;
|
||||
|
||||
using stream_executor::port::Status;
|
||||
using stream_executor::port::StatusOr;
|
||||
using xla::ComputationLayout;
|
||||
using xla::DebugOptions;
|
||||
using xla::DeviceAssignment;
|
||||
using xla::HloModuleConfig;
|
||||
using xla::HloSharding;
|
||||
using xla::InvalidArgument;
|
||||
using xla::ProgramShape;
|
||||
using xla::Shape;
|
||||
using xla::ShapeTree;
|
||||
using xla::ShapeUtil;
|
||||
|
||||
Status ValidateResultShape(const Shape& client_shape,
|
||||
const Shape& result_shape) {
|
||||
@ -486,59 +485,5 @@ StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
|
||||
VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
|
||||
return compilation_request;
|
||||
}
|
||||
|
||||
Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
TPUCompileMetadataProto* metadata,
|
||||
NameAttrList* function_name,
|
||||
std::string* mlir_module) {
|
||||
CHECK_NE(metadata, nullptr);
|
||||
|
||||
int num_computations;
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
|
||||
|
||||
std::string metadata_string;
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
|
||||
if (!metadata->ParsePartialFromString(metadata_string)) {
|
||||
return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
|
||||
}
|
||||
|
||||
if (function_name != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
|
||||
}
|
||||
|
||||
if (mlir_module != nullptr) {
|
||||
TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
|
||||
}
|
||||
|
||||
if (num_computations != metadata->num_cores_per_replica()) {
|
||||
return errors::InvalidArgument(
|
||||
"num_computations must be equal to "
|
||||
"num_cores_per_replica in the 'metadata' "
|
||||
"attribute (",
|
||||
num_computations, " vs ", metadata->num_cores_per_replica(), ")");
|
||||
}
|
||||
|
||||
if (metadata->has_device_assignment()) {
|
||||
StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
|
||||
DeviceAssignment::Deserialize(metadata->device_assignment());
|
||||
TF_RETURN_IF_ERROR(device_assignment_or_error.status());
|
||||
const DeviceAssignment& device_assignment =
|
||||
*device_assignment_or_error.ValueOrDie();
|
||||
const int num_replicas = metadata->num_replicas();
|
||||
if (device_assignment.replica_count() != num_replicas) {
|
||||
return errors::InvalidArgument(
|
||||
"Device assignment replica_count != num_replicas; ",
|
||||
device_assignment.replica_count(), " vs ", num_replicas);
|
||||
}
|
||||
if (device_assignment.computation_count() !=
|
||||
metadata->num_cores_per_replica()) {
|
||||
return errors::InvalidArgument(
|
||||
"Device assignment computation_count != num_cores_per_replica; ",
|
||||
device_assignment.computation_count(), " vs ",
|
||||
metadata->num_cores_per_replica());
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -31,7 +31,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
@ -155,10 +154,6 @@ se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
|
||||
const TPUCompileMetadataProto& metadata,
|
||||
const std::vector<TensorShape>& arg_shapes);
|
||||
|
||||
se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
||||
TPUCompileMetadataProto* metadata,
|
||||
NameAttrList* function_name,
|
||||
std::string* mlir_module);
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user