Prep change for publishing TPU Ops.

PiperOrigin-RevId: 318081188
Change-Id: Ic2facba026e4abc4766ad3466f782f1535abf243
This commit is contained in:
Henry Tan 2020-06-24 09:40:21 -07:00 committed by TensorFlower Gardener
parent 553ef2313e
commit cb7907d992
6 changed files with 50 additions and 3 deletions

View File

@ -28,10 +28,12 @@ cc_library(
srcs = ["tpu_compile_op_common.cc"], srcs = ["tpu_compile_op_common.cc"],
hdrs = ["tpu_compile_op_common.h"], hdrs = ["tpu_compile_op_common.h"],
deps = [ deps = [
":tpu_compile_op_options",
":tpu_compile_op_support", ":tpu_compile_op_support",
":tpu_mesh_state_interface", ":tpu_mesh_state_interface",
":tpu_program_group_interface", ":tpu_program_group_interface",
":tpu_util", ":tpu_util",
":tpu_util_c_api_hdrs",
":tpu_util_hdrs", ":tpu_util_hdrs",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:shape_inference",
@ -50,6 +52,7 @@ cc_library(
"//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface", "//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <string> #include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
@ -28,8 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_defs.h"
@ -518,5 +521,41 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
return Status::OK(); return Status::OK();
} }
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
if (TpuCompile_ShouldTpuCompileOpIgnoreCancellation()) {
return;
}
// Sleep and exit in another thread so the cancellation manager can
// continue running callbacks.
ctx->env()->SchedClosure([ctx, done]() { ExitCountdown(ctx, done); });
});
// If the RPC was cancelled before we registered the cancellation callback,
// don't compile the TPU program.
OP_REQUIRES(ctx, !already_cancelled,
errors::Cancelled("RPC cancelled, not compiling TPU program"));
// We only want to abort the process if a cancellation actually occurs during
// compilation; we must deregister the callback in the success case. It
// doesn't hurt to also deregister the callback in the failure case; the
// CancellationManager ensures that already-registered callbacks will be run
// once cancellation has started.
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
ctx->cancellation_manager()->DeregisterCallback(token);
done->store(true);
});
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
}
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -53,7 +53,8 @@ class TpuCompileOpKernelCommon {
virtual ~TpuCompileOpKernelCommon() = default; virtual ~TpuCompileOpKernelCommon() = default;
virtual void Compute(OpKernelContext* ctx) = 0; void Compute(OpKernelContext* ctx);
virtual Status ComputeInternal(OpKernelContext* ctx) = 0;
// Computes shapes for each argument. Uses both the static shape from the // Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not // metadata, and the dynamic shapes where the static shape is not

View File

@ -95,6 +95,5 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
} }
return Status::OK(); return Status::OK();
} }
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -68,7 +68,6 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
// A callback called on exit. // A callback called on exit.
void LogAndExit(int code); void LogAndExit(int code);
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -31,6 +31,12 @@ void TpuCompile_ToTpuShapeRepresentation(
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape, bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
SE_Status* status); SE_Status* status);
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
// when cancellation is requested for an XLA compile op. Some tests require this
// behavior to be disabled, and we test for this condition with the following
// flag function.
bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
} // extern "C" } // extern "C"
struct TfTpu_UtilApiFn { struct TfTpu_UtilApiFn {