Add TpuCompileOp implementation.

PiperOrigin-RevId: 321717418
Change-Id: I4e0fb203014c54252511d6596063461dcc5de250
This commit is contained in:
Henry Tan 2020-07-16 22:44:59 -07:00 committed by TensorFlower Gardener
parent d7918bcd43
commit 94d2ab31f8
8 changed files with 36 additions and 390 deletions

View File

@ -178,14 +178,12 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:dump",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_module_config",
"//tensorflow/compiler/xla/service:hlo_module_group",
"//tensorflow/core:framework",
"//tensorflow/core/framework:protos_all_cc",
"//tensorflow/core/platform:errors",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
"//tensorflow/stream_executor/tpu:proto_helper",
"@com_google_absl//absl/strings",
@ -496,10 +494,7 @@ tf_proto_library_cc(
cc_library(
name = "tpu_compile_op_hdrs",
hdrs = ["tpu_compile_op.h"],
deps = [
":tpu_compile_op_common",
"//tensorflow/core:framework",
],
deps = ["//tensorflow/core:framework"],
)
cc_library(
@ -558,60 +553,3 @@ cc_library(
],
alwayslink = 1,
)
cc_library(
name = "tpu_compile_op_registration",
srcs = ["tpu_compile_op_registration.cc"],
deps = [
":tpu_compile_op_common",
":tpu_compile_op_impl",
":tpu_compile_op_impl_factory",
":tpu_compile_op_support",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
],
)
cc_library(
name = "tpu_compile_op_impl_factory",
srcs = ["tpu_compile_op_impl_factory.cc"],
hdrs = ["tpu_compile_op_impl_factory.h"],
deps = [
":tpu_compile_op_common",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/core:framework",
],
alwayslink = 1,
)
cc_library(
name = "tpu_compile_op_lib",
srcs = ["tpu_compile_op.cc"],
deps = [
":tpu_compile_op_hdrs",
":tpu_compile_op_impl_factory",
":tpu_compile_op_options",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
"//tensorflow/stream_executor/tpu:tpu_node_context",
],
alwayslink = 1,
)
cc_library(
name = "tpu_compile_op",
deps = [
":tpu_compile_op_hdrs",
":tpu_compile_op_impl",
":tpu_compile_op_impl_factory",
":tpu_compile_op_lib",
":tpu_compile_op_options",
":tpu_compile_op_registration",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
"//tensorflow/stream_executor/tpu:tpu_node_context",
],
alwayslink = 1,
)

View File

@ -1,94 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compile_op.h"
#include <string>
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
namespace tensorflow {
namespace tpu {
using ::stream_executor::port::StatusOr;
TpuCompileOp::TpuCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
(*GetTpuCompileOpCreateFn())(ctx);
OP_REQUIRES_OK(ctx, compile_op.status());
impl_ = std::move(compile_op.ValueOrDie());
}
void TpuCompileOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
TpuCompileMlirOp::TpuCompileMlirOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
(*GetTpuCompileOpMlirCreateFn())(ctx);
OP_REQUIRES_OK(ctx, compile_op.status());
impl_ = std::move(compile_op.ValueOrDie());
}
void TpuCompileMlirOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
void TpuCompileSucceededAssertOp::Compute(OpKernelContext* ctx) {
const Tensor compilation_result = ctx->input(0);
CompilationResultProto proto;
Status status;
if (!proto.ParseFromString(compilation_result.scalar<tstring>()())) {
status =
errors::InvalidArgument("Unable to parse compilation result proto");
}
if (!status.ok() || proto.status_code() != error::Code::OK) {
status.Update(Status(proto.status_code(), proto.status_error_message()));
errors::AppendToMessage(&status, "TPU compilation failed");
if (tensorflow::internal::TpuCompilationFailureClosesChips()) {
// At this point, if compilation fails we do not know if a task
// is already running that expects results from this compiled
// program to complete. So close the TPU driver to release all
// awaiting interactions (all awaiting interaction will fail and
// continue to fail until reinitialized).
LOG(ERROR) << "Cloud TPU: Closing chips. TPU compilation is considered "
"as part of device state, and a failed compilation results "
"in a device reset.";
Status close_status = TpuNodeContext::CloseTpuHost();
if (!close_status.ok()) {
errors::AppendToMessage(&status, close_status.error_message());
}
}
ctx->CtxFailureWithWarning(status);
}
}
REGISTER_MODULE_INITIALIZER(register_tpu_compile_op_impl, {
#if !defined(LIBTFTPU)
VLOG(1) << "register_tpu_compile_op_impl: TpuCompileOpKernelImpl";
SetTpuCompileOpCreateFn(CreateTpuCompileOpImpl);
SetTpuCompileOpMlirCreateFn(CreateTpuCompileOpMlirImpl);
#endif // LIBTFTPU
});
REGISTER_KERNEL_BUILDER(Name("TPUCompile").Device(DEVICE_CPU), TpuCompileOp);
REGISTER_KERNEL_BUILDER(Name("_TPUCompileMlir").Device(DEVICE_CPU),
TpuCompileMlirOp);
REGISTER_KERNEL_BUILDER(Name("TPUCompileSucceededAssert").Device(DEVICE_CPU),
TpuCompileSucceededAssertOp);
} // namespace tpu
} // namespace tensorflow

