Add TpuCompileOp implementation.
PiperOrigin-RevId: 321717418 Change-Id: I4e0fb203014c54252511d6596063461dcc5de250
This commit is contained in:
parent
d7918bcd43
commit
94d2ab31f8
@ -178,14 +178,12 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:computation_layout",
|
"//tensorflow/compiler/xla/service:computation_layout",
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
|
||||||
"//tensorflow/compiler/xla/service:dump",
|
"//tensorflow/compiler/xla/service:dump",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_module_config",
|
"//tensorflow/compiler/xla/service:hlo_module_config",
|
||||||
"//tensorflow/compiler/xla/service:hlo_module_group",
|
"//tensorflow/compiler/xla/service:hlo_module_group",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core/framework:protos_all_cc",
|
"//tensorflow/core/framework:protos_all_cc",
|
||||||
"//tensorflow/core/platform:errors",
|
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
||||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
@ -496,10 +494,7 @@ tf_proto_library_cc(
|
|||||||
cc_library(
|
cc_library(
|
||||||
name = "tpu_compile_op_hdrs",
|
name = "tpu_compile_op_hdrs",
|
||||||
hdrs = ["tpu_compile_op.h"],
|
hdrs = ["tpu_compile_op.h"],
|
||||||
deps = [
|
deps = ["//tensorflow/core:framework"],
|
||||||
":tpu_compile_op_common",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
@ -558,60 +553,3 @@ cc_library(
|
|||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compile_op_registration",
|
|
||||||
srcs = ["tpu_compile_op_registration.cc"],
|
|
||||||
deps = [
|
|
||||||
":tpu_compile_op_common",
|
|
||||||
":tpu_compile_op_impl",
|
|
||||||
":tpu_compile_op_impl_factory",
|
|
||||||
":tpu_compile_op_support",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/core:protos_all_cc",
|
|
||||||
"//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compile_op_impl_factory",
|
|
||||||
srcs = ["tpu_compile_op_impl_factory.cc"],
|
|
||||||
hdrs = ["tpu_compile_op_impl_factory.h"],
|
|
||||||
deps = [
|
|
||||||
":tpu_compile_op_common",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/compiler/xla/service:computation_placer",
|
|
||||||
"//tensorflow/core:framework",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compile_op_lib",
|
|
||||||
srcs = ["tpu_compile_op.cc"],
|
|
||||||
deps = [
|
|
||||||
":tpu_compile_op_hdrs",
|
|
||||||
":tpu_compile_op_impl_factory",
|
|
||||||
":tpu_compile_op_options",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "tpu_compile_op",
|
|
||||||
deps = [
|
|
||||||
":tpu_compile_op_hdrs",
|
|
||||||
":tpu_compile_op_impl",
|
|
||||||
":tpu_compile_op_impl_factory",
|
|
||||||
":tpu_compile_op_lib",
|
|
||||||
":tpu_compile_op_options",
|
|
||||||
":tpu_compile_op_registration",
|
|
||||||
"//tensorflow/compiler/xla:statusor",
|
|
||||||
"//tensorflow/core/protobuf/tpu:compilation_result_proto_cc",
|
|
||||||
"//tensorflow/stream_executor/tpu:tpu_node_context",
|
|
||||||
],
|
|
||||||
alwayslink = 1,
|
|
||||||
)
|
|
||||||
|
@ -1,94 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op.h"
|
|
||||||
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
|
||||||
#include "tensorflow/core/protobuf/tpu/compilation_result.pb.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
|
||||||
#include "tensorflow/stream_executor/tpu/tpu_node_context.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
using ::stream_executor::port::StatusOr;
|
|
||||||
|
|
||||||
TpuCompileOp::TpuCompileOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
|
||||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
|
|
||||||
(*GetTpuCompileOpCreateFn())(ctx);
|
|
||||||
OP_REQUIRES_OK(ctx, compile_op.status());
|
|
||||||
impl_ = std::move(compile_op.ValueOrDie());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TpuCompileOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
|
|
||||||
|
|
||||||
TpuCompileMlirOp::TpuCompileMlirOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
|
||||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> compile_op =
|
|
||||||
(*GetTpuCompileOpMlirCreateFn())(ctx);
|
|
||||||
OP_REQUIRES_OK(ctx, compile_op.status());
|
|
||||||
impl_ = std::move(compile_op.ValueOrDie());
|
|
||||||
}
|
|
||||||
|
|
||||||
void TpuCompileMlirOp::Compute(OpKernelContext* ctx) { impl_->Compute(ctx); }
|
|
||||||
|
|
||||||
void TpuCompileSucceededAssertOp::Compute(OpKernelContext* ctx) {
|
|
||||||
const Tensor compilation_result = ctx->input(0);
|
|
||||||
CompilationResultProto proto;
|
|
||||||
Status status;
|
|
||||||
if (!proto.ParseFromString(compilation_result.scalar<tstring>()())) {
|
|
||||||
status =
|
|
||||||
errors::InvalidArgument("Unable to parse compilation result proto");
|
|
||||||
}
|
|
||||||
if (!status.ok() || proto.status_code() != error::Code::OK) {
|
|
||||||
status.Update(Status(proto.status_code(), proto.status_error_message()));
|
|
||||||
errors::AppendToMessage(&status, "TPU compilation failed");
|
|
||||||
if (tensorflow::internal::TpuCompilationFailureClosesChips()) {
|
|
||||||
// At this point, if compilation fails we do not know if a task
|
|
||||||
// is already running that expects results from this compiled
|
|
||||||
// program to complete. So close the TPU driver to release all
|
|
||||||
// awaiting interactions (all awaiting interaction will fail and
|
|
||||||
// continue to fail until reinitialized).
|
|
||||||
LOG(ERROR) << "Cloud TPU: Closing chips. TPU compilation is considered "
|
|
||||||
"as part of device state, and a failed compilation results "
|
|
||||||
"in a device reset.";
|
|
||||||
|
|
||||||
Status close_status = TpuNodeContext::CloseTpuHost();
|
|
||||||
|
|
||||||
if (!close_status.ok()) {
|
|
||||||
errors::AppendToMessage(&status, close_status.error_message());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ctx->CtxFailureWithWarning(status);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
REGISTER_MODULE_INITIALIZER(register_tpu_compile_op_impl, {
|
|
||||||
#if !defined(LIBTFTPU)
|
|
||||||
VLOG(1) << "register_tpu_compile_op_impl: TpuCompileOpKernelImpl";
|
|
||||||
SetTpuCompileOpCreateFn(CreateTpuCompileOpImpl);
|
|
||||||
SetTpuCompileOpMlirCreateFn(CreateTpuCompileOpMlirImpl);
|
|
||||||
#endif // LIBTFTPU
|
|
||||||
});
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("TPUCompile").Device(DEVICE_CPU), TpuCompileOp);
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_TPUCompileMlir").Device(DEVICE_CPU),
|
|
||||||
TpuCompileMlirOp);
|
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("TPUCompileSucceededAssert").Device(DEVICE_CPU),
|
|
||||||
TpuCompileSucceededAssertOp);
|
|
||||||
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
@ -18,10 +18,18 @@ limitations under the License.
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
|
// Forward declaration.
|
||||||
|
#if defined(LIBTFTPU)
|
||||||
|
class TpuCompileOpKernelImpl;
|
||||||
|
#else
|
||||||
|
namespace internal {
|
||||||
|
class TpuCompileOpKernelImpl;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
} // namespace tpu
|
||||||
|
|
||||||
// The TPUCompile operator compiles a Tensorflow function into a
|
// The TPUCompile operator compiles a Tensorflow function into a
|
||||||
// TPU executable to be run by TPUExecute.
|
// TPU executable to be run by TPUExecute.
|
||||||
@ -34,9 +42,13 @@ class TpuCompileOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override;
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
|
#if defined(LIBTFTPU)
|
||||||
|
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
|
||||||
|
#else
|
||||||
|
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
|
||||||
|
#endif
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
|
DISALLOW_COPY_AND_ASSIGN(TpuCompileOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
// The TPUCompile operator compiles a MLIR module into a
|
// The TPUCompile operator compiles a MLIR module into a
|
||||||
@ -50,9 +62,13 @@ class TpuCompileMlirOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override;
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::unique_ptr<TpuCompileOpKernelCommon> impl_;
|
#if defined(LIBTFTPU)
|
||||||
|
std::unique_ptr<tpu::TpuCompileOpKernelImpl> impl_;
|
||||||
|
#else
|
||||||
|
std::unique_ptr<tpu::internal::TpuCompileOpKernelImpl> impl_;
|
||||||
|
#endif
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
|
DISALLOW_COPY_AND_ASSIGN(TpuCompileMlirOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
class TpuCompileSucceededAssertOp : public OpKernel {
|
class TpuCompileSucceededAssertOp : public OpKernel {
|
||||||
@ -64,10 +80,9 @@ class TpuCompileSucceededAssertOp : public OpKernel {
|
|||||||
void Compute(OpKernelContext* ctx) override;
|
void Compute(OpKernelContext* ctx) override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
|
DISALLOW_COPY_AND_ASSIGN(TpuCompileSucceededAssertOp);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_H_
|
||||||
|
@ -1,46 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
namespace {
|
|
||||||
static TpuCompileOpImplCreateFn* tpu_compile_op_impl_creation_fn =
|
|
||||||
new TpuCompileOpImplCreateFn(CreateTpuCompileOpImpl);
|
|
||||||
static TpuCompileOpImplCreateFn* tpu_compile_op_mlir_impl_creation_fn =
|
|
||||||
new TpuCompileOpImplCreateFn(CreateTpuCompileOpMlirImpl);
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn() {
|
|
||||||
return tpu_compile_op_impl_creation_fn;
|
|
||||||
}
|
|
||||||
|
|
||||||
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn() {
|
|
||||||
return tpu_compile_op_mlir_impl_creation_fn;
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn) {
|
|
||||||
VLOG(1) << "SetTpuCompileOpCreateFn.";
|
|
||||||
delete tpu_compile_op_impl_creation_fn;
|
|
||||||
tpu_compile_op_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
|
|
||||||
}
|
|
||||||
|
|
||||||
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn) {
|
|
||||||
VLOG(1) << "SetTpuCompileOpMlirCreateFn.";
|
|
||||||
delete tpu_compile_op_mlir_impl_creation_fn;
|
|
||||||
tpu_compile_op_mlir_impl_creation_fn = new TpuCompileOpImplCreateFn(fn);
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
@ -1,55 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
|
||||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
|
||||||
#include "tensorflow/core/framework/op.h"
|
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
|
|
||||||
typedef std::function<stream_executor::port::StatusOr<
|
|
||||||
std::unique_ptr<TpuCompileOpKernelCommon>>(OpKernelConstruction*)>
|
|
||||||
TpuCompileOpImplCreateFn;
|
|
||||||
|
|
||||||
// Creates the callback for creating `TpuCompileOpImpl` instance.
|
|
||||||
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
|
|
||||||
CreateTpuCompileOpImpl(OpKernelConstruction* ctx);
|
|
||||||
|
|
||||||
// Creates the callback for creating Mlir `TpuCompileOpImpl` instance.
|
|
||||||
stream_executor::port::StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>>
|
|
||||||
CreateTpuCompileOpMlirImpl(OpKernelConstruction* ctx);
|
|
||||||
|
|
||||||
// Gets the callback for creating default `TpuCompileOpImpl` instance.
|
|
||||||
TpuCompileOpImplCreateFn* GetTpuCompileOpCreateFn();
|
|
||||||
|
|
||||||
// Gets the callback for creating Mlir `TpuCompileOpImpl` instance.
|
|
||||||
TpuCompileOpImplCreateFn* GetTpuCompileOpMlirCreateFn();
|
|
||||||
|
|
||||||
// Sets the callback for creating default `TpuCompileOpImpl` instance.
|
|
||||||
void SetTpuCompileOpCreateFn(TpuCompileOpImplCreateFn fn);
|
|
||||||
|
|
||||||
// Sets the callback for creating Mlir `TpuCompileOpImpl` instance.
|
|
||||||
void SetTpuCompileOpMlirCreateFn(TpuCompileOpImplCreateFn fn);
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
||||||
|
|
||||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_FACTORY_H_
|
|
@ -1,52 +0,0 @@
|
|||||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
==============================================================================*/
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
|
||||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl_factory.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
|
||||||
|
|
||||||
namespace tensorflow {
|
|
||||||
namespace tpu {
|
|
||||||
using ::stream_executor::port::StatusOr;
|
|
||||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpImpl(
|
|
||||||
OpKernelConstruction* ctx) {
|
|
||||||
NameAttrList function_name;
|
|
||||||
TPUCompileMetadataProto metadata;
|
|
||||||
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(ctx, &metadata,
|
|
||||||
&function_name,
|
|
||||||
/*mlir_module=*/nullptr));
|
|
||||||
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
|
|
||||||
return {std::make_unique<TpuCompileOpKernelImpl>(
|
|
||||||
function_name, metadata, metadata.num_cores_per_replica(),
|
|
||||||
/*return_hlo_protos=*/false,
|
|
||||||
/*unload_cache_on_session_close=*/false)};
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<TpuCompileOpKernelCommon>> CreateTpuCompileOpMlirImpl(
|
|
||||||
OpKernelConstruction* ctx) {
|
|
||||||
TPUCompileMetadataProto metadata;
|
|
||||||
std::string mlir_module;
|
|
||||||
TF_RETURN_IF_ERROR(CompileOpMetadataFromContext(
|
|
||||||
ctx, &metadata, /*function_name=*/nullptr, &mlir_module));
|
|
||||||
VLOG(1) << "Create tensorflow::tpu::TpuCompileOpKernelImpl";
|
|
||||||
return {std::make_unique<TpuCompileOpKernelImpl>(
|
|
||||||
mlir_module, metadata, metadata.num_cores_per_replica(),
|
|
||||||
/*return_hlo_protos=*/false,
|
|
||||||
/*unload_cache_on_session_close=*/false)};
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
|
||||||
} // namespace tensorflow
|
|
@ -16,28 +16,27 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
#include "tensorflow/compiler/xla/service/computation_layout.h"
|
||||||
#include "tensorflow/compiler/xla/service/computation_placer.h"
|
|
||||||
#include "tensorflow/compiler/xla/service/dump.h"
|
#include "tensorflow/compiler/xla/service/dump.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
#include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
|
||||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
namespace tpu {
|
namespace tpu {
|
||||||
using ::stream_executor::port::Status;
|
|
||||||
using ::stream_executor::port::StatusOr;
|
using stream_executor::port::Status;
|
||||||
using ::xla::ComputationLayout;
|
using stream_executor::port::StatusOr;
|
||||||
using ::xla::DebugOptions;
|
using xla::ComputationLayout;
|
||||||
using ::xla::DeviceAssignment;
|
using xla::DebugOptions;
|
||||||
using ::xla::HloModuleConfig;
|
using xla::DeviceAssignment;
|
||||||
using ::xla::HloSharding;
|
using xla::HloModuleConfig;
|
||||||
using ::xla::InvalidArgument;
|
using xla::HloSharding;
|
||||||
using ::xla::ProgramShape;
|
using xla::InvalidArgument;
|
||||||
using ::xla::Shape;
|
using xla::ProgramShape;
|
||||||
using ::xla::ShapeTree;
|
using xla::Shape;
|
||||||
using ::xla::ShapeUtil;
|
using xla::ShapeTree;
|
||||||
|
using xla::ShapeUtil;
|
||||||
|
|
||||||
Status ValidateResultShape(const Shape& client_shape,
|
Status ValidateResultShape(const Shape& client_shape,
|
||||||
const Shape& result_shape) {
|
const Shape& result_shape) {
|
||||||
@ -486,59 +485,5 @@ StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
|
|||||||
VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
|
VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
|
||||||
return compilation_request;
|
return compilation_request;
|
||||||
}
|
}
|
||||||
|
|
||||||
Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|
||||||
TPUCompileMetadataProto* metadata,
|
|
||||||
NameAttrList* function_name,
|
|
||||||
std::string* mlir_module) {
|
|
||||||
CHECK_NE(metadata, nullptr);
|
|
||||||
|
|
||||||
int num_computations;
|
|
||||||
TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
|
|
||||||
|
|
||||||
std::string metadata_string;
|
|
||||||
TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
|
|
||||||
if (!metadata->ParsePartialFromString(metadata_string)) {
|
|
||||||
return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (function_name != nullptr) {
|
|
||||||
TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mlir_module != nullptr) {
|
|
||||||
TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
|
|
||||||
}
|
|
||||||
|
|
||||||
if (num_computations != metadata->num_cores_per_replica()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"num_computations must be equal to "
|
|
||||||
"num_cores_per_replica in the 'metadata' "
|
|
||||||
"attribute (",
|
|
||||||
num_computations, " vs ", metadata->num_cores_per_replica(), ")");
|
|
||||||
}
|
|
||||||
|
|
||||||
if (metadata->has_device_assignment()) {
|
|
||||||
StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
|
|
||||||
DeviceAssignment::Deserialize(metadata->device_assignment());
|
|
||||||
TF_RETURN_IF_ERROR(device_assignment_or_error.status());
|
|
||||||
const DeviceAssignment& device_assignment =
|
|
||||||
*device_assignment_or_error.ValueOrDie();
|
|
||||||
const int num_replicas = metadata->num_replicas();
|
|
||||||
if (device_assignment.replica_count() != num_replicas) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Device assignment replica_count != num_replicas; ",
|
|
||||||
device_assignment.replica_count(), " vs ", num_replicas);
|
|
||||||
}
|
|
||||||
if (device_assignment.computation_count() !=
|
|
||||||
metadata->num_cores_per_replica()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Device assignment computation_count != num_cores_per_replica; ",
|
|
||||||
device_assignment.computation_count(), " vs ",
|
|
||||||
metadata->num_cores_per_replica());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -31,7 +31,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/shape_tree.h"
|
#include "tensorflow/compiler/xla/shape_tree.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
|
||||||
#include "tensorflow/core/framework/function.h"
|
#include "tensorflow/core/framework/function.h"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/tensor.pb.h"
|
#include "tensorflow/core/framework/tensor.pb.h"
|
||||||
@ -155,10 +154,6 @@ se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
|
|||||||
const TPUCompileMetadataProto& metadata,
|
const TPUCompileMetadataProto& metadata,
|
||||||
const std::vector<TensorShape>& arg_shapes);
|
const std::vector<TensorShape>& arg_shapes);
|
||||||
|
|
||||||
se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
|
|
||||||
TPUCompileMetadataProto* metadata,
|
|
||||||
NameAttrList* function_name,
|
|
||||||
std::string* mlir_module);
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user