Prep change for publishing TPU Ops.
PiperOrigin-RevId: 318081188 Change-Id: Ic2facba026e4abc4766ad3466f782f1535abf243
This commit is contained in:
parent
553ef2313e
commit
cb7907d992
@ -28,10 +28,12 @@ cc_library(
|
|||||||
srcs = ["tpu_compile_op_common.cc"],
|
srcs = ["tpu_compile_op_common.cc"],
|
||||||
hdrs = ["tpu_compile_op_common.h"],
|
hdrs = ["tpu_compile_op_common.h"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":tpu_compile_op_options",
|
||||||
":tpu_compile_op_support",
|
":tpu_compile_op_support",
|
||||||
":tpu_mesh_state_interface",
|
":tpu_mesh_state_interface",
|
||||||
":tpu_program_group_interface",
|
":tpu_program_group_interface",
|
||||||
":tpu_util",
|
":tpu_util",
|
||||||
|
":tpu_util_c_api_hdrs",
|
||||||
":tpu_util_hdrs",
|
":tpu_util_hdrs",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
@ -50,6 +52,7 @@ cc_library(
|
|||||||
"//tensorflow/core/tpu:tpu_configuration",
|
"//tensorflow/core/tpu:tpu_configuration",
|
||||||
"//tensorflow/core/tpu:tpu_defs",
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
|
#include "absl/strings/string_view.h"
|
||||||
#include "tensorflow/compiler/jit/flags.h"
|
#include "tensorflow/compiler/jit/flags.h"
|
||||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
@ -28,8 +29,10 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/lib/core/errors.h"
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||||
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||||
|
|
||||||
@ -518,5 +521,41 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
|||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
|
||||||
|
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
|
||||||
|
|
||||||
|
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
|
||||||
|
|
||||||
|
CancellationToken token =
|
||||||
|
ctx->cancellation_manager()->get_cancellation_token();
|
||||||
|
const bool already_cancelled =
|
||||||
|
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
|
||||||
|
if (TpuCompile_ShouldTpuCompileOpIgnoreCancellation()) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sleep and exit in another thread so the cancellation manager can
|
||||||
|
// continue running callbacks.
|
||||||
|
ctx->env()->SchedClosure([ctx, done]() { ExitCountdown(ctx, done); });
|
||||||
|
});
|
||||||
|
|
||||||
|
// If the RPC was cancelled before we registered the cancellation callback,
|
||||||
|
// don't compile the TPU program.
|
||||||
|
OP_REQUIRES(ctx, !already_cancelled,
|
||||||
|
errors::Cancelled("RPC cancelled, not compiling TPU program"));
|
||||||
|
|
||||||
|
// We only want to abort the process if a cancellation actually occurs during
|
||||||
|
// compilation; we must deregister the callback in the success case. It
|
||||||
|
// doesn't hurt to also deregister the callback in the failure case; the
|
||||||
|
// CancellationManager ensures that already-registered callbacks will be run
|
||||||
|
// once cancellation has started.
|
||||||
|
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
|
||||||
|
ctx->cancellation_manager()->DeregisterCallback(token);
|
||||||
|
done->store(true);
|
||||||
|
});
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -53,7 +53,8 @@ class TpuCompileOpKernelCommon {
|
|||||||
|
|
||||||
virtual ~TpuCompileOpKernelCommon() = default;
|
virtual ~TpuCompileOpKernelCommon() = default;
|
||||||
|
|
||||||
virtual void Compute(OpKernelContext* ctx) = 0;
|
void Compute(OpKernelContext* ctx);
|
||||||
|
virtual Status ComputeInternal(OpKernelContext* ctx) = 0;
|
||||||
|
|
||||||
// Computes shapes for each argument. Uses both the static shape from the
|
// Computes shapes for each argument. Uses both the static shape from the
|
||||||
// metadata, and the dynamic shapes where the static shape is not
|
// metadata, and the dynamic shapes where the static shape is not
|
||||||
|
@ -95,6 +95,5 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
|||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -68,7 +68,6 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
|
|||||||
|
|
||||||
// A callback called on exit.
|
// A callback called on exit.
|
||||||
void LogAndExit(int code);
|
void LogAndExit(int code);
|
||||||
|
|
||||||
} // namespace tpu
|
} // namespace tpu
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -31,6 +31,12 @@ void TpuCompile_ToTpuShapeRepresentation(
|
|||||||
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
|
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
|
||||||
SE_Status* status);
|
SE_Status* status);
|
||||||
|
|
||||||
|
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
|
||||||
|
// when cancellation is requested for an XLA compile op. Some tests require this
|
||||||
|
// behavior to be disabled, and we test for this condition with the following
|
||||||
|
// flag function.
|
||||||
|
bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
|
||||||
|
|
||||||
} // extern "C"
|
} // extern "C"
|
||||||
|
|
||||||
struct TfTpu_UtilApiFn {
|
struct TfTpu_UtilApiFn {
|
||||||
|
Loading…
Reference in New Issue
Block a user