[TF2XLA] [NFC] Break apart the [TF2XLA/MLIR] -> xla_compiler dependency edge
This is needed for invoking the MLIR tf2xla bridge from xla_compiler. This CL breaks apart items from xla_compiler into individual build targets, which are then depended on from the MLIR TF bridge. PiperOrigin-RevId: 323640340 Change-Id: I78b972503db9e7b5254014ca7e889005490d8339
This commit is contained in:
parent
5198b44674
commit
bcfb60d0a1
@ -308,6 +308,8 @@ cc_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
|
@ -95,6 +95,7 @@ cc_library(
|
|||||||
":xla_kernel_creator", # buildcleaner: keep
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
@ -115,6 +116,7 @@ cc_library(
|
|||||||
":xla_kernel_creator", # buildcleaner: keep
|
":xla_kernel_creator", # buildcleaner: keep
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
@ -172,6 +174,7 @@ XLA_DEVICE_DEPS = [
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
@ -343,6 +346,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla/client:client_library",
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
@ -406,6 +410,7 @@ cc_library(
|
|||||||
":compilation_passes",
|
":compilation_passes",
|
||||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
@ -641,6 +646,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
||||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
@ -700,6 +706,7 @@ cc_library(
|
|||||||
hdrs = ["device_util.h"],
|
hdrs = ["device_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -914,6 +921,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/jit/graphcycles",
|
"//tensorflow/compiler/jit/graphcycles",
|
||||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
|
@ -21,6 +21,7 @@ XLA_OPS_DEPS = [
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla:executable_run_options",
|
"//tensorflow/compiler/xla:executable_run_options",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
@ -150,6 +150,7 @@ tf_cc_binary(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:tensorflow",
|
"//tensorflow/core:tensorflow",
|
||||||
|
@ -1477,10 +1477,13 @@ COMPILE_MLIR_UTIL_DEPS = [
|
|||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_argument",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"//tensorflow/core/platform:logging",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core/platform:logging",
|
|
||||||
"//tensorflow/stream_executor/lib",
|
"//tensorflow/stream_executor/lib",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
@ -83,7 +83,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
|||||||
Status GetXlaInputShapes(
|
Status GetXlaInputShapes(
|
||||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
bool use_tuple_args,
|
bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
std::vector<xla::Shape>* xla_input_shapes) {
|
std::vector<xla::Shape>* xla_input_shapes) {
|
||||||
xla_input_shapes->clear();
|
xla_input_shapes->clear();
|
||||||
|
|
||||||
@ -135,9 +135,8 @@ Status GetXlaInputShapes(
|
|||||||
// output based on static shapes in MLIR module
|
// output based on static shapes in MLIR module
|
||||||
Status GetOutputInfo(
|
Status GetOutputInfo(
|
||||||
mlir::ModuleOp module,
|
mlir::ModuleOp module,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
xla::Shape* xla_output_shape,
|
xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) {
|
||||||
std::vector<XlaCompiler::OutputDescription>* outputs) {
|
|
||||||
auto shape_representation_fn_no_fast_memory =
|
auto shape_representation_fn_no_fast_memory =
|
||||||
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
|
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
|
||||||
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
||||||
@ -161,7 +160,7 @@ Status GetOutputInfo(
|
|||||||
|
|
||||||
// Construct OutputDescription for result.
|
// Construct OutputDescription for result.
|
||||||
outputs->emplace_back();
|
outputs->emplace_back();
|
||||||
XlaCompiler::OutputDescription& out_desc = outputs->back();
|
XlaOutputDescription& out_desc = outputs->back();
|
||||||
TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
|
TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
|
||||||
// TODO(ycao): Support constant output.
|
// TODO(ycao): Support constant output.
|
||||||
out_desc.is_constant = false;
|
out_desc.is_constant = false;
|
||||||
@ -185,7 +184,7 @@ Status GetOutputInfo(
|
|||||||
// TODO(ycao): Implement logic to compute resource updates when we need to
|
// TODO(ycao): Implement logic to compute resource updates when we need to
|
||||||
// support graphs with resource updates in MLIR-based TF compiler bridge.
|
// support graphs with resource updates in MLIR-based TF compiler bridge.
|
||||||
void GetResourceUpdatesForMlir(
|
void GetResourceUpdatesForMlir(
|
||||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
std::vector<XlaResourceUpdate>* resource_updates) {
|
||||||
resource_updates->clear();
|
resource_updates->clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -265,7 +264,7 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||||
bool return_tuple,
|
bool return_tuple,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
mlir::PassManager tf2xla(module_op.getContext());
|
mlir::PassManager tf2xla(module_op.getContext());
|
||||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||||
@ -341,8 +340,8 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
static Status CompileMlirToXlaHlo(
|
static Status CompileMlirToXlaHlo(
|
||||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
if (VLOG_IS_ON(1))
|
if (VLOG_IS_ON(1))
|
||||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||||
@ -391,8 +390,8 @@ static Status CompileMlirToXlaHlo(
|
|||||||
Status CompileSerializedMlirToXlaHlo(
|
Status CompileSerializedMlirToXlaHlo(
|
||||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
mlir::MLIRContext mlir_context;
|
mlir::MLIRContext mlir_context;
|
||||||
@ -411,16 +410,16 @@ Status CompileSerializedMlirToXlaHlo(
|
|||||||
// removed from the signature.
|
// removed from the signature.
|
||||||
// Returns the original indices for the other arguments on success.
|
// Returns the original indices for the other arguments on success.
|
||||||
static StatusOr<std::vector<int>> RewriteWithArgs(
|
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) {
|
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
|
||||||
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||||
std::vector<int> params;
|
std::vector<int> params;
|
||||||
|
|
||||||
auto builder = mlir::OpBuilder(main_fn.getBody());
|
auto builder = mlir::OpBuilder(main_fn.getBody());
|
||||||
std::vector<int> args_to_erase;
|
std::vector<int> args_to_erase;
|
||||||
for (int idx = 0; idx < args.size(); idx++) {
|
for (int idx = 0; idx < args.size(); idx++) {
|
||||||
const XlaCompiler::Argument& xla_arg = args[idx];
|
const XlaArgument& xla_arg = args[idx];
|
||||||
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
|
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
|
||||||
if (xla_arg.kind != XlaCompiler::Argument::kConstant) {
|
if (xla_arg.kind != XlaArgument::kConstant) {
|
||||||
params.push_back(idx);
|
params.push_back(idx);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@ -439,11 +438,11 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||||
RegisterDialects();
|
RegisterDialects();
|
||||||
|
|
||||||
|
@ -20,7 +20,10 @@ limitations under the License.
|
|||||||
#include "llvm/ADT/StringRef.h"
|
#include "llvm/ADT/StringRef.h"
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||||
@ -57,7 +60,7 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||||
bool return_tuple,
|
bool return_tuple,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||||
|
|
||||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||||
@ -65,17 +68,17 @@ Status ConvertMLIRToXlaComputation(
|
|||||||
Status CompileSerializedMlirToXlaHlo(
|
Status CompileSerializedMlirToXlaHlo(
|
||||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||||
|
|
||||||
// Same as the above but takes input as TensorFlow Graph.
|
// Same as the above but takes input as TensorFlow Graph.
|
||||||
Status CompileGraphToXlaHlo(
|
Status CompileGraphToXlaHlo(
|
||||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||||
llvm::StringRef device_type, bool use_tuple_args,
|
llvm::StringRef device_type, bool use_tuple_args,
|
||||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
XlaCompiler::CompilationResult* compilation_result,
|
XlaCompilationResult* compilation_result,
|
||||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -92,7 +92,11 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
|
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_expression",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/core:core_cpu_lib",
|
"//tensorflow/core:core_cpu_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -125,8 +129,10 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/hlo",
|
"//tensorflow/compiler/mlir/hlo",
|
||||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||||
|
"//tensorflow/compiler/xla:debug_options_flags",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
"//tensorflow/compiler/xla/service:backend",
|
||||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||||
@ -228,7 +234,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
"//tensorflow/compiler/xla:comparison_util",
|
"//tensorflow/compiler/xla:comparison_util",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||||
@ -463,7 +462,7 @@ class ConvertToHloModule {
|
|||||||
// single value.
|
// single value.
|
||||||
explicit ConvertToHloModule(
|
explicit ConvertToHloModule(
|
||||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
|
||||||
: module_(module),
|
: module_(module),
|
||||||
module_builder_("main"),
|
module_builder_("main"),
|
||||||
use_tuple_args_(use_tuple_args),
|
use_tuple_args_(use_tuple_args),
|
||||||
@ -545,7 +544,7 @@ class ConvertToHloModule {
|
|||||||
|
|
||||||
// Shape representation function to determine entry function argument and
|
// Shape representation function to determine entry function argument and
|
||||||
// result shapes.
|
// result shapes.
|
||||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
|
||||||
|
|
||||||
// Unique suffix to give to the name of the next lowered region.
|
// Unique suffix to give to the name of the next lowered region.
|
||||||
size_t region_id_ = 0;
|
size_t region_id_ = 0;
|
||||||
@ -1500,7 +1499,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
|||||||
|
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple,
|
bool use_tuple_args, bool return_tuple,
|
||||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||||
shape_representation_fn) {
|
shape_representation_fn) {
|
||||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||||
|
@ -18,9 +18,10 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "mlir/IR/Module.h" // from @llvm-project
|
#include "mlir/IR/Module.h" // from @llvm-project
|
||||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
@ -33,7 +34,7 @@ namespace mlir {
|
|||||||
// single value.
|
// single value.
|
||||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||||
bool use_tuple_args, bool return_tuple,
|
bool use_tuple_args, bool return_tuple,
|
||||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||||
shape_representation_fn = nullptr);
|
shape_representation_fn = nullptr);
|
||||||
|
|
||||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||||
|
@ -48,7 +48,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/core/common_runtime/device.h"
|
#include "tensorflow/core/common_runtime/device.h"
|
||||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||||
@ -410,7 +411,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
|||||||
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
|
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
|
||||||
shape_or.ValueOrDie());
|
shape_or.ValueOrDie());
|
||||||
tensorflow::Tensor& tensor = tensors.back();
|
tensorflow::Tensor& tensor = tensors.back();
|
||||||
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor);
|
tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
|
||||||
inputs.emplace_back(&tensor);
|
inputs.emplace_back(&tensor);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,7 +439,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
|||||||
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
|
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
|
||||||
tensorflow::Tensor* output = op_context.mutable_output(i);
|
tensorflow::Tensor* output = op_context.mutable_output(i);
|
||||||
const tensorflow::XlaExpression* expr =
|
const tensorflow::XlaExpression* expr =
|
||||||
tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
|
tensorflow::XlaExpression::CastExpressionFromTensor(*output);
|
||||||
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
|
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
|
||||||
return op_->emitError(
|
return op_->emitError(
|
||||||
"expects XlaExpression of kind kXlaOp in compiled output");
|
"expects XlaExpression of kind kXlaOp in compiled output");
|
||||||
|
@ -37,6 +37,8 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||||
|
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/backend.h"
|
||||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
@ -50,6 +50,7 @@ cc_library(
|
|||||||
visibility = ["//visibility:public"],
|
visibility = ["//visibility:public"],
|
||||||
deps = [
|
deps = [
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
|
":xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
@ -145,6 +146,7 @@ cc_library(
|
|||||||
":tf2xla_proto_cc",
|
":tf2xla_proto_cc",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
|
":xla_op_registry",
|
||||||
"//tensorflow/compiler/aot:aot_only_var_handle_op",
|
"//tensorflow/compiler/aot:aot_only_var_handle_op",
|
||||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||||
"//tensorflow/compiler/xla/client",
|
"//tensorflow/compiler/xla/client",
|
||||||
@ -316,14 +318,8 @@ cc_library(
|
|||||||
srcs = [
|
srcs = [
|
||||||
"const_analysis.cc",
|
"const_analysis.cc",
|
||||||
"graph_compiler.cc",
|
"graph_compiler.cc",
|
||||||
"xla_compilation_device.cc",
|
|
||||||
"xla_compiler.cc",
|
"xla_compiler.cc",
|
||||||
"xla_context.cc",
|
|
||||||
"xla_expression.cc",
|
|
||||||
"xla_helpers.cc",
|
|
||||||
"xla_op_kernel.cc",
|
"xla_op_kernel.cc",
|
||||||
"xla_op_registry.cc",
|
|
||||||
"xla_resource.cc",
|
|
||||||
"xla_cpu_backend.cc",
|
"xla_cpu_backend.cc",
|
||||||
] + if_cuda_is_configured([
|
] + if_cuda_is_configured([
|
||||||
"xla_gpu_backend.cc",
|
"xla_gpu_backend.cc",
|
||||||
@ -333,14 +329,10 @@ cc_library(
|
|||||||
hdrs = [
|
hdrs = [
|
||||||
"const_analysis.h",
|
"const_analysis.h",
|
||||||
"graph_compiler.h",
|
"graph_compiler.h",
|
||||||
"xla_compilation_device.h",
|
|
||||||
"xla_compiler.h",
|
"xla_compiler.h",
|
||||||
"xla_context.h",
|
|
||||||
"xla_expression.h",
|
|
||||||
"xla_helpers.h",
|
"xla_helpers.h",
|
||||||
"xla_op_kernel.h",
|
"xla_op_kernel.h",
|
||||||
"xla_op_registry.h",
|
"xla_op_registry.h",
|
||||||
"xla_resource.h",
|
|
||||||
],
|
],
|
||||||
visibility = [":friends"],
|
visibility = [":friends"],
|
||||||
deps = [
|
deps = [
|
||||||
@ -351,10 +343,18 @@ cc_library(
|
|||||||
":sharding_util",
|
":sharding_util",
|
||||||
":side_effect_util",
|
":side_effect_util",
|
||||||
":tf2xla_util",
|
":tf2xla_util",
|
||||||
|
":xla_argument",
|
||||||
|
":xla_compilation_device",
|
||||||
|
":xla_context",
|
||||||
|
":xla_expression",
|
||||||
|
":xla_helpers",
|
||||||
|
":xla_op_registry",
|
||||||
|
":xla_resource",
|
||||||
"//tensorflow/compiler/jit:common",
|
"//tensorflow/compiler/jit:common",
|
||||||
"//tensorflow/compiler/jit:flags",
|
"//tensorflow/compiler/jit:flags",
|
||||||
"//tensorflow/compiler/jit:shape_inference",
|
"//tensorflow/compiler/jit:shape_inference",
|
||||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||||
|
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||||
"//tensorflow/compiler/tf2xla/lib:util",
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
@ -370,6 +370,7 @@ cc_library(
|
|||||||
"//tensorflow/compiler/xla/client:xla_computation",
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
"//tensorflow/compiler/xla/client/lib:constants",
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:core_cpu",
|
"//tensorflow/core:core_cpu",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
@ -388,6 +389,172 @@ cc_library(
|
|||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_compilation_device",
|
||||||
|
srcs = [
|
||||||
|
"xla_compilation_device.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_compilation_device.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":frontend_attributes_util",
|
||||||
|
":sharding_util",
|
||||||
|
":xla_context",
|
||||||
|
":xla_helpers",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:session_options",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_context",
|
||||||
|
srcs = [
|
||||||
|
"xla_context.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_context.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":xla_expression",
|
||||||
|
":xla_helpers",
|
||||||
|
"//tensorflow/compiler/xla:literal",
|
||||||
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_op_registry",
|
||||||
|
srcs = [
|
||||||
|
"xla_op_registry.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_op_registry.h",
|
||||||
|
],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":xla_context",
|
||||||
|
"//tensorflow/compiler/jit:flags",
|
||||||
|
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||||
|
"//tensorflow/compiler/xla/client:client_library",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core:session_options",
|
||||||
|
"//tensorflow/core:stream_executor_no_cuda",
|
||||||
|
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_expression",
|
||||||
|
srcs = [
|
||||||
|
"xla_expression.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_expression.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":xla_resource",
|
||||||
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla/client",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/types:optional",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_resource",
|
||||||
|
srcs = [
|
||||||
|
"xla_resource.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_resource.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":sharding_util",
|
||||||
|
":xla_helpers",
|
||||||
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"@com_google_absl//absl/memory",
|
||||||
|
"@com_google_absl//absl/strings",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_helpers",
|
||||||
|
srcs = [
|
||||||
|
"xla_helpers.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_helpers.h",
|
||||||
|
],
|
||||||
|
visibility = [":friends"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
":host_compute_metadata_proto_cc",
|
||||||
|
"//tensorflow/compiler/tf2xla/lib:util",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_computation",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
|
"//tensorflow/compiler/xla/client/lib:constants",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "xla_argument",
|
||||||
|
srcs = [
|
||||||
|
"xla_argument.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"xla_argument.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":host_compute_metadata_proto_cc",
|
||||||
|
":xla_resource",
|
||||||
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"@com_google_absl//absl/types:span",
|
||||||
|
],
|
||||||
|
alwayslink = 1,
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "common",
|
name = "common",
|
||||||
srcs = [
|
srcs = [
|
||||||
@ -564,6 +731,8 @@ tf_cc_test(
|
|||||||
":common",
|
":common",
|
||||||
":side_effect_util",
|
":side_effect_util",
|
||||||
":xla_compiler",
|
":xla_compiler",
|
||||||
|
":xla_expression",
|
||||||
|
":xla_resource",
|
||||||
"//tensorflow/cc:cc_ops",
|
"//tensorflow/cc:cc_ops",
|
||||||
"//tensorflow/cc:function_ops",
|
"//tensorflow/cc:function_ops",
|
||||||
"//tensorflow/cc:functional_ops",
|
"//tensorflow/cc:functional_ops",
|
||||||
|
@ -145,7 +145,12 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_resource",
|
||||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||||
"//tensorflow/compiler/tf2xla/lib:data_format",
|
"//tensorflow/compiler/tf2xla/lib:data_format",
|
||||||
"//tensorflow/compiler/tf2xla/lib:random",
|
"//tensorflow/compiler/tf2xla/lib:random",
|
||||||
@ -223,6 +228,8 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
@ -276,6 +283,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
@ -296,6 +305,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
@ -314,6 +325,8 @@ tf_kernel_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_context",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||||
"//tensorflow/compiler/xla:literal",
|
"//tensorflow/compiler/xla:literal",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
@ -333,6 +346,7 @@ tf_kernel_library(
|
|||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/core:array_ops_op_lib",
|
"//tensorflow/core:array_ops_op_lib",
|
||||||
"//tensorflow/core:framework",
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
|
@ -38,6 +38,7 @@ cc_library(
|
|||||||
hdrs = ["random.h"],
|
hdrs = ["random.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||||
"//tensorflow/compiler/xla:status_macros",
|
"//tensorflow/compiler/xla:status_macros",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
|
53
tensorflow/compiler/tf2xla/xla_argument.cc
Normal file
53
tensorflow/compiler/tf2xla/xla_argument.cc
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
bool XlaArgument::operator==(const XlaArgument& other) const {
|
||||||
|
if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
|
||||||
|
tensor_array_gradients) !=
|
||||||
|
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
||||||
|
other.initialized, other.max_array_size,
|
||||||
|
other.tensor_array_gradients)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (absl::holds_alternative<xla::Shape>(shape)) {
|
||||||
|
if (!absl::holds_alternative<xla::Shape>(other.shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
|
||||||
|
absl::get<xla::Shape>(other.shape))) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!absl::holds_alternative<TensorShape>(other.shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (constant_value.shape() != other.constant_value.shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
121
tensorflow/compiler/tf2xla/xla_argument.h
Normal file
121
tensorflow/compiler/tf2xla/xla_argument.h
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
||||||
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
||||||
|
|
||||||
|
#include "absl/types/span.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||||
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
// Describes how to derive the value of each _Arg node in the graph/function
|
||||||
|
// being compiled. There must be one Argument for each _Arg index.
|
||||||
|
struct XlaArgument {
|
||||||
|
enum Kind {
|
||||||
|
// Default value; not a valid kind.
|
||||||
|
kInvalid,
|
||||||
|
|
||||||
|
// Argument is a compile-time constant. No associated runtime parameter.
|
||||||
|
kConstant,
|
||||||
|
|
||||||
|
// Argument is a Variable, TensorArray, or Stack resource. Has an
|
||||||
|
// associated runtime parameter iff `initialized` is true.
|
||||||
|
kResource,
|
||||||
|
|
||||||
|
// Argument is a run-time parameter.
|
||||||
|
kParameter,
|
||||||
|
|
||||||
|
// Argument is an XLA token.
|
||||||
|
kToken,
|
||||||
|
|
||||||
|
// Argument is a TensorList.
|
||||||
|
kTensorList,
|
||||||
|
};
|
||||||
|
|
||||||
|
Kind kind = kInvalid;
|
||||||
|
|
||||||
|
// The type of the argument. If the argument is a resource, this
|
||||||
|
// is the type of the variable's value, not DT_RESOURCE.
|
||||||
|
DataType type = DT_INVALID;
|
||||||
|
|
||||||
|
// The shape of the argument. For:
|
||||||
|
// * a parameter: the shape of the parameter. We allow setting the xla shape
|
||||||
|
// if known. This helps avoid conversions to and from TensorShape.
|
||||||
|
// * a constant: ignored; the shape given by constant_value is used
|
||||||
|
// instead.
|
||||||
|
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
||||||
|
// uninitialized resource (otherwise we would have initialized it!)
|
||||||
|
// * an initialized variable: the shape of the variable's value.
|
||||||
|
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
||||||
|
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
||||||
|
// XLA data structure that represents the complete stack/array.
|
||||||
|
absl::variant<TensorShape, xla::Shape> shape;
|
||||||
|
|
||||||
|
// The value of the argument, if it is a compile-time constant. Must be a
|
||||||
|
// host-memory tensor.
|
||||||
|
Tensor constant_value;
|
||||||
|
|
||||||
|
// The name of this argument, used for debugging.
|
||||||
|
string name;
|
||||||
|
|
||||||
|
// The name of TensorFlow _Arg node, used for debugging.
|
||||||
|
string node_name;
|
||||||
|
|
||||||
|
// For a kResource, what kind of resource is it?
|
||||||
|
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
||||||
|
|
||||||
|
// For a kResource, has this resource been initialized?
|
||||||
|
bool initialized = false;
|
||||||
|
|
||||||
|
// For a kResource, is this resource on Fast Memory.
|
||||||
|
bool fast_mem = false;
|
||||||
|
|
||||||
|
// For a TensorArray or Stack resource, what is the array's declared size?
|
||||||
|
// (Used for lazy initialization.)
|
||||||
|
int64 max_array_size = -1;
|
||||||
|
|
||||||
|
// TensorArray resource parameters are passed as (array, gradient array 0,
|
||||||
|
// ..., gradient array k), where the gradient arrays are in the same order
|
||||||
|
// as `tensor_array_gradients`.
|
||||||
|
std::set<string> tensor_array_gradients;
|
||||||
|
|
||||||
|
// dynamic dims to arg number map. Empty if no dynamic shapes.
|
||||||
|
std::map<int32, int32> dynamic_dim_to_arg_num_map;
|
||||||
|
bool is_pad_arg = false;
|
||||||
|
|
||||||
|
// Whether this argument will receive the same data across all replicas.
|
||||||
|
bool is_same_data_across_replicas = false;
|
||||||
|
|
||||||
|
bool operator==(const XlaArgument& other) const;
|
||||||
|
|
||||||
|
// Returns a human-readable summary of the argument.
|
||||||
|
string HumanString() const;
|
||||||
|
|
||||||
|
// Returns the dimension sizes for either TensorShape or xla::Shape.
|
||||||
|
std::vector<int64> DimensionSizes() const;
|
||||||
|
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
|
||||||
|
|
||||||
|
// Returns the human-readable string for either TensorShape or xla::Shape.
|
||||||
|
string ShapeHumanString() const;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
@ -422,39 +422,6 @@ Status BuildComputation(
|
|||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
bool XlaCompiler::Argument::operator==(
|
|
||||||
const XlaCompiler::Argument& other) const {
|
|
||||||
if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
|
|
||||||
tensor_array_gradients) !=
|
|
||||||
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
|
||||||
other.initialized, other.max_array_size,
|
|
||||||
other.tensor_array_gradients)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (absl::holds_alternative<xla::Shape>(shape)) {
|
|
||||||
if (!absl::holds_alternative<xla::Shape>(other.shape)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
|
|
||||||
absl::get<xla::Shape>(other.shape))) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (!absl::holds_alternative<TensorShape>(other.shape)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (constant_value.shape() != other.constant_value.shape()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
|
||||||
}
|
|
||||||
|
|
||||||
string XlaCompiler::Argument::HumanString() const {
|
string XlaCompiler::Argument::HumanString() const {
|
||||||
string common;
|
string common;
|
||||||
@ -1494,93 +1461,4 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
|
|||||||
return iter->second;
|
return iter->second;
|
||||||
}
|
}
|
||||||
|
|
||||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
|
||||||
return [](const TensorShape& shape, DataType dtype,
|
|
||||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
|
||||||
xla::Shape xla_shape;
|
|
||||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
|
||||||
return xla_shape;
|
|
||||||
};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
|
||||||
Status RewriteLayoutWithShardedShape(
|
|
||||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
xla::Shape* xla_shape) {
|
|
||||||
if (sharding && !sharding->IsTileMaximal()) {
|
|
||||||
// After sharding, per core shape might have different layout. For example,
|
|
||||||
// before sharding, a shape [128, 128] will be assigned default
|
|
||||||
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
|
||||||
// the sharded shapes will have minor-to-major {0, 1}.
|
|
||||||
//
|
|
||||||
// As a result, for sharded shapes, we set their layout to per core shape's
|
|
||||||
// layout.
|
|
||||||
//
|
|
||||||
// TODO(endlessroad): for variable input & update, we might have
|
|
||||||
// different layouts which will prevent input output aliasing and
|
|
||||||
// increase memory usage. Investigate such cases.
|
|
||||||
int64 device = *sharding->tile_assignment().begin();
|
|
||||||
std::vector<int64> offset =
|
|
||||||
sharding->TileOffsetForDevice(*xla_shape, device);
|
|
||||||
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
|
||||||
std::vector<int64> dimensions(xla_shape->rank());
|
|
||||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
|
||||||
dimensions[i] = limit[i] - offset[i];
|
|
||||||
}
|
|
||||||
xla::Shape per_device_xla_shape =
|
|
||||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
|
||||||
TensorShape per_device_tensor_shape;
|
|
||||||
TF_RETURN_IF_ERROR(
|
|
||||||
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
|
||||||
xla_shape->element_type()));
|
|
||||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
|
||||||
shape_representation_fn(per_device_tensor_shape, dtype,
|
|
||||||
use_fast_memory));
|
|
||||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
|
||||||
}
|
|
||||||
return Status::OK();
|
|
||||||
}
|
|
||||||
|
|
||||||
// There is a shape_representation_fn or sharding for an output, this function
|
|
||||||
// uses a reshape to fix the layout.
|
|
||||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
|
||||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
|
||||||
if (original_shape.IsTuple()) {
|
|
||||||
std::vector<xla::XlaOp> elements;
|
|
||||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
|
||||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
|
||||||
TF_ASSIGN_OR_RETURN(auto element,
|
|
||||||
ReshapeWithCorrectRepresentationAndSharding(
|
|
||||||
builder, xla::GetTupleElement(original, i),
|
|
||||||
original_shape.tuple_shapes(i),
|
|
||||||
shape_representation_fn, subsharding, fast_mem));
|
|
||||||
elements.push_back(element);
|
|
||||||
}
|
|
||||||
return xla::Tuple(builder, elements);
|
|
||||||
}
|
|
||||||
if (!original_shape.IsArray()) return original;
|
|
||||||
TensorShape shape;
|
|
||||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
|
||||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
|
||||||
original_shape.element_type()));
|
|
||||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
|
||||||
shape_representation_fn(shape, dtype, fast_mem));
|
|
||||||
if (sharding) {
|
|
||||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
|
||||||
xla::HloSharding::FromProto(*sharding));
|
|
||||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
|
||||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
|
||||||
}
|
|
||||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
|
||||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
|
||||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return xla::Reshape(to_shape, original);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "absl/types/variant.h"
|
#include "absl/types/variant.h"
|
||||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
@ -97,96 +99,7 @@ class XlaContext;
|
|||||||
// `tensor_array_gradients` ordered set.
|
// `tensor_array_gradients` ordered set.
|
||||||
class XlaCompiler {
|
class XlaCompiler {
|
||||||
public:
|
public:
|
||||||
// Describes how to derive the value of each _Arg node in the graph/function
|
using Argument = ::tensorflow::XlaArgument;
|
||||||
// being compiled. There must be one Argument for each _Arg index.
|
|
||||||
struct Argument {
|
|
||||||
enum Kind {
|
|
||||||
// Default value; not a valid kind.
|
|
||||||
kInvalid,
|
|
||||||
|
|
||||||
// Argument is a compile-time constant. No associated runtime parameter.
|
|
||||||
kConstant,
|
|
||||||
|
|
||||||
// Argument is a Variable, TensorArray, or Stack resource. Has an
|
|
||||||
// associated runtime parameter iff `initialized` is true.
|
|
||||||
kResource,
|
|
||||||
|
|
||||||
// Argument is a run-time parameter.
|
|
||||||
kParameter,
|
|
||||||
|
|
||||||
// Argument is an XLA token.
|
|
||||||
kToken,
|
|
||||||
|
|
||||||
// Argument is a TensorList.
|
|
||||||
kTensorList,
|
|
||||||
};
|
|
||||||
|
|
||||||
Kind kind = kInvalid;
|
|
||||||
|
|
||||||
// The type of the argument. If the argument is a resource, this
|
|
||||||
// is the type of the variable's value, not DT_RESOURCE.
|
|
||||||
DataType type = DT_INVALID;
|
|
||||||
|
|
||||||
// The shape of the argument. For:
|
|
||||||
// * a parameter: the shape of the parameter. We allow setting the xla shape
|
|
||||||
// if known. This helps avoid conversions to and from TensorShape.
|
|
||||||
// * a constant: ignored; the shape given by constant_value is used
|
|
||||||
// instead.
|
|
||||||
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
|
||||||
// uninitialized resource (otherwise we would have initialized it!)
|
|
||||||
// * an initialized variable: the shape of the variable's value.
|
|
||||||
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
|
||||||
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
|
||||||
// XLA data structure that represents the complete stack/array.
|
|
||||||
absl::variant<TensorShape, xla::Shape> shape;
|
|
||||||
|
|
||||||
// The value of the argument, if it is a compile-time constant. Must be a
|
|
||||||
// host-memory tensor.
|
|
||||||
Tensor constant_value;
|
|
||||||
|
|
||||||
// The name of this argument, used for debugging.
|
|
||||||
string name;
|
|
||||||
|
|
||||||
// The name of TensorFlow _Arg node, used for debugging.
|
|
||||||
string node_name;
|
|
||||||
|
|
||||||
// For a kResource, what kind of resource is it?
|
|
||||||
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
|
||||||
|
|
||||||
// For a kResource, has this resource been initialized?
|
|
||||||
bool initialized = false;
|
|
||||||
|
|
||||||
// For a kResource, is this resource on Fast Memory.
|
|
||||||
bool fast_mem = false;
|
|
||||||
|
|
||||||
// For a TensorArray or Stack resource, what is the array's declared size?
|
|
||||||
// (Used for lazy initialization.)
|
|
||||||
int64 max_array_size = -1;
|
|
||||||
|
|
||||||
// TensorArray resource parameters are passed as (array, gradient array 0,
|
|
||||||
// ..., gradient array k), where the gradient arrays are in the same order
|
|
||||||
// as `tensor_array_gradients`.
|
|
||||||
std::set<string> tensor_array_gradients;
|
|
||||||
|
|
||||||
// dynamic dims to arg number map. Empty if no dynamic shapes.
|
|
||||||
std::map<int32, int32> dynamic_dim_to_arg_num_map;
|
|
||||||
bool is_pad_arg = false;
|
|
||||||
|
|
||||||
// Whether this argument will receive the same data across all replicas.
|
|
||||||
bool is_same_data_across_replicas = false;
|
|
||||||
|
|
||||||
bool operator==(const Argument& other) const;
|
|
||||||
|
|
||||||
// Returns a human-readable summary of the argument.
|
|
||||||
string HumanString() const;
|
|
||||||
|
|
||||||
// Returns the dimension sizes for either TensorShape or xla::Shape.
|
|
||||||
std::vector<int64> DimensionSizes() const;
|
|
||||||
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
|
|
||||||
|
|
||||||
// Returns the human-readable string for either TensorShape or xla::Shape.
|
|
||||||
string ShapeHumanString() const;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Options pertaining to an individual call to CompileGraph() or
|
// Options pertaining to an individual call to CompileGraph() or
|
||||||
// CompileFunction().
|
// CompileFunction().
|
||||||
@ -221,77 +134,11 @@ class XlaCompiler {
|
|||||||
bool alias_resource_update = false;
|
bool alias_resource_update = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct OutputDescription {
|
using OutputDescription = ::tensorflow::XlaOutputDescription;
|
||||||
// Type and shape of the output. The shape is the unflattened shape.
|
|
||||||
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
|
||||||
// variable's value.
|
|
||||||
DataType type;
|
|
||||||
TensorShape shape;
|
|
||||||
|
|
||||||
// Constant output value, if known to be constant at JIT compilation time.
|
using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
|
||||||
// 'Tensor' is in host memory.
|
|
||||||
bool is_constant = false;
|
|
||||||
Tensor constant_value;
|
|
||||||
|
|
||||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
using CompilationResult = ::tensorflow::XlaCompilationResult;
|
||||||
// the index of the input that contains the resource.
|
|
||||||
int input_index;
|
|
||||||
|
|
||||||
// Whether this output is a TensorList.
|
|
||||||
bool is_tensor_list = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Describes a variable write side effect of the computation.
|
|
||||||
struct ResourceUpdate {
|
|
||||||
// Index of the input that contains the variable resource to write to.
|
|
||||||
int input_index;
|
|
||||||
|
|
||||||
// Type and shape of the tensor to be written back.
|
|
||||||
// The `shape` field has the same meaning as the Argument::shape field.
|
|
||||||
DataType type;
|
|
||||||
TensorShape shape;
|
|
||||||
|
|
||||||
// Was the value of the variable modified by the computation?
|
|
||||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
|
||||||
bool modified;
|
|
||||||
|
|
||||||
// If the resource is a TensorArray, the set of gradients read or written.
|
|
||||||
std::set<string> tensor_array_gradients_accessed;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct CompilationResult {
|
|
||||||
// Vector that maps from the parameters of the XLA computation to their
|
|
||||||
// original argument positions. To handle compile-time constant inputs, the
|
|
||||||
// parameters to the XLA computation may be a subset of the original
|
|
||||||
// arguments. The relative ordering of parameters are maintained.
|
|
||||||
std::vector<int> input_mapping;
|
|
||||||
|
|
||||||
// Input shapes of the computation. If we are flattening inputs, these are
|
|
||||||
// the flattened shapes.
|
|
||||||
std::vector<xla::Shape> xla_input_shapes;
|
|
||||||
|
|
||||||
// Output shape in XLA format. The output shape is always a tuple. If we
|
|
||||||
// are flattening outputs, these are the flattened shapes.
|
|
||||||
xla::Shape xla_output_shape;
|
|
||||||
|
|
||||||
// TensorFlow shapes of outputs, together with the values of any
|
|
||||||
// constant arguments. Vector indexed by Tensorflow _Retval number,
|
|
||||||
// containing both constant and non-constant results.
|
|
||||||
std::vector<OutputDescription> outputs;
|
|
||||||
|
|
||||||
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
|
||||||
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
|
||||||
tf2xla::HostComputeMetadata host_compute_metadata;
|
|
||||||
|
|
||||||
// Resources whose values were updated by the computation, ordered
|
|
||||||
// by return value position (which is the same as the order the resources
|
|
||||||
// were passed as arguments). Resource updates follow the non-constant
|
|
||||||
// results in the outputs of XLA computation.
|
|
||||||
std::vector<ResourceUpdate> resource_updates;
|
|
||||||
|
|
||||||
// The XLA computation built from the tensorflow subgraph.
|
|
||||||
std::shared_ptr<xla::XlaComputation> computation;
|
|
||||||
};
|
|
||||||
|
|
||||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
||||||
bool)>
|
bool)>
|
||||||
@ -518,21 +365,6 @@ class XlaCompiler {
|
|||||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||||
};
|
};
|
||||||
|
|
||||||
// Creates an identity shape representation function.
|
|
||||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
|
||||||
|
|
||||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
|
||||||
Status RewriteLayoutWithShardedShape(
|
|
||||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
xla::Shape* xla_shape);
|
|
||||||
|
|
||||||
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
|
||||||
// sharding is present.
|
|
||||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
|
||||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
|
||||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
|
||||||
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
|||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
@ -33,6 +32,7 @@ limitations under the License.
|
|||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
|
|
||||||
class XlaOpKernelContext;
|
class XlaOpKernelContext;
|
||||||
|
class XlaCompiler;
|
||||||
|
|
||||||
// The XlaContext is the data structure that holds the state of an XLA
|
// The XlaContext is the data structure that holds the state of an XLA
|
||||||
// compilation, that is accessible from OpKernelContexts when compiling a
|
// compilation, that is accessible from OpKernelContexts when compiling a
|
||||||
|
@ -163,4 +163,23 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const XlaExpression* XlaExpression::CastExpressionFromTensor(
|
||||||
|
const Tensor& tensor) {
|
||||||
|
const XlaExpression* expression =
|
||||||
|
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||||
|
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||||
|
<< expression->HumanString();
|
||||||
|
return expression;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||||
|
void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
|
||||||
|
Tensor* tensor) {
|
||||||
|
const XlaExpression* expression =
|
||||||
|
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||||
|
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||||
|
<< expression->HumanString();
|
||||||
|
*const_cast<XlaExpression*>(expression) = value;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -104,6 +104,13 @@ class XlaExpression {
|
|||||||
// not the shape of the resource's value.
|
// not the shape of the resource's value.
|
||||||
xla::StatusOr<TensorShape> GetShape() const;
|
xla::StatusOr<TensorShape> GetShape() const;
|
||||||
|
|
||||||
|
// Retrieves an XlaExpression that was allocated by a previous Op.
|
||||||
|
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
||||||
|
|
||||||
|
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||||
|
static void AssignExpressionToTensor(const XlaExpression& value,
|
||||||
|
Tensor* tensor);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Kind kind_ = Kind::kInvalid;
|
Kind kind_ = Kind::kInvalid;
|
||||||
|
|
||||||
|
@ -22,8 +22,6 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
|
||||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
@ -128,4 +126,93 @@ xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
|
|||||||
return xla::ConvertElementType(operand, convert_to);
|
return xla::ConvertElementType(operand, convert_to);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
||||||
|
return [](const TensorShape& shape, DataType dtype,
|
||||||
|
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||||
|
xla::Shape xla_shape;
|
||||||
|
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||||
|
return xla_shape;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||||
|
Status RewriteLayoutWithShardedShape(
|
||||||
|
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||||
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
xla::Shape* xla_shape) {
|
||||||
|
if (sharding && !sharding->IsTileMaximal()) {
|
||||||
|
// After sharding, per core shape might have different layout. For example,
|
||||||
|
// before sharding, a shape [128, 128] will be assigned default
|
||||||
|
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||||
|
// the sharded shapes will have minor-to-major {0, 1}.
|
||||||
|
//
|
||||||
|
// As a result, for sharded shapes, we set their layout to per core shape's
|
||||||
|
// layout.
|
||||||
|
//
|
||||||
|
// TODO(endlessroad): for variable input & update, we might have
|
||||||
|
// different layouts which will prevent input output aliasing and
|
||||||
|
// increase memory usage. Investigate such cases.
|
||||||
|
int64 device = *sharding->tile_assignment().begin();
|
||||||
|
std::vector<int64> offset =
|
||||||
|
sharding->TileOffsetForDevice(*xla_shape, device);
|
||||||
|
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
||||||
|
std::vector<int64> dimensions(xla_shape->rank());
|
||||||
|
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||||
|
dimensions[i] = limit[i] - offset[i];
|
||||||
|
}
|
||||||
|
xla::Shape per_device_xla_shape =
|
||||||
|
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||||
|
TensorShape per_device_tensor_shape;
|
||||||
|
TF_RETURN_IF_ERROR(
|
||||||
|
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||||
|
xla_shape->element_type()));
|
||||||
|
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||||
|
shape_representation_fn(per_device_tensor_shape, dtype,
|
||||||
|
use_fast_memory));
|
||||||
|
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
// There is a shape_representation_fn or sharding for an output, this function
|
||||||
|
// uses a reshape to fix the layout.
|
||||||
|
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||||
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||||
|
if (original_shape.IsTuple()) {
|
||||||
|
std::vector<xla::XlaOp> elements;
|
||||||
|
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||||
|
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto element,
|
||||||
|
ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
builder, xla::GetTupleElement(original, i),
|
||||||
|
original_shape.tuple_shapes(i),
|
||||||
|
shape_representation_fn, subsharding, fast_mem));
|
||||||
|
elements.push_back(element);
|
||||||
|
}
|
||||||
|
return xla::Tuple(builder, elements);
|
||||||
|
}
|
||||||
|
if (!original_shape.IsArray()) return original;
|
||||||
|
TensorShape shape;
|
||||||
|
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||||
|
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||||
|
original_shape.element_type()));
|
||||||
|
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||||
|
shape_representation_fn(shape, dtype, fast_mem));
|
||||||
|
if (sharding) {
|
||||||
|
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||||
|
xla::HloSharding::FromProto(*sharding));
|
||||||
|
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||||
|
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||||
|
}
|
||||||
|
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||||
|
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||||
|
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return xla::Reshape(to_shape, original);
|
||||||
|
}
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -19,8 +19,9 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||||
#include "tensorflow/core/framework/tensor.h"
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
|
||||||
namespace tensorflow {
|
namespace tensorflow {
|
||||||
@ -72,6 +73,98 @@ class XlaHelpers {
|
|||||||
// than the xla::PrimitiveType.
|
// than the xla::PrimitiveType.
|
||||||
static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
|
static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
|
||||||
const DataType new_element_type);
|
const DataType new_element_type);
|
||||||
|
|
||||||
|
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
||||||
|
bool)>
|
||||||
|
ShapeRepresentationFn;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Creates an identity shape representation function.
|
||||||
|
XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
||||||
|
|
||||||
|
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||||
|
Status RewriteLayoutWithShardedShape(
|
||||||
|
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||||
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
xla::Shape* xla_shape);
|
||||||
|
|
||||||
|
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
||||||
|
// sharding is present.
|
||||||
|
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||||
|
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||||
|
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||||
|
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
||||||
|
|
||||||
|
struct XlaOutputDescription {
|
||||||
|
// Type and shape of the output. The shape is the unflattened shape.
|
||||||
|
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
||||||
|
// variable's value.
|
||||||
|
DataType type;
|
||||||
|
TensorShape shape;
|
||||||
|
|
||||||
|
// Constant output value, if known to be constant at JIT compilation time.
|
||||||
|
// 'Tensor' is in host memory.
|
||||||
|
bool is_constant = false;
|
||||||
|
Tensor constant_value;
|
||||||
|
|
||||||
|
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||||
|
// the index of the input that contains the resource.
|
||||||
|
int input_index;
|
||||||
|
|
||||||
|
// Whether this output is a TensorList.
|
||||||
|
bool is_tensor_list = false;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Describes a variable write side effect of the computation.
|
||||||
|
struct XlaResourceUpdate {
|
||||||
|
// Index of the input that contains the variable resource to write to.
|
||||||
|
int input_index;
|
||||||
|
|
||||||
|
// Type and shape of the tensor to be written back.
|
||||||
|
// The `shape` field has the same meaning as the Argument::shape field.
|
||||||
|
DataType type;
|
||||||
|
TensorShape shape;
|
||||||
|
|
||||||
|
// Was the value of the variable modified by the computation?
|
||||||
|
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||||
|
bool modified;
|
||||||
|
|
||||||
|
// If the resource is a TensorArray, the set of gradients read or written.
|
||||||
|
std::set<string> tensor_array_gradients_accessed;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct XlaCompilationResult {
|
||||||
|
// Vector that maps from the parameters of the XLA computation to their
|
||||||
|
// original argument positions. To handle compile-time constant inputs, the
|
||||||
|
// parameters to the XLA computation may be a subset of the original
|
||||||
|
// arguments. The relative ordering of parameters are maintained.
|
||||||
|
std::vector<int> input_mapping;
|
||||||
|
|
||||||
|
// Input shapes of the computation. If we are flattening inputs, these are
|
||||||
|
// the flattened shapes.
|
||||||
|
std::vector<xla::Shape> xla_input_shapes;
|
||||||
|
|
||||||
|
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||||
|
// are flattening outputs, these are the flattened shapes.
|
||||||
|
xla::Shape xla_output_shape;
|
||||||
|
|
||||||
|
// TensorFlow shapes of outputs, together with the values of any
|
||||||
|
// constant arguments. Vector indexed by Tensorflow _Retval number,
|
||||||
|
// containing both constant and non-constant results.
|
||||||
|
std::vector<XlaOutputDescription> outputs;
|
||||||
|
|
||||||
|
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||||
|
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||||
|
tf2xla::HostComputeMetadata host_compute_metadata;
|
||||||
|
|
||||||
|
// Resources whose values were updated by the computation, ordered
|
||||||
|
// by return value position (which is the same as the order the resources
|
||||||
|
// were passed as arguments). Resource updates follow the non-constant
|
||||||
|
// results in the outputs of XLA computation.
|
||||||
|
std::vector<XlaResourceUpdate> resource_updates;
|
||||||
|
|
||||||
|
// The XLA computation built from the tensorflow subgraph.
|
||||||
|
std::shared_ptr<xla::XlaComputation> computation;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // end namespace tensorflow
|
} // end namespace tensorflow
|
||||||
|
@ -49,33 +49,13 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
|
|||||||
return xla_context()->compiler();
|
return xla_context()->compiler();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Retrieves an XlaExpression that was allocated by a previous Op.
|
|
||||||
const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
|
|
||||||
const Tensor& tensor) {
|
|
||||||
const XlaExpression* expression =
|
|
||||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
|
||||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
|
||||||
<< expression->HumanString();
|
|
||||||
return expression;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
|
||||||
void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value,
|
|
||||||
Tensor* tensor) {
|
|
||||||
const XlaExpression* expression =
|
|
||||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
|
||||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
|
||||||
<< expression->HumanString();
|
|
||||||
*const_cast<XlaExpression*>(expression) = value;
|
|
||||||
}
|
|
||||||
|
|
||||||
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
|
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
|
||||||
return *CastExpressionFromTensor(context_->input(index));
|
return *XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||||
}
|
}
|
||||||
|
|
||||||
const XlaExpression& XlaOpKernelContext::InputExpression(
|
const XlaExpression& XlaOpKernelContext::InputExpression(
|
||||||
absl::string_view name) {
|
absl::string_view name) {
|
||||||
return *CastExpressionFromTensor(GetInputTensorByName(name));
|
return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
|
||||||
}
|
}
|
||||||
|
|
||||||
xla::XlaOp XlaOpKernelContext::Input(int index) {
|
xla::XlaOp XlaOpKernelContext::Input(int index) {
|
||||||
@ -108,7 +88,8 @@ DataType XlaOpKernelContext::input_type(int index) const {
|
|||||||
if (type == DT_UINT8) {
|
if (type == DT_UINT8) {
|
||||||
// Masqueraded XlaExpression could have different type. See
|
// Masqueraded XlaExpression could have different type. See
|
||||||
// XlaOpKernelContext::SetOutputExpression for details.
|
// XlaOpKernelContext::SetOutputExpression for details.
|
||||||
auto expression = CastExpressionFromTensor(context_->input(index));
|
auto expression =
|
||||||
|
XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||||
type = expression->dtype();
|
type = expression->dtype();
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
@ -120,7 +101,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) {
|
|||||||
if (type == DT_UINT8) {
|
if (type == DT_UINT8) {
|
||||||
// Masqueraded XlaExpression could have different type. See
|
// Masqueraded XlaExpression could have different type. See
|
||||||
// XlaOpKernelContext::SetOutputExpression for details.
|
// XlaOpKernelContext::SetOutputExpression for details.
|
||||||
auto expression = CastExpressionFromTensor(tensor);
|
auto expression = XlaExpression::CastExpressionFromTensor(tensor);
|
||||||
type = expression->dtype();
|
type = expression->dtype();
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
@ -385,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
|
|||||||
handles->clear();
|
handles->clear();
|
||||||
shapes->clear();
|
shapes->clear();
|
||||||
for (const Tensor& input : inputs) {
|
for (const Tensor& input : inputs) {
|
||||||
handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
handles->push_back(
|
||||||
|
XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
||||||
shapes->push_back(input.shape());
|
shapes->push_back(input.shape());
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -408,7 +390,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
|||||||
const XlaOpKernelContext* ctx,
|
const XlaOpKernelContext* ctx,
|
||||||
TensorShape* shape, xla::XlaOp* value) {
|
TensorShape* shape, xla::XlaOp* value) {
|
||||||
const XlaExpression* expression =
|
const XlaExpression* expression =
|
||||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
XlaExpression::CastExpressionFromTensor(tensor);
|
||||||
XlaResource* variable = expression->resource();
|
XlaResource* variable = expression->resource();
|
||||||
TF_RET_CHECK(variable != nullptr);
|
TF_RET_CHECK(variable != nullptr);
|
||||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||||
@ -461,7 +443,8 @@ Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
|
|||||||
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||||
TensorShape* shape) const {
|
TensorShape* shape) const {
|
||||||
const Tensor& tensor = context_->input(index);
|
const Tensor& tensor = context_->input(index);
|
||||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
const XlaExpression* expression =
|
||||||
|
XlaExpression::CastExpressionFromTensor(tensor);
|
||||||
XlaResource* variable = expression->resource();
|
XlaResource* variable = expression->resource();
|
||||||
TF_RET_CHECK(variable != nullptr);
|
TF_RET_CHECK(variable != nullptr);
|
||||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||||
@ -502,8 +485,8 @@ void XlaOpKernelContext::SetOutputExpression(int index,
|
|||||||
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
||||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
||||||
}
|
}
|
||||||
XlaOpKernelContext::AssignExpressionToTensor(
|
XlaExpression::AssignExpressionToTensor(expression,
|
||||||
expression, context_->mutable_output(index));
|
context_->mutable_output(index));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}();
|
}();
|
||||||
if (!status.ok()) {
|
if (!status.ok()) {
|
||||||
@ -542,7 +525,7 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
|||||||
|
|
||||||
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
||||||
const XlaExpression* expression =
|
const XlaExpression* expression =
|
||||||
CastExpressionFromTensor(context_->input(index));
|
XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||||
TF_RET_CHECK(expression->resource() != nullptr);
|
TF_RET_CHECK(expression->resource() != nullptr);
|
||||||
*resource = expression->resource();
|
*resource = expression->resource();
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
@ -554,7 +537,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
|||||||
const XlaOpKernelContext* ctx, xla::XlaOp handle,
|
const XlaOpKernelContext* ctx, xla::XlaOp handle,
|
||||||
xla::XlaBuilder* builder) {
|
xla::XlaBuilder* builder) {
|
||||||
const XlaExpression* expression =
|
const XlaExpression* expression =
|
||||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
XlaExpression::CastExpressionFromTensor(tensor);
|
||||||
XlaResource* variable = expression->resource();
|
XlaResource* variable = expression->resource();
|
||||||
TF_RET_CHECK(variable != nullptr);
|
TF_RET_CHECK(variable != nullptr);
|
||||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||||
|
@ -17,6 +17,9 @@ limitations under the License.
|
|||||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||||
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||||
|
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
@ -284,13 +287,6 @@ class XlaOpKernelContext {
|
|||||||
// separate specialization of the computation for each DataType.
|
// separate specialization of the computation for each DataType.
|
||||||
const xla::XlaComputation* GetOrCreateMul(const DataType type);
|
const xla::XlaComputation* GetOrCreateMul(const DataType type);
|
||||||
|
|
||||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
|
||||||
static void AssignExpressionToTensor(const XlaExpression& value,
|
|
||||||
Tensor* tensor);
|
|
||||||
|
|
||||||
// Retrieves an XlaExpression that was assigned to the specified tensor.
|
|
||||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Returns the tensor of input `name`.
|
// Returns the tensor of input `name`.
|
||||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
|||||||
#include "absl/memory/memory.h"
|
#include "absl/memory/memory.h"
|
||||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
|
||||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||||
|
|
||||||
|
@ -57,6 +57,7 @@ cc_library(
|
|||||||
":tpu_defs",
|
":tpu_defs",
|
||||||
":tpu_node_device_util",
|
":tpu_node_device_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
],
|
],
|
||||||
alwayslink = 1,
|
alwayslink = 1,
|
||||||
)
|
)
|
||||||
@ -180,8 +181,8 @@ cc_library(
|
|||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/core:core_cpu_internal",
|
"//tensorflow/core:core_cpu_internal",
|
||||||
"//tensorflow/core:framework",
|
|
||||||
"//tensorflow/core:framework_internal",
|
"//tensorflow/core:framework_internal",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
@ -656,6 +656,7 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:common",
|
"//tensorflow/compiler/tf2xla:common",
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||||
@ -673,6 +674,7 @@ cc_library(
|
|||||||
srcs = ["topk_ops.cc"],
|
srcs = ["topk_ops.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||||
|
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||||
"//tensorflow/compiler/xla/client:xla_builder",
|
"//tensorflow/compiler/xla/client:xla_builder",
|
||||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||||
"//tensorflow/core/tpu:tpu_defs",
|
"//tensorflow/core/tpu:tpu_defs",
|
||||||
|
Loading…
Reference in New Issue
Block a user