View File

@ -18,10 +18,18 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
namespace tensorflow {
namespace tpu {
// Forward declaration.
#if defined(LIBTFTPU)
class TpuCompileOpKernelImpl;
#else
namespace internal {
class TpuCompileOpKernelImpl;
}
#endif
} // namespace tpu
// The TPUCompile operator compiles a Tensorflow function into a
// TPU executable to be run by TPUExecute.
@ -34,9 +42,13 @@ class TpuCompileOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
private:
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
#if defined(LIBTFTPU)
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
#else
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
#endif
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
};
// The TPUCompile operator compiles a MLIR module into a
@ -50,9 +62,13 @@ class TpuCompileMlirOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
private:
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
#if defined(LIBTFTPU)
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
#else
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
#endif
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
};
class TpuCompileSucceededAssertOp : public OpKernel {
@ -64,10 +80,9 @@ class TpuCompileSucceededAssertOp : public OpKernel {
void Compute(OpKernelContext* ctx) override;
private:
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
};
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_

View File

@ -1,46 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
namespace tensorflow {
namespace tpu {
namespace {
static TpuCompileOpImplCreateFn* tpu_compile_op_impl_creation_fn =
new TpuCompileOpImplCreateFn(CreateTpuCompileOpImpl);
static TpuCompileOpImplCreateFn* tpu_compile_op_mlir_impl_creation_fn =
new TpuCompileOpImplCreateFn(CreateTpuCompileOpMlirImpl);
} // namespace
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn() {
return tpu_compile_op_impl_creation_fn;
}
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn() {
return tpu_compile_op_mlir_impl_creation_fn;
}
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn) {
VLOG(1) << "SetTpuCompileOpCreateFn.";
delete tpu_compile_op_impl_creation_fn;
tpu_compile_op_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
}
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn) {
VLOG(1) << "SetTpuCompileOpMlirCreateFn.";
delete tpu_compile_op_mlir_impl_creation_fn;
tpu_compile_op_mlir_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
}
} // namespace tpu
} // namespace tensorflow

View File

@ -1,55 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
#include <functional>
#include <memory>
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
namespace tensorflow {
namespace tpu {
typedef std::function<stream_executor::port::StatusOr<
std::unique_ptr<TpuCompileOpKernelCommon>>(OpKernelConstruction*)>
TpuCompileOpImplCreateFn;
// Creates the callback for creating `TpuCompileOpImpl` instance.
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
CreateTpuCompileOpImpl(OpKernelConstruction* ctx);
// Creates the callback for creating Mlir `TpuCompileOpImpl` instance.
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
CreateTpuCompileOpMlirImpl(OpKernelConstruction* ctx);
// Gets the callback for creating default `TpuCompileOpImpl` instance.
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn();
// Gets the callback for creating Mlir `TpuCompileOpImpl` instance.
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn();
// Sets the callback for creating default `TpuCompileOpImpl` instance.
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn);
// Sets the callback for creating Mlir `TpuCompileOpImpl` instance.
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn);
} // namespace tpu
} // namespace tensorflow
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_

