diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index dfa37348c8d..b256790a0fb 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -508,3 +508,29 @@ cc_library( "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", ], ) + +cc_library( + name = "tpu_compile_op_impl", + srcs = ["tpu_compile_op_impl.cc"], + hdrs = ["tpu_compile_op_impl.h"], + deps = [ + "//tensorflow/compiler/jit:shape_inference", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/xla:status", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core/tpu/kernels:tpu_compilation_cache_key", + "//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs", + "//tensorflow/core/tpu/kernels:tpu_compile_op_common", + "//tensorflow/core/tpu/kernels:tpu_compile_op_support", + "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", + "//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs", + "//tensorflow/core/tpu/kernels:tpu_program_group", + "//tensorflow/core/tpu/kernels:tpu_program_group_interface", + "//tensorflow/core/tpu/kernels:tpu_util", + "//tensorflow/stream_executor/tpu:tpu_executor", + "//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs", + "@com_google_absl//absl/types:variant", + ], + alwayslink = 1, +) diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc new file mode 100644 index 00000000000..0d514997142 --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc @@ -0,0 +1,39 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl.h" + +#include "tensorflow/compiler/xla/status.h" +#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { +Status TpuCompileOpKernelImpl::Compile( + const std::variant& computation, + const XLA_TpuMeshState* mesh_state, + const std::vector& arg_shapes, + TpuProgramGroupInterface* tpu_program_group) { + TF_ASSIGN_OR_RETURN( + TpuCompilationRequestProto compilation_request, + CreateTpuCompilationRequest(computation, metadata_, arg_shapes)); + + return TpuProgramGroup::CompileAndBuild(compilation_request, mesh_state, + tpu_program_group); +} +} // namespace tpu +} // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h new file mode 100644 index 00000000000..cd8ef78614a --- /dev/null +++ b/tensorflow/core/tpu/kernels/tpu_compile_op_impl.h @@ -0,0 +1,66 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ +#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_ + +#include +#include + +#include "absl/types/variant.h" +#include "tensorflow/compiler/jit/shape_inference.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" +#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h" +#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h" +#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" + +namespace tensorflow { +namespace tpu { + +// Base class for TpuCompileOp and TpuCompileMlirOp. +// Depends on whether it is given a computation in the form of serialized MLIR +// module or a Tensorflow function, TpuCompileOpKernelImpl converts computation +// into XLA HLO and then into a TPU execuable binary. +class TpuCompileOpKernelImpl : public TpuCompileOpKernelCommon { + public: + TpuCompileOpKernelImpl(const std::string& mlir_module, + const tpu::TPUCompileMetadataProto& metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close) + : TpuCompileOpKernelCommon(mlir_module, metadata, num_computations, + return_hlo_protos, + unload_cache_on_session_close) {} + + TpuCompileOpKernelImpl(const NameAttrList& function, + const tpu::TPUCompileMetadataProto& metadata, + int num_computations, bool return_hlo_protos, + bool unload_cache_on_session_close) + : TpuCompileOpKernelCommon( + function, metadata, num_computations, return_hlo_protos, + unload_cache_on_session_close, /*persistent_cache=*/nullptr) {} + + private: + FRIEND_TEST(TpuCompileOpImplTest, Compile); + + Status Compile( + const absl::variant& computation, + const XLA_TpuMeshState* mesh_state, + const std::vector& arg_shapes, + TpuProgramGroupInterface* tpu_program_group) override; +}; +} // namespace tpu +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_