From bcfb60d0a138d215980b0881e4619a2d9b20e489 Mon Sep 17 00:00:00 2001 From: George Karpenkov <cheshire@google.com> Date: Tue, 28 Jul 2020 13:23:04 -0700 Subject: [PATCH] [TF2XLA] [NFC] Break apart the [TF2XLA/MLIR] -> xla_compiler dependency edge This is needed for invoking the MLIR tf2xla bridge from xla_compiler. This CL breaks apart items from xla_compiler into individual build targets, which are then depended on from the MLIR TF bridge. PiperOrigin-RevId: 323640340 Change-Id: I78b972503db9e7b5254014ca7e889005490d8339 --- tensorflow/compiler/aot/BUILD | 2 + tensorflow/compiler/jit/BUILD | 8 + tensorflow/compiler/jit/kernels/BUILD | 1 + tensorflow/compiler/mlir/BUILD | 1 + tensorflow/compiler/mlir/tensorflow/BUILD | 7 +- .../tensorflow/utils/compile_mlir_util.cc | 33 ++- .../mlir/tensorflow/utils/compile_mlir_util.h | 17 +- tensorflow/compiler/mlir/xla/BUILD | 10 +- .../compiler/mlir/xla/mlir_hlo_to_hlo.cc | 7 +- .../compiler/mlir/xla/mlir_hlo_to_hlo.h | 5 +- .../xla/transforms/legalize_tf_with_tf2xla.cc | 7 +- .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 2 + tensorflow/compiler/tf2xla/BUILD | 189 +++++++++++++++++- tensorflow/compiler/tf2xla/kernels/BUILD | 14 ++ tensorflow/compiler/tf2xla/lib/BUILD | 1 + tensorflow/compiler/tf2xla/xla_argument.cc | 53 +++++ tensorflow/compiler/tf2xla/xla_argument.h | 121 +++++++++++ tensorflow/compiler/tf2xla/xla_compiler.cc | 122 ----------- tensorflow/compiler/tf2xla/xla_compiler.h | 180 +---------------- tensorflow/compiler/tf2xla/xla_context.cc | 1 - tensorflow/compiler/tf2xla/xla_context.h | 2 +- tensorflow/compiler/tf2xla/xla_expression.cc | 19 ++ tensorflow/compiler/tf2xla/xla_expression.h | 7 + tensorflow/compiler/tf2xla/xla_helpers.cc | 91 ++++++++- tensorflow/compiler/tf2xla/xla_helpers.h | 95 ++++++++- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 45 ++--- tensorflow/compiler/tf2xla/xla_op_kernel.h | 10 +- tensorflow/compiler/tf2xla/xla_resource.cc | 1 - tensorflow/core/tpu/BUILD | 3 +- tensorflow/core/tpu/kernels/BUILD | 2 + 30 files changed, 668 insertions(+), 388 deletions(-) create mode 100644 tensorflow/compiler/tf2xla/xla_argument.cc create mode 100644 tensorflow/compiler/tf2xla/xla_argument.h diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index d091146c75a..ff255dd9cc1 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -308,6 +308,8 @@ cc_library( ], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:framework", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index b52a350dc48..ecbb1a5d200 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -95,6 +95,7 @@ cc_library( ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", @@ -115,6 +116,7 @@ cc_library( ":xla_kernel_creator", # buildcleaner: keep "//tensorflow/compiler/jit/kernels:xla_ops", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep "//tensorflow/core:core_cpu_internal", @@ -172,6 +174,7 @@ XLA_DEVICE_DEPS = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", @@ -343,6 +346,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla/client:client_library", @@ -406,6 +410,7 @@ cc_library( ":compilation_passes", "//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -641,6 +646,7 @@ cc_library( "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/cc:xla_jit_ops", "//tensorflow/compiler/tf2xla/cc:xla_ops", "//tensorflow/compiler/xla:status_macros", @@ -700,6 +706,7 @@ cc_library( hdrs = ["device_util.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/core:framework", @@ -914,6 +921,7 @@ cc_library( "//tensorflow/compiler/jit/graphcycles", "//tensorflow/compiler/tf2xla:resource_operation_table", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/core:core_cpu", diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index 347bae087df..eb9ad8a2e85 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -21,6 +21,7 @@ XLA_OPS_DEPS = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:executable_run_options", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 57f923caa91..01c187790b7 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -150,6 +150,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/tensorflow:translate_registration", "//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op", "//tensorflow/compiler/mlir/xla:xla_mlir_translate", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 2a800cfc8c4..fe1f47d8d69 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -1477,10 +1477,13 @@ COMPILE_MLIR_UTIL_DEPS = [ "//tensorflow/compiler/mlir/xla:xla_legalize_tf", "//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_argument", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core/common_runtime:core_cpu_internal", + "//tensorflow/core/platform:logging", "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", - "//tensorflow/core/platform:logging", "//tensorflow/stream_executor/lib", "//tensorflow/compiler/xla:xla_data_proto_cc", "//tensorflow/compiler/xla/service:hlo", diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 5e548da55f1..16bc851d3a6 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -83,7 +83,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string, Status GetXlaInputShapes( mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::vector<xla::Shape>* xla_input_shapes) { xla_input_shapes->clear(); @@ -135,9 +135,8 @@ Status GetXlaInputShapes( // output based on static shapes in MLIR module Status GetOutputInfo( mlir::ModuleOp module, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_output_shape, - std::vector<XlaCompiler::OutputDescription>* outputs) { + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) { auto shape_representation_fn_no_fast_memory = [shape_representation_fn](const TensorShape& shape, DataType dtype) { return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false); @@ -161,7 +160,7 @@ Status GetOutputInfo( // Construct OutputDescription for result. outputs->emplace_back(); - XlaCompiler::OutputDescription& out_desc = outputs->back(); + XlaOutputDescription& out_desc = outputs->back(); TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type)); // TODO(ycao): Support constant output. out_desc.is_constant = false; @@ -185,7 +184,7 @@ Status GetOutputInfo( // TODO(ycao): Implement logic to compute resource updates when we need to // support graphs with resource updates in MLIR-based TF compiler bridge. void GetResourceUpdatesForMlir( - std::vector<XlaCompiler::ResourceUpdate>* resource_updates) { + std::vector<XlaResourceUpdate>* resource_updates) { resource_updates->clear(); } @@ -265,7 +264,7 @@ Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) { mlir::PassManager tf2xla(module_op.getContext()); tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass()); @@ -341,8 +340,8 @@ Status ConvertMLIRToXlaComputation( static Status CompileMlirToXlaHlo( mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) { if (VLOG_IS_ON(1)) tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op); @@ -391,8 +390,8 @@ static Status CompileMlirToXlaHlo( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) { RegisterDialects(); mlir::MLIRContext mlir_context; @@ -411,16 +410,16 @@ Status CompileSerializedMlirToXlaHlo( // removed from the signature. // Returns the original indices for the other arguments on success. static StatusOr<std::vector<int>> RewriteWithArgs( - mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) { + mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) { mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main"); std::vector<int> params; auto builder = mlir::OpBuilder(main_fn.getBody()); std::vector<int> args_to_erase; for (int idx = 0; idx < args.size(); idx++) { - const XlaCompiler::Argument& xla_arg = args[idx]; + const XlaArgument& xla_arg = args[idx]; mlir::BlockArgument mlir_arg = main_fn.getArgument(idx); - if (xla_arg.kind != XlaCompiler::Argument::kConstant) { + if (xla_arg.kind != XlaArgument::kConstant) { params.push_back(idx); continue; } @@ -439,11 +438,11 @@ static StatusOr<std::vector<int>> RewriteWithArgs( } Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args, + const Graph& graph, llvm::ArrayRef<const XlaArgument> args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) { RegisterDialects(); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h index 24b60dcb346..719a96f52d4 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.h @@ -20,7 +20,10 @@ limitations under the License. #include "llvm/ADT/StringRef.h" #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/protobuf/graph_debug_info.pb.h" #include "tensorflow/stream_executor/lib/statusor.h" @@ -57,7 +60,7 @@ Status ConvertMLIRToXlaComputation( mlir::ModuleOp module_op, llvm::StringRef device_type, xla::XlaComputation* xla_computation, bool use_tuple_args, bool return_tuple, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {}); // Compiles a serialized MLIR module into XLA HLO, generates all accompanying @@ -65,17 +68,17 @@ Status ConvertMLIRToXlaComputation( Status CompileSerializedMlirToXlaHlo( llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes, llvm::StringRef device_type, bool use_tuple_args, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {}); // Same as the above but takes input as TensorFlow Graph. Status CompileGraphToXlaHlo( - const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args, + const Graph& graph, llvm::ArrayRef<const XlaArgument> args, llvm::StringRef device_type, bool use_tuple_args, const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info, - const XlaCompiler::ShapeRepresentationFn shape_representation_fn, - XlaCompiler::CompilationResult* compilation_result, + const XlaHelpers::ShapeRepresentationFn shape_representation_fn, + XlaCompilationResult* compilation_result, std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {}); } // namespace tensorflow diff --git a/tensorflow/compiler/mlir/xla/BUILD b/tensorflow/compiler/mlir/xla/BUILD index 838b060079c..55daec0395e 100644 --- a/tensorflow/compiler/mlir/xla/BUILD +++ b/tensorflow/compiler/mlir/xla/BUILD @@ -92,7 +92,11 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op", "//tensorflow/compiler/mlir/tensorflow:lower_tf_lib", "//tensorflow/compiler/mlir/tensorflow:translate_utils", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_compilation_device", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_expression", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", @@ -125,8 +129,10 @@ cc_library( "//tensorflow/compiler/mlir/hlo", "//tensorflow/compiler/mlir/hlo:hlo_dialect_registration", "//tensorflow/compiler/mlir/hlo:lhlo", + "//tensorflow/compiler/xla:debug_options_flags", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:backend", "//tensorflow/compiler/xla/service:buffer_assignment", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_casting_utils", @@ -228,7 +234,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_type", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/tf2xla:common", - "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/xla:comparison_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc index a4c3c43cfbf..e45cf1b56ee 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.cc @@ -43,7 +43,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h" #include "tensorflow/compiler/mlir/xla/type_to_shape.h" #include "tensorflow/compiler/tf2xla/shape_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/xla/client/lib/matrix.h" #include "tensorflow/compiler/xla/client/lib/quantize.h" #include "tensorflow/compiler/xla/client/lib/slicing.h" @@ -463,7 +462,7 @@ class ConvertToHloModule { // single value. explicit ConvertToHloModule( mlir::ModuleOp module, bool use_tuple_args, bool return_tuple, - tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn) + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) : module_(module), module_builder_("main"), use_tuple_args_(use_tuple_args), @@ -545,7 +544,7 @@ class ConvertToHloModule { // Shape representation function to determine entry function argument and // result shapes. - tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_; + tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_; // Unique suffix to give to the name of the next lowered region. size_t region_id_ = 0; @@ -1500,7 +1499,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module, Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, - const tensorflow::XlaCompiler::ShapeRepresentationFn + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn) { mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext()); ConvertToHloModule converter(module, use_tuple_args, return_tuple, diff --git a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h index 8bfe4c76b04..d84aa92d3e2 100644 --- a/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h +++ b/tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h @@ -18,9 +18,10 @@ limitations under the License. #include "mlir/IR/Module.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" -#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/core/framework/tensor_shape.h" namespace mlir { @@ -33,7 +34,7 @@ namespace mlir { // single value. Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto, bool use_tuple_args, bool return_tuple, - const tensorflow::XlaCompiler::ShapeRepresentationFn + const tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr); // Creates XlaOp equivalent of a given MLIR operation using the operand info diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 34e12d3300e..1743ae7be17 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -48,7 +48,8 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/device_factory.h" @@ -410,7 +411,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(), shape_or.ValueOrDie()); tensorflow::Tensor& tensor = tensors.back(); - tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor); + tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor); inputs.emplace_back(&tensor); } @@ -438,7 +439,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { for (int i = 0, e = op_->getNumResults(); i < e; i++) { tensorflow::Tensor* output = op_context.mutable_output(i); const tensorflow::XlaExpression* expr = - tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output); + tensorflow::XlaExpression::CastExpressionFromTensor(*output); if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp) return op_->emitError( "expects XlaExpression of kind kXlaOp in compiled output"); diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index 519068893e7..d45f1ba8ec6 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -37,6 +37,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h" #include "tensorflow/compiler/mlir/xla/hlo_utils.h" #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h" +#include "tensorflow/compiler/xla/debug_options_flags.h" +#include "tensorflow/compiler/xla/service/backend.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" #include "tensorflow/compiler/xla/service/hlo_casting_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 663e34c2b8e..1e57c11b2cf 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -50,6 +50,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":xla_compiler", + ":xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -145,6 +146,7 @@ cc_library( ":tf2xla_proto_cc", ":tf2xla_util", ":xla_compiler", + ":xla_op_registry", "//tensorflow/compiler/aot:aot_only_var_handle_op", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/compiler/xla/client", @@ -316,14 +318,8 @@ cc_library( srcs = [ "const_analysis.cc", "graph_compiler.cc", - "xla_compilation_device.cc", "xla_compiler.cc", - "xla_context.cc", - "xla_expression.cc", - "xla_helpers.cc", "xla_op_kernel.cc", - "xla_op_registry.cc", - "xla_resource.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", @@ -333,14 +329,10 @@ cc_library( hdrs = [ "const_analysis.h", "graph_compiler.h", - "xla_compilation_device.h", "xla_compiler.h", - "xla_context.h", - "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", - "xla_resource.h", ], visibility = [":friends"], deps = [ @@ -351,10 +343,18 @@ cc_library( ":sharding_util", ":side_effect_util", ":tf2xla_util", + ":xla_argument", + ":xla_compilation_device", + ":xla_context", + ":xla_expression", + ":xla_helpers", + ":xla_op_registry", + ":xla_resource", "//tensorflow/compiler/jit:common", "//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes", "//tensorflow/compiler/tf2xla/lib:util", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:shape_util", @@ -370,6 +370,7 @@ cc_library( "//tensorflow/compiler/xla/client:xla_computation", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -388,6 +389,172 @@ cc_library( alwayslink = 1, ) +cc_library( + name = "xla_compilation_device", + srcs = [ + "xla_compilation_device.cc", + ], + hdrs = [ + "xla_compilation_device.h", + ], + deps = [ + ":common", + ":frontend_attributes_util", + ":sharding_util", + ":xla_context", + ":xla_helpers", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:session_options", + "//tensorflow/core/common_runtime:core_cpu_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_context", + srcs = [ + "xla_context.cc", + ], + hdrs = [ + "xla_context.h", + ], + deps = [ + ":common", + ":xla_expression", + ":xla_helpers", + "//tensorflow/compiler/xla:literal", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core/common_runtime:core_cpu_internal", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_op_registry", + srcs = [ + "xla_op_registry.cc", + ], + hdrs = [ + "xla_op_registry.h", + ], + visibility = [":friends"], + deps = [ + ":common", + ":xla_context", + "//tensorflow/compiler/jit:flags", + "//tensorflow/compiler/jit:xla_cluster_util", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:session_options", + "//tensorflow/core:stream_executor_no_cuda", + "//tensorflow/core/common_runtime:core_cpu_internal", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_expression", + srcs = [ + "xla_expression.cc", + ], + hdrs = [ + "xla_expression.h", + ], + deps = [ + ":common", + ":xla_resource", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/types:optional", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_resource", + srcs = [ + "xla_resource.cc", + ], + hdrs = [ + "xla_resource.h", + ], + deps = [ + ":common", + ":sharding_util", + ":xla_helpers", + "//tensorflow/compiler/xla:xla_data_proto_cc", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_helpers", + srcs = [ + "xla_helpers.cc", + ], + hdrs = [ + "xla_helpers.h", + ], + visibility = [":friends"], + deps = [ + ":common", + ":host_compute_metadata_proto_cc", + "//tensorflow/compiler/tf2xla/lib:util", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/client:xla_computation", + "//tensorflow/compiler/xla/client/lib:arithmetic", + "//tensorflow/compiler/xla/client/lib:constants", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + +cc_library( + name = "xla_argument", + srcs = [ + "xla_argument.cc", + ], + hdrs = [ + "xla_argument.h", + ], + deps = [ + ":host_compute_metadata_proto_cc", + ":xla_resource", + "//tensorflow/compiler/xla/client:xla_builder", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:framework", + "@com_google_absl//absl/types:span", + ], + alwayslink = 1, +) + cc_library( name = "common", srcs = [ @@ -564,6 +731,8 @@ tf_cc_test( ":common", ":side_effect_util", ":xla_compiler", + ":xla_expression", + ":xla_resource", "//tensorflow/cc:cc_ops", "//tensorflow/cc:function_ops", "//tensorflow/cc:functional_ops", diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index ec0cb9c0b66..26051c98cb7 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -145,7 +145,12 @@ tf_kernel_library( "//tensorflow/compiler/jit:xla_activity_listener", "//tensorflow/compiler/jit:xla_activity_proto_cc", "//tensorflow/compiler/tf2xla:common", + "//tensorflow/compiler/tf2xla:xla_compilation_device", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", + "//tensorflow/compiler/tf2xla:xla_resource", "//tensorflow/compiler/tf2xla/lib:broadcast", "//tensorflow/compiler/tf2xla/lib:data_format", "//tensorflow/compiler/tf2xla/lib:random", @@ -223,6 +228,8 @@ cc_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -276,6 +283,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla:status_macros", @@ -296,6 +305,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", @@ -314,6 +325,8 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_context", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/ops:xla_ops", "//tensorflow/compiler/xla:literal", "//tensorflow/compiler/xla/client:xla_builder", @@ -333,6 +346,7 @@ tf_kernel_library( ], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:array_ops_op_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/tf2xla/lib/BUILD b/tensorflow/compiler/tf2xla/lib/BUILD index f0bd97c85eb..531679d3905 100644 --- a/tensorflow/compiler/tf2xla/lib/BUILD +++ b/tensorflow/compiler/tf2xla/lib/BUILD @@ -38,6 +38,7 @@ cc_library( hdrs = ["random.h"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_helpers", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla/client:xla_builder", diff --git a/tensorflow/compiler/tf2xla/xla_argument.cc b/tensorflow/compiler/tf2xla/xla_argument.cc new file mode 100644 index 00000000000..fe31025386e --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_argument.cc @@ -0,0 +1,53 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_argument.h" + +namespace tensorflow { + +bool XlaArgument::operator==(const XlaArgument& other) const { + if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, + tensor_array_gradients) != + std::tie(other.kind, other.resource_kind, other.type, other.name, + other.initialized, other.max_array_size, + other.tensor_array_gradients)) { + return false; + } + if (absl::holds_alternative<xla::Shape>(shape)) { + if (!absl::holds_alternative<xla::Shape>(other.shape)) { + return false; + } + if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape), + absl::get<xla::Shape>(other.shape))) { + return false; + } + } else { + if (!absl::holds_alternative<TensorShape>(other.shape)) { + return false; + } + if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) { + return false; + } + } + if (constant_value.shape() != other.constant_value.shape()) { + return false; + } + if (is_same_data_across_replicas != other.is_same_data_across_replicas) { + return false; + } + return constant_value.tensor_data() == other.constant_value.tensor_data(); +} + +} // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h new file mode 100644 index 00000000000..e2cd634e1d5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -0,0 +1,121 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ + +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" +#include "tensorflow/core/framework/tensor.h" + +namespace tensorflow { + +// Describes how to derive the value of each _Arg node in the graph/function +// being compiled. There must be one Argument for each _Arg index. +struct XlaArgument { + enum Kind { + // Default value; not a valid kind. + kInvalid, + + // Argument is a compile-time constant. No associated runtime parameter. + kConstant, + + // Argument is a Variable, TensorArray, or Stack resource. Has an + // associated runtime parameter iff `initialized` is true. + kResource, + + // Argument is a run-time parameter. + kParameter, + + // Argument is an XLA token. + kToken, + + // Argument is a TensorList. + kTensorList, + }; + + Kind kind = kInvalid; + + // The type of the argument. If the argument is a resource, this + // is the type of the variable's value, not DT_RESOURCE. + DataType type = DT_INVALID; + + // The shape of the argument. For: + // * a parameter: the shape of the parameter. We allow setting the xla shape + // if known. This helps avoid conversions to and from TensorShape. + // * a constant: ignored; the shape given by constant_value is used + // instead. + // * an uninitialized resource: ignored. We don't yet know the shape of an + // uninitialized resource (otherwise we would have initialized it!) + // * an initialized variable: the shape of the variable's value. + // * an initialized TensorArray or Stack resource: the shape of an entry in + // the TensorArray/Stack. Note this is the size of a single entry, not the + // XLA data structure that represents the complete stack/array. + absl::variant<TensorShape, xla::Shape> shape; + + // The value of the argument, if it is a compile-time constant. Must be a + // host-memory tensor. + Tensor constant_value; + + // The name of this argument, used for debugging. + string name; + + // The name of TensorFlow _Arg node, used for debugging. + string node_name; + + // For a kResource, what kind of resource is it? + XlaResource::Kind resource_kind = XlaResource::kInvalid; + + // For a kResource, has this resource been initialized? + bool initialized = false; + + // For a kResource, is this resource on Fast Memory. + bool fast_mem = false; + + // For a TensorArray or Stack resource, what is the array's declared size? + // (Used for lazy initialization.) + int64 max_array_size = -1; + + // TensorArray resource parameters are passed as (array, gradient array 0, + // ..., gradient array k), where the gradient arrays are in the same order + // as `tensor_array_gradients`. + std::set<string> tensor_array_gradients; + + // dynamic dims to arg number map. Empty if no dynamic shapes. + std::map<int32, int32> dynamic_dim_to_arg_num_map; + bool is_pad_arg = false; + + // Whether this argument will receive the same data across all replicas. + bool is_same_data_across_replicas = false; + + bool operator==(const XlaArgument& other) const; + + // Returns a human-readable summary of the argument. + string HumanString() const; + + // Returns the dimension sizes for either TensorShape or xla::Shape. + std::vector<int64> DimensionSizes() const; + absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const; + + // Returns the human-readable string for either TensorShape or xla::Shape. + string ShapeHumanString() const; +}; + +} // end namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0722c30787f..db54f2f6563 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -422,39 +422,6 @@ Status BuildComputation( } // namespace -bool XlaCompiler::Argument::operator==( - const XlaCompiler::Argument& other) const { - if (std::tie(kind, resource_kind, type, name, initialized, max_array_size, - tensor_array_gradients) != - std::tie(other.kind, other.resource_kind, other.type, other.name, - other.initialized, other.max_array_size, - other.tensor_array_gradients)) { - return false; - } - if (absl::holds_alternative<xla::Shape>(shape)) { - if (!absl::holds_alternative<xla::Shape>(other.shape)) { - return false; - } - if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape), - absl::get<xla::Shape>(other.shape))) { - return false; - } - } else { - if (!absl::holds_alternative<TensorShape>(other.shape)) { - return false; - } - if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) { - return false; - } - } - if (constant_value.shape() != other.constant_value.shape()) { - return false; - } - if (is_same_data_across_replicas != other.is_same_data_across_replicas) { - return false; - } - return constant_value.tensor_data() == other.constant_value.tensor_data(); -} string XlaCompiler::Argument::HumanString() const { string common; @@ -1494,93 +1461,4 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) { return iter->second; } -XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() { - return [](const TensorShape& shape, DataType dtype, - bool use_fast_memory) -> xla::StatusOr<xla::Shape> { - xla::Shape xla_shape; - TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); - return xla_shape; - }; -} - -// Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( - const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape) { - if (sharding && !sharding->IsTileMaximal()) { - // After sharding, per core shape might have different layout. For example, - // before sharding, a shape [128, 128] will be assigned default - // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, - // the sharded shapes will have minor-to-major {0, 1}. - // - // As a result, for sharded shapes, we set their layout to per core shape's - // layout. - // - // TODO(endlessroad): for variable input & update, we might have - // different layouts which will prevent input output aliasing and - // increase memory usage. Investigate such cases. - int64 device = *sharding->tile_assignment().begin(); - std::vector<int64> offset = - sharding->TileOffsetForDevice(*xla_shape, device); - std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector<int64> dimensions(xla_shape->rank()); - for (int64 i = 0; i < xla_shape->rank(); ++i) { - dimensions[i] = limit[i] - offset[i]; - } - xla::Shape per_device_xla_shape = - xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); - TensorShape per_device_tensor_shape; - TF_RETURN_IF_ERROR( - XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - xla_shape->element_type())); - TF_ASSIGN_OR_RETURN(per_device_xla_shape, - shape_representation_fn(per_device_tensor_shape, dtype, - use_fast_memory)); - *xla_shape->mutable_layout() = per_device_xla_shape.layout(); - } - return Status::OK(); -} - -// There is a shape_representation_fn or sharding for an output, this function -// uses a reshape to fix the layout. -xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional<xla::OpSharding> sharding, bool fast_mem) { - if (original_shape.IsTuple()) { - std::vector<xla::XlaOp> elements; - for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { - auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; - TF_ASSIGN_OR_RETURN(auto element, - ReshapeWithCorrectRepresentationAndSharding( - builder, xla::GetTupleElement(original, i), - original_shape.tuple_shapes(i), - shape_representation_fn, subsharding, fast_mem)); - elements.push_back(element); - } - return xla::Tuple(builder, elements); - } - if (!original_shape.IsArray()) return original; - TensorShape shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); - TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( - original_shape.element_type())); - TF_ASSIGN_OR_RETURN(auto to_shape, - shape_representation_fn(shape, dtype, fast_mem)); - if (sharding) { - TF_ASSIGN_OR_RETURN(auto hlo_sharding, - xla::HloSharding::FromProto(*sharding)); - TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( - hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); - } - if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64 i = 0; i < original_shape.rank(); ++i) { - to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); - } - } - return xla::Reshape(to_shape, original); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index b95d250636a..b0d93cde846 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,8 +21,10 @@ limitations under the License. #include "absl/types/span.h" #include "absl/types/variant.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" +#include "tensorflow/compiler/tf2xla/xla_argument.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -97,96 +99,7 @@ class XlaContext; // `tensor_array_gradients` ordered set. class XlaCompiler { public: - // Describes how to derive the value of each _Arg node in the graph/function - // being compiled. There must be one Argument for each _Arg index. - struct Argument { - enum Kind { - // Default value; not a valid kind. - kInvalid, - - // Argument is a compile-time constant. No associated runtime parameter. - kConstant, - - // Argument is a Variable, TensorArray, or Stack resource. Has an - // associated runtime parameter iff `initialized` is true. - kResource, - - // Argument is a run-time parameter. - kParameter, - - // Argument is an XLA token. - kToken, - - // Argument is a TensorList. - kTensorList, - }; - - Kind kind = kInvalid; - - // The type of the argument. If the argument is a resource, this - // is the type of the variable's value, not DT_RESOURCE. - DataType type = DT_INVALID; - - // The shape of the argument. For: - // * a parameter: the shape of the parameter. We allow setting the xla shape - // if known. This helps avoid conversions to and from TensorShape. - // * a constant: ignored; the shape given by constant_value is used - // instead. - // * an uninitialized resource: ignored. We don't yet know the shape of an - // uninitialized resource (otherwise we would have initialized it!) - // * an initialized variable: the shape of the variable's value. - // * an initialized TensorArray or Stack resource: the shape of an entry in - // the TensorArray/Stack. Note this is the size of a single entry, not the - // XLA data structure that represents the complete stack/array. - absl::variant<TensorShape, xla::Shape> shape; - - // The value of the argument, if it is a compile-time constant. Must be a - // host-memory tensor. - Tensor constant_value; - - // The name of this argument, used for debugging. - string name; - - // The name of TensorFlow _Arg node, used for debugging. - string node_name; - - // For a kResource, what kind of resource is it? - XlaResource::Kind resource_kind = XlaResource::kInvalid; - - // For a kResource, has this resource been initialized? - bool initialized = false; - - // For a kResource, is this resource on Fast Memory. - bool fast_mem = false; - - // For a TensorArray or Stack resource, what is the array's declared size? - // (Used for lazy initialization.) - int64 max_array_size = -1; - - // TensorArray resource parameters are passed as (array, gradient array 0, - // ..., gradient array k), where the gradient arrays are in the same order - // as `tensor_array_gradients`. - std::set<string> tensor_array_gradients; - - // dynamic dims to arg number map. Empty if no dynamic shapes. - std::map<int32, int32> dynamic_dim_to_arg_num_map; - bool is_pad_arg = false; - - // Whether this argument will receive the same data across all replicas. - bool is_same_data_across_replicas = false; - - bool operator==(const Argument& other) const; - - // Returns a human-readable summary of the argument. - string HumanString() const; - - // Returns the dimension sizes for either TensorShape or xla::Shape. - std::vector<int64> DimensionSizes() const; - absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const; - - // Returns the human-readable string for either TensorShape or xla::Shape. - string ShapeHumanString() const; - }; + using Argument = ::tensorflow::XlaArgument; // Options pertaining to an individual call to CompileGraph() or // CompileFunction(). @@ -221,77 +134,11 @@ class XlaCompiler { bool alias_resource_update = false; }; - struct OutputDescription { - // Type and shape of the output. The shape is the unflattened shape. - // When `type` is DT_RESOURCE, `shape` is the shape of the resource - // variable's value. - DataType type; - TensorShape shape; + using OutputDescription = ::tensorflow::XlaOutputDescription; - // Constant output value, if known to be constant at JIT compilation time. - // 'Tensor' is in host memory. - bool is_constant = false; - Tensor constant_value; + using ResourceUpdate = ::tensorflow::XlaResourceUpdate; - // When this output is a resource, i.e. `type == DT_RESOURCE`, this is - // the index of the input that contains the resource. - int input_index; - - // Whether this output is a TensorList. - bool is_tensor_list = false; - }; - - // Describes a variable write side effect of the computation. - struct ResourceUpdate { - // Index of the input that contains the variable resource to write to. - int input_index; - - // Type and shape of the tensor to be written back. - // The `shape` field has the same meaning as the Argument::shape field. - DataType type; - TensorShape shape; - - // Was the value of the variable modified by the computation? - // (Always true, unless `return_updated_values_for_all_resources` is true.) - bool modified; - - // If the resource is a TensorArray, the set of gradients read or written. - std::set<string> tensor_array_gradients_accessed; - }; - - struct CompilationResult { - // Vector that maps from the parameters of the XLA computation to their - // original argument positions. To handle compile-time constant inputs, the - // parameters to the XLA computation may be a subset of the original - // arguments. The relative ordering of parameters are maintained. - std::vector<int> input_mapping; - - // Input shapes of the computation. If we are flattening inputs, these are - // the flattened shapes. - std::vector<xla::Shape> xla_input_shapes; - - // Output shape in XLA format. The output shape is always a tuple. If we - // are flattening outputs, these are the flattened shapes. - xla::Shape xla_output_shape; - - // TensorFlow shapes of outputs, together with the values of any - // constant arguments. Vector indexed by Tensorflow _Retval number, - // containing both constant and non-constant results. - std::vector<OutputDescription> outputs; - - // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their - // matching RecvAtHost/SendFromHost Ops in the outer graph. - tf2xla::HostComputeMetadata host_compute_metadata; - - // Resources whose values were updated by the computation, ordered - // by return value position (which is the same as the order the resources - // were passed as arguments). Resource updates follow the non-constant - // results in the outputs of XLA computation. - std::vector<ResourceUpdate> resource_updates; - - // The XLA computation built from the tensorflow subgraph. - std::shared_ptr<xla::XlaComputation> computation; - }; + using CompilationResult = ::tensorflow::XlaCompilationResult; typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType, bool)> @@ -518,21 +365,6 @@ class XlaCompiler { TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler); }; -// Creates an identity shape representation function. -XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn(); - -// Rewrites the layout of xla_shape if there is tiled sharding. -Status RewriteLayoutWithShardedShape( - const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - xla::Shape* xla_shape); - -// Adds reshapes to fix the layout of an output, if a shape_representation_fn or -// sharding is present. -xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( - xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, - XlaCompiler::ShapeRepresentationFn shape_representation_fn, - absl::optional<xla::OpSharding> sharding, bool fast_mem); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index c94c4805d53..cb5bf34208f 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -24,7 +24,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index eb4ad3fe6a1..e44ac05b702 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,7 +20,6 @@ limitations under the License. #include <vector> -#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -33,6 +32,7 @@ limitations under the License. namespace tensorflow { class XlaOpKernelContext; +class XlaCompiler; // The XlaContext is the data structure that holds the state of an XLA // compilation, that is accessible from OpKernelContexts when compiling a diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index 49f108ed6c8..34e108bb6bf 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -163,4 +163,23 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const { } } +const XlaExpression* XlaExpression::CastExpressionFromTensor( + const Tensor& tensor) { + const XlaExpression* expression = + reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); + return expression; +} + +// Assigns an XlaExpression to a tensor on an XLA compilation device. +void XlaExpression::AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor) { + const XlaExpression* expression = + reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast<XlaExpression*>(expression) = value; +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 5d0bb35b182..3010964c5b7 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -104,6 +104,13 @@ class XlaExpression { // not the shape of the resource's value. xla::StatusOr<TensorShape> GetShape() const; + // Retrieves an XlaExpression that was allocated by a previous Op. + static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); + + // Assigns an XlaExpression to a tensor on an XLA compilation device. + static void AssignExpressionToTensor(const XlaExpression& value, + Tensor* tensor); + private: Kind kind_ = Kind::kInvalid; diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 74247bbaec7..8c4b55aec8a 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -22,8 +22,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" -#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" #include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -128,4 +126,93 @@ xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand, return xla::ConvertElementType(operand, convert_to); } +XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() { + return [](const TensorShape& shape, DataType dtype, + bool use_fast_memory) -> xla::StatusOr<xla::Shape> { + xla::Shape xla_shape; + TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape)); + return xla_shape; + }; +} + +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape) { + if (sharding && !sharding->IsTileMaximal()) { + // After sharding, per core shape might have different layout. For example, + // before sharding, a shape [128, 128] will be assigned default + // minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2, + // the sharded shapes will have minor-to-major {0, 1}. + // + // As a result, for sharded shapes, we set their layout to per core shape's + // layout. + // + // TODO(endlessroad): for variable input & update, we might have + // different layouts which will prevent input output aliasing and + // increase memory usage. Investigate such cases. + int64 device = *sharding->tile_assignment().begin(); + std::vector<int64> offset = + sharding->TileOffsetForDevice(*xla_shape, device); + std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device); + std::vector<int64> dimensions(xla_shape->rank()); + for (int64 i = 0; i < xla_shape->rank(); ++i) { + dimensions[i] = limit[i] - offset[i]; + } + xla::Shape per_device_xla_shape = + xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions); + TensorShape per_device_tensor_shape; + TF_RETURN_IF_ERROR( + XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + xla_shape->element_type())); + TF_ASSIGN_OR_RETURN(per_device_xla_shape, + shape_representation_fn(per_device_tensor_shape, dtype, + use_fast_memory)); + *xla_shape->mutable_layout() = per_device_xla_shape.layout(); + } + return Status::OK(); +} + +// There is a shape_representation_fn or sharding for an output, this function +// uses a reshape to fix the layout. +xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + absl::optional<xla::OpSharding> sharding, bool fast_mem) { + if (original_shape.IsTuple()) { + std::vector<xla::XlaOp> elements; + for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) { + auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; + TF_ASSIGN_OR_RETURN(auto element, + ReshapeWithCorrectRepresentationAndSharding( + builder, xla::GetTupleElement(original, i), + original_shape.tuple_shapes(i), + shape_representation_fn, subsharding, fast_mem)); + elements.push_back(element); + } + return xla::Tuple(builder, elements); + } + if (!original_shape.IsArray()) return original; + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape)); + TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType( + original_shape.element_type())); + TF_ASSIGN_OR_RETURN(auto to_shape, + shape_representation_fn(shape, dtype, fast_mem)); + if (sharding) { + TF_ASSIGN_OR_RETURN(auto hlo_sharding, + xla::HloSharding::FromProto(*sharding)); + TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape( + hlo_sharding, fast_mem, shape_representation_fn, &to_shape)); + } + if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { + for (int64 i = 0; i < original_shape.rank(); ++i) { + to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); + } + } + return xla::Reshape(to_shape, original); +} + } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index 490923526bd..3a9375ec1f4 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -19,8 +19,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_ #include "absl/types/span.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/core/framework/tensor.h" namespace tensorflow { @@ -72,6 +73,98 @@ class XlaHelpers { // than the xla::PrimitiveType. static xla::XlaOp ConvertElementType(const xla::XlaOp& operand, const DataType new_element_type); + + typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType, + bool)> + ShapeRepresentationFn; +}; + +// Creates an identity shape representation function. +XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn(); + +// Rewrites the layout of xla_shape if there is tiled sharding. +Status RewriteLayoutWithShardedShape( + const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + xla::Shape* xla_shape); + +// Adds reshapes to fix the layout of an output, if a shape_representation_fn or +// sharding is present. +xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding( + xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape, + XlaHelpers::ShapeRepresentationFn shape_representation_fn, + absl::optional<xla::OpSharding> sharding, bool fast_mem); + +struct XlaOutputDescription { + // Type and shape of the output. The shape is the unflattened shape. + // When `type` is DT_RESOURCE, `shape` is the shape of the resource + // variable's value. + DataType type; + TensorShape shape; + + // Constant output value, if known to be constant at JIT compilation time. + // 'Tensor' is in host memory. + bool is_constant = false; + Tensor constant_value; + + // When this output is a resource, i.e. `type == DT_RESOURCE`, this is + // the index of the input that contains the resource. + int input_index; + + // Whether this output is a TensorList. + bool is_tensor_list = false; +}; + +// Describes a variable write side effect of the computation. +struct XlaResourceUpdate { + // Index of the input that contains the variable resource to write to. + int input_index; + + // Type and shape of the tensor to be written back. + // The `shape` field has the same meaning as the Argument::shape field. + DataType type; + TensorShape shape; + + // Was the value of the variable modified by the computation? + // (Always true, unless `return_updated_values_for_all_resources` is true.) + bool modified; + + // If the resource is a TensorArray, the set of gradients read or written. + std::set<string> tensor_array_gradients_accessed; +}; + +struct XlaCompilationResult { + // Vector that maps from the parameters of the XLA computation to their + // original argument positions. To handle compile-time constant inputs, the + // parameters to the XLA computation may be a subset of the original + // arguments. The relative ordering of parameters are maintained. + std::vector<int> input_mapping; + + // Input shapes of the computation. If we are flattening inputs, these are + // the flattened shapes. + std::vector<xla::Shape> xla_input_shapes; + + // Output shape in XLA format. The output shape is always a tuple. If we + // are flattening outputs, these are the flattened shapes. + xla::Shape xla_output_shape; + + // TensorFlow shapes of outputs, together with the values of any + // constant arguments. Vector indexed by Tensorflow _Retval number, + // containing both constant and non-constant results. + std::vector<XlaOutputDescription> outputs; + + // TensorFlow shapes and types of sends/recvs from HostCompute Ops to their + // matching RecvAtHost/SendFromHost Ops in the outer graph. + tf2xla::HostComputeMetadata host_compute_metadata; + + // Resources whose values were updated by the computation, ordered + // by return value position (which is the same as the order the resources + // were passed as arguments). Resource updates follow the non-constant + // results in the outputs of XLA computation. + std::vector<XlaResourceUpdate> resource_updates; + + // The XLA computation built from the tensorflow subgraph. + std::shared_ptr<xla::XlaComputation> computation; }; } // end namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 27766408716..735a6c7291e 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -49,33 +49,13 @@ XlaCompiler* XlaOpKernelContext::compiler() const { return xla_context()->compiler(); } -// Retrieves an XlaExpression that was allocated by a previous Op. -const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor( - const Tensor& tensor) { - const XlaExpression* expression = - reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data()); - CHECK(expression->kind() != XlaExpression::Kind::kInvalid) - << expression->HumanString(); - return expression; -} - -// Assigns an XlaExpression to a tensor on an XLA compilation device. -void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor) { - const XlaExpression* expression = - reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data()); - CHECK(expression->kind() == XlaExpression::Kind::kInvalid) - << expression->HumanString(); - *const_cast<XlaExpression*>(expression) = value; -} - const XlaExpression& XlaOpKernelContext::InputExpression(int index) { - return *CastExpressionFromTensor(context_->input(index)); + return *XlaExpression::CastExpressionFromTensor(context_->input(index)); } const XlaExpression& XlaOpKernelContext::InputExpression( absl::string_view name) { - return *CastExpressionFromTensor(GetInputTensorByName(name)); + return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name)); } xla::XlaOp XlaOpKernelContext::Input(int index) { @@ -108,7 +88,8 @@ DataType XlaOpKernelContext::input_type(int index) const { if (type == DT_UINT8) { // Masqueraded XlaExpression could have different type. See // XlaOpKernelContext::SetOutputExpression for details. - auto expression = CastExpressionFromTensor(context_->input(index)); + auto expression = + XlaExpression::CastExpressionFromTensor(context_->input(index)); type = expression->dtype(); } return type; @@ -120,7 +101,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) { if (type == DT_UINT8) { // Masqueraded XlaExpression could have different type. See // XlaOpKernelContext::SetOutputExpression for details. - auto expression = CastExpressionFromTensor(tensor); + auto expression = XlaExpression::CastExpressionFromTensor(tensor); type = expression->dtype(); } return type; @@ -385,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); + handles->push_back( + XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -408,7 +390,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, TensorShape* shape, xla::XlaOp* value) { const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -461,7 +443,8 @@ Status XlaOpKernelContext::ReadVariableInput(absl::string_view name, Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, TensorShape* shape) const { const Tensor& tensor = context_->input(index); - const XlaExpression* expression = CastExpressionFromTensor(tensor); + const XlaExpression* expression = + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); @@ -502,8 +485,8 @@ void XlaOpKernelContext::SetOutputExpression(int index, TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); } - XlaOpKernelContext::AssignExpressionToTensor( - expression, context_->mutable_output(index)); + XlaExpression::AssignExpressionToTensor(expression, + context_->mutable_output(index)); return Status::OK(); }(); if (!status.ok()) { @@ -542,7 +525,7 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { const XlaExpression* expression = - CastExpressionFromTensor(context_->input(index)); + XlaExpression::CastExpressionFromTensor(context_->input(index)); TF_RET_CHECK(expression->resource() != nullptr); *resource = expression->resource(); return Status::OK(); @@ -554,7 +537,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type, const XlaOpKernelContext* ctx, xla::XlaOp handle, xla::XlaBuilder* builder) { const XlaExpression* expression = - XlaOpKernelContext::CastExpressionFromTensor(tensor); + XlaExpression::CastExpressionFromTensor(tensor); XlaResource* variable = expression->resource(); TF_RET_CHECK(variable != nullptr); TF_RET_CHECK(variable->kind() == XlaResource::kVariable); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 6987b6fbb98..3cf51e6ec6f 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -17,6 +17,9 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_ #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -284,13 +287,6 @@ class XlaOpKernelContext { // separate specialization of the computation for each DataType. const xla::XlaComputation* GetOrCreateMul(const DataType type); - // Assigns an XlaExpression to a tensor on an XLA compilation device. - static void AssignExpressionToTensor(const XlaExpression& value, - Tensor* tensor); - - // Retrieves an XlaExpression that was assigned to the specified tensor. - static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor); - private: // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index 32d42cb8a42..bec0b46611d 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -21,7 +21,6 @@ limitations under the License. #include "absl/memory/memory.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/sharding_util.h" -#include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/xla/client/xla_builder.h" diff --git a/tensorflow/core/tpu/BUILD b/tensorflow/core/tpu/BUILD index 30a90c1da6c..0a17ba3d408 100644 --- a/tensorflow/core/tpu/BUILD +++ b/tensorflow/core/tpu/BUILD @@ -57,6 +57,7 @@ cc_library( ":tpu_defs", ":tpu_node_device_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", ], alwayslink = 1, ) @@ -180,8 +181,8 @@ cc_library( "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:tf2xla_util", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index 7f64758d238..e5f49158231 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -656,6 +656,7 @@ cc_library( deps = [ "//tensorflow/compiler/tf2xla:common", "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto_cc", @@ -673,6 +674,7 @@ cc_library( srcs = ["topk_ops.cc"], deps = [ "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/xla/client:xla_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", "//tensorflow/core/tpu:tpu_defs",