Add TpuCompileOpImpl
to core/tpu/kernels
.
PiperOrigin-RevId: 321069369 Change-Id: Ic276b3e5a8b9dd8c8708a0ba4a2142bd76a2a9e4
This commit is contained in:
parent
1e1bcbbf80
commit
8b52123ca3
@ -508,3 +508,29 @@ cc_library(
|
|||||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "tpu_compile_op_impl",
|
||||||
|
srcs = ["tpu_compile_op_impl.cc"],
|
||||||
|
hdrs = ["tpu_compile_op_impl.h"],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/xla:status",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compilation_cache_key",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compile_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compile_op_common",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compile_op_support",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_compile_proto_cc",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_mesh_state_c_api_hdrs",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_program_group",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_program_group_interface",
|
||||||
|
"//tensorflow/core/tpu/kernels:tpu_util",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_executor",
|
||||||
|
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||||
|
"@com_google_absl//absl/types:variant",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
39
tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc
Normal file
39
tensorflow/core/tpu/kernels/tpu_compile_op_impl.cc
Normal file
@ -0,0 +1,39 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_impl.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/status.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_program_group.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
Status TpuCompileOpKernelImpl::Compile(
|
||||||
|
const std::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
|
||||||
|
const XLA_TpuMeshState* mesh_state,
|
||||||
|
const std::vector<TensorShape>& arg_shapes,
|
||||||
|
TpuProgramGroupInterface* tpu_program_group) {
|
||||||
|
TF_ASSIGN_OR_RETURN(
|
||||||
|
TpuCompilationRequestProto compilation_request,
|
||||||
|
CreateTpuCompilationRequest(computation, metadata_, arg_shapes));
|
||||||
|
|
||||||
|
return TpuProgramGroup::CompileAndBuild(compilation_request, mesh_state,
|
||||||
|
tpu_program_group);
|
||||||
|
}
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
66
tensorflow/core/tpu/kernels/tpu_compile_op_impl.h
Normal file
66
tensorflow/core/tpu/kernels/tpu_compile_op_impl.h
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_
|
||||||
|
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "absl/types/variant.h"
|
||||||
|
#include "tensorflow/compiler/jit/shape_inference.h"
|
||||||
|
#include "tensorflow/core/framework/function.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_compile_op_common.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||||
|
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace tpu {
|
||||||
|
|
||||||
|
// Base class for TpuCompileOp and TpuCompileMlirOp.
|
||||||
|
// Depends on whether it is given a computation in the form of serialized MLIR
|
||||||
|
// module or a Tensorflow function, TpuCompileOpKernelImpl converts computation
|
||||||
|
// into XLA HLO and then into a TPU execuable binary.
|
||||||
|
class TpuCompileOpKernelImpl : public TpuCompileOpKernelCommon {
|
||||||
|
public:
|
||||||
|
TpuCompileOpKernelImpl(const std::string& mlir_module,
|
||||||
|
const tpu::TPUCompileMetadataProto& metadata,
|
||||||
|
int num_computations, bool return_hlo_protos,
|
||||||
|
bool unload_cache_on_session_close)
|
||||||
|
: TpuCompileOpKernelCommon(mlir_module, metadata, num_computations,
|
||||||
|
return_hlo_protos,
|
||||||
|
unload_cache_on_session_close) {}
|
||||||
|
|
||||||
|
TpuCompileOpKernelImpl(const NameAttrList& function,
|
||||||
|
const tpu::TPUCompileMetadataProto& metadata,
|
||||||
|
int num_computations, bool return_hlo_protos,
|
||||||
|
bool unload_cache_on_session_close)
|
||||||
|
: TpuCompileOpKernelCommon(
|
||||||
|
function, metadata, num_computations, return_hlo_protos,
|
||||||
|
unload_cache_on_session_close, /*persistent_cache=*/nullptr) {}
|
||||||
|
|
||||||
|
private:
|
||||||
|
FRIEND_TEST(TpuCompileOpImplTest, Compile);
|
||||||
|
|
||||||
|
Status Compile(
|
||||||
|
const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
|
||||||
|
const XLA_TpuMeshState* mesh_state,
|
||||||
|
const std::vector<TensorShape>& arg_shapes,
|
||||||
|
TpuProgramGroupInterface* tpu_program_group) override;
|
||||||
|
};
|
||||||
|
} // namespace tpu
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_IMPL_H_
|
Loading…
x
Reference in New Issue
Block a user