View File

@ -1,52 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <string>
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
namespace tensorflow {
namespace tpu {
using ::stream_executor::port::StatusOr;
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpImpl(
OpKernelConstruction* ctx) {
NameAttrList function_name;
TPUCompileMetadataProto metadata;
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(ctx, &metadata,
&function_name,
/*mlir_module=*/nullptr));
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
return {std::make_unique<TpuCompileOpKernelImpl>(
function_name, metadata, metadata.num_cores_per_replica(),
/*return_hlo_protos=*/false,
/*unload_cache_on_session_close=*/false)};
}
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpMlirImpl(
OpKernelConstruction* ctx) {
TPUCompileMetadataProto metadata;
std::string mlir_module;
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(
ctx, &metadata, /*function_name=*/nullptr, &mlir_module));
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
return {std::make_unique<TpuCompileOpKernelImpl>(
mlir_module, metadata, metadata.num_cores_per_replica(),
/*return_hlo_protos=*/false,
/*unload_cache_on_session_close=*/false)};
}
} // namespace tpu
} // namespace tensorflow

View File

@ -16,28 +16,27 @@ limitations under the License.
#include "tensorflow/compiler/xla/debug_options_flags.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/computation_placer.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
#include "tensorflow/stream_executor/tpu/proto_helper.h"
namespace tensorflow {
namespace tpu {
using ::stream_executor::port::Status;
using ::stream_executor::port::StatusOr;
using ::xla::ComputationLayout;
using ::xla::DebugOptions;
using ::xla::DeviceAssignment;
using ::xla::HloModuleConfig;
using ::xla::HloSharding;
using ::xla::InvalidArgument;
using ::xla::ProgramShape;
using ::xla::Shape;
using ::xla::ShapeTree;
using ::xla::ShapeUtil;
using stream_executor::port::Status;
using stream_executor::port::StatusOr;
using xla::ComputationLayout;
using xla::DebugOptions;
using xla::DeviceAssignment;
using xla::HloModuleConfig;
using xla::HloSharding;
using xla::InvalidArgument;
using xla::ProgramShape;
using xla::Shape;
using xla::ShapeTree;
using xla::ShapeUtil;
Status ValidateResultShape(const Shape& client_shape,
const Shape& result_shape) {
@ -486,59 +485,5 @@ StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
return compilation_request;
}
Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
TPUCompileMetadataProto* metadata,
NameAttrList* function_name,
std::string* mlir_module) {
CHECK_NE(metadata, nullptr);
int num_computations;
TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
std::string metadata_string;
TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
if (!metadata->ParsePartialFromString(metadata_string)) {
return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
}
if (function_name != nullptr) {
TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
}
if (mlir_module != nullptr) {
TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
}
if (num_computations != metadata->num_cores_per_replica()) {
return errors::InvalidArgument(
"num_computations must be equal to "
"num_cores_per_replica in the 'metadata' "
"attribute (",
num_computations, " vs ", metadata->num_cores_per_replica(), ")");
}
if (metadata->has_device_assignment()) {
StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
DeviceAssignment::Deserialize(metadata->device_assignment());
TF_RETURN_IF_ERROR(device_assignment_or_error.status());
const DeviceAssignment& device_assignment =
*device_assignment_or_error.ValueOrDie();
const int num_replicas = metadata->num_replicas();
if (device_assignment.replica_count() != num_replicas) {
return errors::InvalidArgument(
"Device assignment replica_count != num_replicas; ",
device_assignment.replica_count(), " vs ", num_replicas);
}
if (device_assignment.computation_count() !=
metadata->num_cores_per_replica()) {
return errors::InvalidArgument(
"Device assignment computation_count != num_cores_per_replica; ",
device_assignment.computation_count(), " vs ",
metadata->num_cores_per_replica());
}
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -31,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/shape_tree.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.pb.h"
@ -155,10 +154,6 @@ se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
const TPUCompileMetadataProto& metadata,
const std::vector<TensorShape>& arg_shapes);
se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
TPUCompileMetadataProto* metadata,
NameAttrList* function_name,
std::string* mlir_module);
} // namespace tpu
} // namespace tensorflow