[TF2XLA] [NFC] Break apart the [TF2XLA/MLIR] -> xla_compiler dependency edge
This is needed for invoking the MLIR tf2xla bridge from xla_compiler. This CL breaks apart items from xla_compiler into individual build targets, which are then depended on from the MLIR TF bridge. PiperOrigin-RevId: 323640340 Change-Id: I78b972503db9e7b5254014ca7e889005490d8339
This commit is contained in:
parent
5198b44674
commit
bcfb60d0a1
tensorflow
@ -308,6 +308,8 @@ cc_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:framework",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -95,6 +95,7 @@ cc_library(
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:cpu_plugin", # buildcleaner: keep
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -115,6 +116,7 @@ cc_library(
|
||||
":xla_kernel_creator", # buildcleaner: keep
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/service:gpu_plugin", # buildcleaner: keep
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
@ -172,6 +174,7 @@ XLA_DEVICE_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
@ -343,6 +346,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
@ -406,6 +410,7 @@ cc_library(
|
||||
":compilation_passes",
|
||||
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
@ -641,6 +646,7 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
|
||||
"//tensorflow/compiler/tf2xla/cc:xla_ops",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -700,6 +706,7 @@ cc_library(
|
||||
hdrs = ["device_util.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/core:framework",
|
||||
@ -914,6 +921,7 @@ cc_library(
|
||||
"//tensorflow/compiler/jit/graphcycles",
|
||||
"//tensorflow/compiler/tf2xla:resource_operation_table",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/core:core_cpu",
|
||||
|
@ -21,6 +21,7 @@ XLA_OPS_DEPS = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:executable_run_options",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
|
@ -150,6 +150,7 @@ tf_cc_binary(
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_registration",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/xla:xla_mlir_translate",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:tensorflow",
|
||||
|
@ -1477,10 +1477,13 @@ COMPILE_MLIR_UTIL_DEPS = [
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf",
|
||||
"//tensorflow/compiler/mlir/xla:xla_legalize_tf_with_tf2xla",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_argument",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/platform:logging",
|
||||
"//tensorflow/stream_executor/lib",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
|
@ -83,7 +83,7 @@ Status ParseMlirModule(llvm::StringRef mlir_module_string,
|
||||
Status GetXlaInputShapes(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<xla::Shape>* xla_input_shapes) {
|
||||
xla_input_shapes->clear();
|
||||
|
||||
@ -135,9 +135,8 @@ Status GetXlaInputShapes(
|
||||
// output based on static shapes in MLIR module
|
||||
Status GetOutputInfo(
|
||||
mlir::ModuleOp module,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_output_shape,
|
||||
std::vector<XlaCompiler::OutputDescription>* outputs) {
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_output_shape, std::vector<XlaOutputDescription>* outputs) {
|
||||
auto shape_representation_fn_no_fast_memory =
|
||||
[shape_representation_fn](const TensorShape& shape, DataType dtype) {
|
||||
return shape_representation_fn(shape, dtype, /*use_fast_memory=*/false);
|
||||
@ -161,7 +160,7 @@ Status GetOutputInfo(
|
||||
|
||||
// Construct OutputDescription for result.
|
||||
outputs->emplace_back();
|
||||
XlaCompiler::OutputDescription& out_desc = outputs->back();
|
||||
XlaOutputDescription& out_desc = outputs->back();
|
||||
TF_RETURN_IF_ERROR(ConvertToDataType(tensor_type, &out_desc.type));
|
||||
// TODO(ycao): Support constant output.
|
||||
out_desc.is_constant = false;
|
||||
@ -185,7 +184,7 @@ Status GetOutputInfo(
|
||||
// TODO(ycao): Implement logic to compute resource updates when we need to
|
||||
// support graphs with resource updates in MLIR-based TF compiler bridge.
|
||||
void GetResourceUpdatesForMlir(
|
||||
std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
|
||||
std::vector<XlaResourceUpdate>* resource_updates) {
|
||||
resource_updates->clear();
|
||||
}
|
||||
|
||||
@ -265,7 +264,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
mlir::PassManager tf2xla(module_op.getContext());
|
||||
tf2xla.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
@ -341,8 +340,8 @@ Status ConvertMLIRToXlaComputation(
|
||||
static Status CompileMlirToXlaHlo(
|
||||
mlir::ModuleOp module_op, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
if (VLOG_IS_ON(1))
|
||||
tensorflow::DumpMlirOpToFile("mlir_compile_before", module_op);
|
||||
@ -391,8 +390,8 @@ static Status CompileMlirToXlaHlo(
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
mlir::MLIRContext mlir_context;
|
||||
@ -411,16 +410,16 @@ Status CompileSerializedMlirToXlaHlo(
|
||||
// removed from the signature.
|
||||
// Returns the original indices for the other arguments on success.
|
||||
static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaCompiler::Argument> args) {
|
||||
mlir::ModuleOp module, llvm::ArrayRef<const XlaArgument> args) {
|
||||
mlir::FuncOp main_fn = module.lookupSymbol<mlir::FuncOp>("main");
|
||||
std::vector<int> params;
|
||||
|
||||
auto builder = mlir::OpBuilder(main_fn.getBody());
|
||||
std::vector<int> args_to_erase;
|
||||
for (int idx = 0; idx < args.size(); idx++) {
|
||||
const XlaCompiler::Argument& xla_arg = args[idx];
|
||||
const XlaArgument& xla_arg = args[idx];
|
||||
mlir::BlockArgument mlir_arg = main_fn.getArgument(idx);
|
||||
if (xla_arg.kind != XlaCompiler::Argument::kConstant) {
|
||||
if (xla_arg.kind != XlaArgument::kConstant) {
|
||||
params.push_back(idx);
|
||||
continue;
|
||||
}
|
||||
@ -439,11 +438,11 @@ static StatusOr<std::vector<int>> RewriteWithArgs(
|
||||
}
|
||||
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes) {
|
||||
RegisterDialects();
|
||||
|
||||
|
@ -20,7 +20,10 @@ limitations under the License.
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/protobuf/graph_debug_info.pb.h"
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
@ -57,7 +60,7 @@ Status ConvertMLIRToXlaComputation(
|
||||
mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
xla::XlaComputation* xla_computation, bool use_tuple_args,
|
||||
bool return_tuple,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn = nullptr,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
// Compiles a serialized MLIR module into XLA HLO, generates all accompanying
|
||||
@ -65,17 +68,17 @@ Status ConvertMLIRToXlaComputation(
|
||||
Status CompileSerializedMlirToXlaHlo(
|
||||
llvm::StringRef mlir_module_string, llvm::ArrayRef<TensorShape> arg_shapes,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
// Same as the above but takes input as TensorFlow Graph.
|
||||
Status CompileGraphToXlaHlo(
|
||||
const Graph& graph, llvm::ArrayRef<const XlaCompiler::Argument> args,
|
||||
const Graph& graph, llvm::ArrayRef<const XlaArgument> args,
|
||||
llvm::StringRef device_type, bool use_tuple_args,
|
||||
const FunctionLibraryDefinition& flib_def, const GraphDebugInfo& debug_info,
|
||||
const XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompiler::CompilationResult* compilation_result,
|
||||
const XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
XlaCompilationResult* compilation_result,
|
||||
std::vector<std::unique_ptr<mlir::Pass>> custom_legalization_passes = {});
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -92,7 +92,11 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:export_tf_dialect_op",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:translate_utils",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_expression",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:core_cpu_lib",
|
||||
"//tensorflow/core:framework",
|
||||
@ -125,8 +129,10 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:hlo_dialect_registration",
|
||||
"//tensorflow/compiler/mlir/hlo:lhlo",
|
||||
"//tensorflow/compiler/xla:debug_options_flags",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla/service:backend",
|
||||
"//tensorflow/compiler/xla/service:buffer_assignment",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_casting_utils",
|
||||
@ -228,7 +234,7 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow:convert_type",
|
||||
"//tensorflow/compiler/mlir/tensorflow:error_util",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/xla:comparison_util",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
|
@ -43,7 +43,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/convert_type.h"
|
||||
#include "tensorflow/compiler/mlir/xla/type_to_shape.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/matrix.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/quantize.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/slicing.h"
|
||||
@ -463,7 +462,7 @@ class ConvertToHloModule {
|
||||
// single value.
|
||||
explicit ConvertToHloModule(
|
||||
mlir::ModuleOp module, bool use_tuple_args, bool return_tuple,
|
||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn)
|
||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn)
|
||||
: module_(module),
|
||||
module_builder_("main"),
|
||||
use_tuple_args_(use_tuple_args),
|
||||
@ -545,7 +544,7 @@ class ConvertToHloModule {
|
||||
|
||||
// Shape representation function to determine entry function argument and
|
||||
// result shapes.
|
||||
tensorflow::XlaCompiler::ShapeRepresentationFn shape_representation_fn_;
|
||||
tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn_;
|
||||
|
||||
// Unique suffix to give to the name of the next lowered region.
|
||||
size_t region_id_ = 0;
|
||||
@ -1500,7 +1499,7 @@ LogicalResult AddDynamicParameterBindings(mlir::ModuleOp module,
|
||||
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||
shape_representation_fn) {
|
||||
mlir::StatusScopedDiagnosticHandler diag_handler(module.getContext());
|
||||
ConvertToHloModule converter(module, use_tuple_args, return_tuple,
|
||||
|
@ -18,9 +18,10 @@ limitations under the License.
|
||||
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@ -33,7 +34,7 @@ namespace mlir {
|
||||
// single value.
|
||||
Status ConvertMlirHloToHlo(mlir::ModuleOp module, ::xla::HloProto* hlo_proto,
|
||||
bool use_tuple_args, bool return_tuple,
|
||||
const tensorflow::XlaCompiler::ShapeRepresentationFn
|
||||
const tensorflow::XlaHelpers::ShapeRepresentationFn
|
||||
shape_representation_fn = nullptr);
|
||||
|
||||
// Creates XlaOp equivalent of a given MLIR operation using the operand info
|
||||
|
@ -48,7 +48,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/common_runtime/device_factory.h"
|
||||
@ -410,7 +411,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
||||
device_->GetAllocator(tensorflow::AllocatorAttributes()), expr.dtype(),
|
||||
shape_or.ValueOrDie());
|
||||
tensorflow::Tensor& tensor = tensors.back();
|
||||
tensorflow::XlaOpKernelContext::AssignExpressionToTensor(expr, &tensor);
|
||||
tensorflow::XlaExpression::AssignExpressionToTensor(expr, &tensor);
|
||||
inputs.emplace_back(&tensor);
|
||||
}
|
||||
|
||||
@ -438,7 +439,7 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() {
|
||||
for (int i = 0, e = op_->getNumResults(); i < e; i++) {
|
||||
tensorflow::Tensor* output = op_context.mutable_output(i);
|
||||
const tensorflow::XlaExpression* expr =
|
||||
tensorflow::XlaOpKernelContext::CastExpressionFromTensor(*output);
|
||||
tensorflow::XlaExpression::CastExpressionFromTensor(*output);
|
||||
if (expr->kind() != tensorflow::XlaExpression::Kind::kXlaOp)
|
||||
return op_->emitError(
|
||||
"expects XlaExpression of kind kXlaOp in compiled output");
|
||||
|
@ -37,6 +37,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
|
||||
#include "tensorflow/compiler/mlir/xla/hlo_utils.h"
|
||||
#include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
|
||||
#include "tensorflow/compiler/xla/debug_options_flags.h"
|
||||
#include "tensorflow/compiler/xla/service/backend.h"
|
||||
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
|
@ -50,6 +50,7 @@ cc_library(
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":xla_compiler",
|
||||
":xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
@ -145,6 +146,7 @@ cc_library(
|
||||
":tf2xla_proto_cc",
|
||||
":tf2xla_util",
|
||||
":xla_compiler",
|
||||
":xla_op_registry",
|
||||
"//tensorflow/compiler/aot:aot_only_var_handle_op",
|
||||
"//tensorflow/compiler/tf2xla/kernels:xla_ops",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
@ -316,14 +318,8 @@ cc_library(
|
||||
srcs = [
|
||||
"const_analysis.cc",
|
||||
"graph_compiler.cc",
|
||||
"xla_compilation_device.cc",
|
||||
"xla_compiler.cc",
|
||||
"xla_context.cc",
|
||||
"xla_expression.cc",
|
||||
"xla_helpers.cc",
|
||||
"xla_op_kernel.cc",
|
||||
"xla_op_registry.cc",
|
||||
"xla_resource.cc",
|
||||
"xla_cpu_backend.cc",
|
||||
] + if_cuda_is_configured([
|
||||
"xla_gpu_backend.cc",
|
||||
@ -333,14 +329,10 @@ cc_library(
|
||||
hdrs = [
|
||||
"const_analysis.h",
|
||||
"graph_compiler.h",
|
||||
"xla_compilation_device.h",
|
||||
"xla_compiler.h",
|
||||
"xla_context.h",
|
||||
"xla_expression.h",
|
||||
"xla_helpers.h",
|
||||
"xla_op_kernel.h",
|
||||
"xla_op_registry.h",
|
||||
"xla_resource.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
@ -351,10 +343,18 @@ cc_library(
|
||||
":sharding_util",
|
||||
":side_effect_util",
|
||||
":tf2xla_util",
|
||||
":xla_argument",
|
||||
":xla_compilation_device",
|
||||
":xla_context",
|
||||
":xla_expression",
|
||||
":xla_helpers",
|
||||
":xla_op_registry",
|
||||
":xla_resource",
|
||||
"//tensorflow/compiler/jit:common",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/mlir/tensorflow:compile_mlir_util_no_tf_dialect_passes",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -370,6 +370,7 @@ cc_library(
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
@ -388,6 +389,172 @@ cc_library(
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_compilation_device",
|
||||
srcs = [
|
||||
"xla_compilation_device.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_compilation_device.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":frontend_attributes_util",
|
||||
":sharding_util",
|
||||
":xla_context",
|
||||
":xla_helpers",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_context",
|
||||
srcs = [
|
||||
"xla_context.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_context.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_expression",
|
||||
":xla_helpers",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_op_registry",
|
||||
srcs = [
|
||||
"xla_op_registry.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_op_registry.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_context",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:xla_cluster_util",
|
||||
"//tensorflow/compiler/xla/client:client_library",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:session_options",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/core/common_runtime:core_cpu_internal",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_expression",
|
||||
srcs = [
|
||||
"xla_expression.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_expression.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":xla_resource",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/types:optional",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_resource",
|
||||
srcs = [
|
||||
"xla_resource.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_resource.h",
|
||||
],
|
||||
deps = [
|
||||
":common",
|
||||
":sharding_util",
|
||||
":xla_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_helpers",
|
||||
srcs = [
|
||||
"xla_helpers.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_helpers.h",
|
||||
],
|
||||
visibility = [":friends"],
|
||||
deps = [
|
||||
":common",
|
||||
":host_compute_metadata_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla/lib:util",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client:xla_computation",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/compiler/xla/client/lib:constants",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "xla_argument",
|
||||
srcs = [
|
||||
"xla_argument.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"xla_argument.h",
|
||||
],
|
||||
deps = [
|
||||
":host_compute_metadata_proto_cc",
|
||||
":xla_resource",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:framework",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
@ -564,6 +731,8 @@ tf_cc_test(
|
||||
":common",
|
||||
":side_effect_util",
|
||||
":xla_compiler",
|
||||
":xla_expression",
|
||||
":xla_resource",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:function_ops",
|
||||
"//tensorflow/cc:functional_ops",
|
||||
|
@ -145,7 +145,12 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/jit:xla_activity_listener",
|
||||
"//tensorflow/compiler/jit:xla_activity_proto_cc",
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compilation_device",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla:xla_resource",
|
||||
"//tensorflow/compiler/tf2xla/lib:broadcast",
|
||||
"//tensorflow/compiler/tf2xla/lib:data_format",
|
||||
"//tensorflow/compiler/tf2xla/lib:random",
|
||||
@ -223,6 +228,8 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
@ -276,6 +283,8 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
@ -296,6 +305,8 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -314,6 +325,8 @@ tf_kernel_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:side_effect_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_context",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/tf2xla/ops:xla_ops",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
@ -333,6 +346,7 @@ tf_kernel_library(
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
|
@ -38,6 +38,7 @@ cc_library(
|
||||
hdrs = ["random.h"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_helpers",
|
||||
"//tensorflow/compiler/xla:status_macros",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
|
53
tensorflow/compiler/tf2xla/xla_argument.cc
Normal file
53
tensorflow/compiler/tf2xla/xla_argument.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
bool XlaArgument::operator==(const XlaArgument& other) const {
|
||||
if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
|
||||
tensor_array_gradients) !=
|
||||
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
||||
other.initialized, other.max_array_size,
|
||||
other.tensor_array_gradients)) {
|
||||
return false;
|
||||
}
|
||||
if (absl::holds_alternative<xla::Shape>(shape)) {
|
||||
if (!absl::holds_alternative<xla::Shape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
|
||||
absl::get<xla::Shape>(other.shape))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!absl::holds_alternative<TensorShape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (constant_value.shape() != other.constant_value.shape()) {
|
||||
return false;
|
||||
}
|
||||
if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
|
||||
return false;
|
||||
}
|
||||
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
121
tensorflow/compiler/tf2xla/xla_argument.h
Normal file
121
tensorflow/compiler/tf2xla/xla_argument.h
Normal file
@ -0,0 +1,121 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Describes how to derive the value of each _Arg node in the graph/function
|
||||
// being compiled. There must be one Argument for each _Arg index.
|
||||
struct XlaArgument {
|
||||
enum Kind {
|
||||
// Default value; not a valid kind.
|
||||
kInvalid,
|
||||
|
||||
// Argument is a compile-time constant. No associated runtime parameter.
|
||||
kConstant,
|
||||
|
||||
// Argument is a Variable, TensorArray, or Stack resource. Has an
|
||||
// associated runtime parameter iff `initialized` is true.
|
||||
kResource,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
|
||||
// Argument is an XLA token.
|
||||
kToken,
|
||||
|
||||
// Argument is a TensorList.
|
||||
kTensorList,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
|
||||
// The type of the argument. If the argument is a resource, this
|
||||
// is the type of the variable's value, not DT_RESOURCE.
|
||||
DataType type = DT_INVALID;
|
||||
|
||||
// The shape of the argument. For:
|
||||
// * a parameter: the shape of the parameter. We allow setting the xla shape
|
||||
// if known. This helps avoid conversions to and from TensorShape.
|
||||
// * a constant: ignored; the shape given by constant_value is used
|
||||
// instead.
|
||||
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
||||
// uninitialized resource (otherwise we would have initialized it!)
|
||||
// * an initialized variable: the shape of the variable's value.
|
||||
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
||||
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
||||
// XLA data structure that represents the complete stack/array.
|
||||
absl::variant<TensorShape, xla::Shape> shape;
|
||||
|
||||
// The value of the argument, if it is a compile-time constant. Must be a
|
||||
// host-memory tensor.
|
||||
Tensor constant_value;
|
||||
|
||||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// The name of TensorFlow _Arg node, used for debugging.
|
||||
string node_name;
|
||||
|
||||
// For a kResource, what kind of resource is it?
|
||||
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
||||
|
||||
// For a kResource, has this resource been initialized?
|
||||
bool initialized = false;
|
||||
|
||||
// For a kResource, is this resource on Fast Memory.
|
||||
bool fast_mem = false;
|
||||
|
||||
// For a TensorArray or Stack resource, what is the array's declared size?
|
||||
// (Used for lazy initialization.)
|
||||
int64 max_array_size = -1;
|
||||
|
||||
// TensorArray resource parameters are passed as (array, gradient array 0,
|
||||
// ..., gradient array k), where the gradient arrays are in the same order
|
||||
// as `tensor_array_gradients`.
|
||||
std::set<string> tensor_array_gradients;
|
||||
|
||||
// dynamic dims to arg number map. Empty if no dynamic shapes.
|
||||
std::map<int32, int32> dynamic_dim_to_arg_num_map;
|
||||
bool is_pad_arg = false;
|
||||
|
||||
// Whether this argument will receive the same data across all replicas.
|
||||
bool is_same_data_across_replicas = false;
|
||||
|
||||
bool operator==(const XlaArgument& other) const;
|
||||
|
||||
// Returns a human-readable summary of the argument.
|
||||
string HumanString() const;
|
||||
|
||||
// Returns the dimension sizes for either TensorShape or xla::Shape.
|
||||
std::vector<int64> DimensionSizes() const;
|
||||
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
|
||||
|
||||
// Returns the human-readable string for either TensorShape or xla::Shape.
|
||||
string ShapeHumanString() const;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_ARGUMENT_H_
|
@ -422,39 +422,6 @@ Status BuildComputation(
|
||||
|
||||
} // namespace
|
||||
|
||||
bool XlaCompiler::Argument::operator==(
|
||||
const XlaCompiler::Argument& other) const {
|
||||
if (std::tie(kind, resource_kind, type, name, initialized, max_array_size,
|
||||
tensor_array_gradients) !=
|
||||
std::tie(other.kind, other.resource_kind, other.type, other.name,
|
||||
other.initialized, other.max_array_size,
|
||||
other.tensor_array_gradients)) {
|
||||
return false;
|
||||
}
|
||||
if (absl::holds_alternative<xla::Shape>(shape)) {
|
||||
if (!absl::holds_alternative<xla::Shape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
if (!xla::Shape::Equal()(absl::get<xla::Shape>(shape),
|
||||
absl::get<xla::Shape>(other.shape))) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!absl::holds_alternative<TensorShape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
if (absl::get<TensorShape>(shape) != absl::get<TensorShape>(other.shape)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (constant_value.shape() != other.constant_value.shape()) {
|
||||
return false;
|
||||
}
|
||||
if (is_same_data_across_replicas != other.is_same_data_across_replicas) {
|
||||
return false;
|
||||
}
|
||||
return constant_value.tensor_data() == other.constant_value.tensor_data();
|
||||
}
|
||||
|
||||
string XlaCompiler::Argument::HumanString() const {
|
||||
string common;
|
||||
@ -1494,93 +1461,4 @@ xla::StatusOr<xla::XlaOp> XlaCompiler::GetNodeToken(const string& node_name) {
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
||||
return [](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
}
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape) {
|
||||
if (sharding && !sharding->IsTileMaximal()) {
|
||||
// After sharding, per core shape might have different layout. For example,
|
||||
// before sharding, a shape [128, 128] will be assigned default
|
||||
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||
// the sharded shapes will have minor-to-major {0, 1}.
|
||||
//
|
||||
// As a result, for sharded shapes, we set their layout to per core shape's
|
||||
// layout.
|
||||
//
|
||||
// TODO(endlessroad): for variable input & update, we might have
|
||||
// different layouts which will prevent input output aliasing and
|
||||
// increase memory usage. Investigate such cases.
|
||||
int64 device = *sharding->tile_assignment().begin();
|
||||
std::vector<int64> offset =
|
||||
sharding->TileOffsetForDevice(*xla_shape, device);
|
||||
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
||||
std::vector<int64> dimensions(xla_shape->rank());
|
||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||
dimensions[i] = limit[i] - offset[i];
|
||||
}
|
||||
xla::Shape per_device_xla_shape =
|
||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||
TensorShape per_device_tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
xla_shape->element_type()));
|
||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||
shape_representation_fn(per_device_tensor_shape, dtype,
|
||||
use_fast_memory));
|
||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is a shape_representation_fn or sharding for an output, this function
|
||||
// uses a reshape to fix the layout.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||
if (original_shape.IsTuple()) {
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||
TF_ASSIGN_OR_RETURN(auto element,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, xla::GetTupleElement(original, i),
|
||||
original_shape.tuple_shapes(i),
|
||||
shape_representation_fn, subsharding, fast_mem));
|
||||
elements.push_back(element);
|
||||
}
|
||||
return xla::Tuple(builder, elements);
|
||||
}
|
||||
if (!original_shape.IsArray()) return original;
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
original_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||
shape_representation_fn(shape, dtype, fast_mem));
|
||||
if (sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||
}
|
||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
return xla::Reshape(to_shape, original);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -21,8 +21,10 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "absl/types/variant.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_argument.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -97,96 +99,7 @@ class XlaContext;
|
||||
// `tensor_array_gradients` ordered set.
|
||||
class XlaCompiler {
|
||||
public:
|
||||
// Describes how to derive the value of each _Arg node in the graph/function
|
||||
// being compiled. There must be one Argument for each _Arg index.
|
||||
struct Argument {
|
||||
enum Kind {
|
||||
// Default value; not a valid kind.
|
||||
kInvalid,
|
||||
|
||||
// Argument is a compile-time constant. No associated runtime parameter.
|
||||
kConstant,
|
||||
|
||||
// Argument is a Variable, TensorArray, or Stack resource. Has an
|
||||
// associated runtime parameter iff `initialized` is true.
|
||||
kResource,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
|
||||
// Argument is an XLA token.
|
||||
kToken,
|
||||
|
||||
// Argument is a TensorList.
|
||||
kTensorList,
|
||||
};
|
||||
|
||||
Kind kind = kInvalid;
|
||||
|
||||
// The type of the argument. If the argument is a resource, this
|
||||
// is the type of the variable's value, not DT_RESOURCE.
|
||||
DataType type = DT_INVALID;
|
||||
|
||||
// The shape of the argument. For:
|
||||
// * a parameter: the shape of the parameter. We allow setting the xla shape
|
||||
// if known. This helps avoid conversions to and from TensorShape.
|
||||
// * a constant: ignored; the shape given by constant_value is used
|
||||
// instead.
|
||||
// * an uninitialized resource: ignored. We don't yet know the shape of an
|
||||
// uninitialized resource (otherwise we would have initialized it!)
|
||||
// * an initialized variable: the shape of the variable's value.
|
||||
// * an initialized TensorArray or Stack resource: the shape of an entry in
|
||||
// the TensorArray/Stack. Note this is the size of a single entry, not the
|
||||
// XLA data structure that represents the complete stack/array.
|
||||
absl::variant<TensorShape, xla::Shape> shape;
|
||||
|
||||
// The value of the argument, if it is a compile-time constant. Must be a
|
||||
// host-memory tensor.
|
||||
Tensor constant_value;
|
||||
|
||||
// The name of this argument, used for debugging.
|
||||
string name;
|
||||
|
||||
// The name of TensorFlow _Arg node, used for debugging.
|
||||
string node_name;
|
||||
|
||||
// For a kResource, what kind of resource is it?
|
||||
XlaResource::Kind resource_kind = XlaResource::kInvalid;
|
||||
|
||||
// For a kResource, has this resource been initialized?
|
||||
bool initialized = false;
|
||||
|
||||
// For a kResource, is this resource on Fast Memory.
|
||||
bool fast_mem = false;
|
||||
|
||||
// For a TensorArray or Stack resource, what is the array's declared size?
|
||||
// (Used for lazy initialization.)
|
||||
int64 max_array_size = -1;
|
||||
|
||||
// TensorArray resource parameters are passed as (array, gradient array 0,
|
||||
// ..., gradient array k), where the gradient arrays are in the same order
|
||||
// as `tensor_array_gradients`.
|
||||
std::set<string> tensor_array_gradients;
|
||||
|
||||
// dynamic dims to arg number map. Empty if no dynamic shapes.
|
||||
std::map<int32, int32> dynamic_dim_to_arg_num_map;
|
||||
bool is_pad_arg = false;
|
||||
|
||||
// Whether this argument will receive the same data across all replicas.
|
||||
bool is_same_data_across_replicas = false;
|
||||
|
||||
bool operator==(const Argument& other) const;
|
||||
|
||||
// Returns a human-readable summary of the argument.
|
||||
string HumanString() const;
|
||||
|
||||
// Returns the dimension sizes for either TensorShape or xla::Shape.
|
||||
std::vector<int64> DimensionSizes() const;
|
||||
absl::InlinedVector<int64, 4> DimensionSizesAsInlinedVector() const;
|
||||
|
||||
// Returns the human-readable string for either TensorShape or xla::Shape.
|
||||
string ShapeHumanString() const;
|
||||
};
|
||||
using Argument = ::tensorflow::XlaArgument;
|
||||
|
||||
// Options pertaining to an individual call to CompileGraph() or
|
||||
// CompileFunction().
|
||||
@ -221,77 +134,11 @@ class XlaCompiler {
|
||||
bool alias_resource_update = false;
|
||||
};
|
||||
|
||||
struct OutputDescription {
|
||||
// Type and shape of the output. The shape is the unflattened shape.
|
||||
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
||||
// variable's value.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
using OutputDescription = ::tensorflow::XlaOutputDescription;
|
||||
|
||||
// Constant output value, if known to be constant at JIT compilation time.
|
||||
// 'Tensor' is in host memory.
|
||||
bool is_constant = false;
|
||||
Tensor constant_value;
|
||||
using ResourceUpdate = ::tensorflow::XlaResourceUpdate;
|
||||
|
||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||
// the index of the input that contains the resource.
|
||||
int input_index;
|
||||
|
||||
// Whether this output is a TensorList.
|
||||
bool is_tensor_list = false;
|
||||
};
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
struct ResourceUpdate {
|
||||
// Index of the input that contains the variable resource to write to.
|
||||
int input_index;
|
||||
|
||||
// Type and shape of the tensor to be written back.
|
||||
// The `shape` field has the same meaning as the Argument::shape field.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
|
||||
// Was the value of the variable modified by the computation?
|
||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||
bool modified;
|
||||
|
||||
// If the resource is a TensorArray, the set of gradients read or written.
|
||||
std::set<string> tensor_array_gradients_accessed;
|
||||
};
|
||||
|
||||
struct CompilationResult {
|
||||
// Vector that maps from the parameters of the XLA computation to their
|
||||
// original argument positions. To handle compile-time constant inputs, the
|
||||
// parameters to the XLA computation may be a subset of the original
|
||||
// arguments. The relative ordering of parameters are maintained.
|
||||
std::vector<int> input_mapping;
|
||||
|
||||
// Input shapes of the computation. If we are flattening inputs, these are
|
||||
// the flattened shapes.
|
||||
std::vector<xla::Shape> xla_input_shapes;
|
||||
|
||||
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||
// are flattening outputs, these are the flattened shapes.
|
||||
xla::Shape xla_output_shape;
|
||||
|
||||
// TensorFlow shapes of outputs, together with the values of any
|
||||
// constant arguments. Vector indexed by Tensorflow _Retval number,
|
||||
// containing both constant and non-constant results.
|
||||
std::vector<OutputDescription> outputs;
|
||||
|
||||
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||
tf2xla::HostComputeMetadata host_compute_metadata;
|
||||
|
||||
// Resources whose values were updated by the computation, ordered
|
||||
// by return value position (which is the same as the order the resources
|
||||
// were passed as arguments). Resource updates follow the non-constant
|
||||
// results in the outputs of XLA computation.
|
||||
std::vector<ResourceUpdate> resource_updates;
|
||||
|
||||
// The XLA computation built from the tensorflow subgraph.
|
||||
std::shared_ptr<xla::XlaComputation> computation;
|
||||
};
|
||||
using CompilationResult = ::tensorflow::XlaCompilationResult;
|
||||
|
||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
||||
bool)>
|
||||
@ -518,21 +365,6 @@ class XlaCompiler {
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompiler);
|
||||
};
|
||||
|
||||
// Creates an identity shape representation function.
|
||||
XlaCompiler::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape);
|
||||
|
||||
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
||||
// sharding is present.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaCompiler::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -24,7 +24,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
|
@ -20,7 +20,6 @@ limitations under the License.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
@ -33,6 +32,7 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
|
||||
class XlaOpKernelContext;
|
||||
class XlaCompiler;
|
||||
|
||||
// The XlaContext is the data structure that holds the state of an XLA
|
||||
// compilation, that is accessible from OpKernelContexts when compiling a
|
||||
|
@ -163,4 +163,23 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||
}
|
||||
}
|
||||
|
||||
const XlaExpression* XlaExpression::CastExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
return expression;
|
||||
}
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
void XlaExpression::AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
*const_cast<XlaExpression*>(expression) = value;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -104,6 +104,13 @@ class XlaExpression {
|
||||
// not the shape of the resource's value.
|
||||
xla::StatusOr<TensorShape> GetShape() const;
|
||||
|
||||
// Retrieves an XlaExpression that was allocated by a previous Op.
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
static void AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor);
|
||||
|
||||
private:
|
||||
Kind kind_ = Kind::kInvalid;
|
||||
|
||||
|
@ -22,8 +22,6 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/tf2xla/literal_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/type_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
@ -128,4 +126,93 @@ xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
|
||||
return xla::ConvertElementType(operand, convert_to);
|
||||
}
|
||||
|
||||
XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn() {
|
||||
return [](const TensorShape& shape, DataType dtype,
|
||||
bool use_fast_memory) -> xla::StatusOr<xla::Shape> {
|
||||
xla::Shape xla_shape;
|
||||
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(dtype, shape, &xla_shape));
|
||||
return xla_shape;
|
||||
};
|
||||
}
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape) {
|
||||
if (sharding && !sharding->IsTileMaximal()) {
|
||||
// After sharding, per core shape might have different layout. For example,
|
||||
// before sharding, a shape [128, 128] will be assigned default
|
||||
// minor-to-major {1, 0}. But after we shard this shape to [128, 64] * 2,
|
||||
// the sharded shapes will have minor-to-major {0, 1}.
|
||||
//
|
||||
// As a result, for sharded shapes, we set their layout to per core shape's
|
||||
// layout.
|
||||
//
|
||||
// TODO(endlessroad): for variable input & update, we might have
|
||||
// different layouts which will prevent input output aliasing and
|
||||
// increase memory usage. Investigate such cases.
|
||||
int64 device = *sharding->tile_assignment().begin();
|
||||
std::vector<int64> offset =
|
||||
sharding->TileOffsetForDevice(*xla_shape, device);
|
||||
std::vector<int64> limit = sharding->TileLimitForDevice(*xla_shape, device);
|
||||
std::vector<int64> dimensions(xla_shape->rank());
|
||||
for (int64 i = 0; i < xla_shape->rank(); ++i) {
|
||||
dimensions[i] = limit[i] - offset[i];
|
||||
}
|
||||
xla::Shape per_device_xla_shape =
|
||||
xla::ShapeUtil::MakeShape(xla_shape->element_type(), dimensions);
|
||||
TensorShape per_device_tensor_shape;
|
||||
TF_RETURN_IF_ERROR(
|
||||
XLAShapeToTensorShape(per_device_xla_shape, &per_device_tensor_shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
xla_shape->element_type()));
|
||||
TF_ASSIGN_OR_RETURN(per_device_xla_shape,
|
||||
shape_representation_fn(per_device_tensor_shape, dtype,
|
||||
use_fast_memory));
|
||||
*xla_shape->mutable_layout() = per_device_xla_shape.layout();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// There is a shape_representation_fn or sharding for an output, this function
|
||||
// uses a reshape to fix the layout.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem) {
|
||||
if (original_shape.IsTuple()) {
|
||||
std::vector<xla::XlaOp> elements;
|
||||
for (int64 i = 0; i < original_shape.tuple_shapes_size(); ++i) {
|
||||
auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding;
|
||||
TF_ASSIGN_OR_RETURN(auto element,
|
||||
ReshapeWithCorrectRepresentationAndSharding(
|
||||
builder, xla::GetTupleElement(original, i),
|
||||
original_shape.tuple_shapes(i),
|
||||
shape_representation_fn, subsharding, fast_mem));
|
||||
elements.push_back(element);
|
||||
}
|
||||
return xla::Tuple(builder, elements);
|
||||
}
|
||||
if (!original_shape.IsArray()) return original;
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(original_shape, &shape));
|
||||
TF_ASSIGN_OR_RETURN(DataType dtype, EncodePrimitiveTypeAsDataType(
|
||||
original_shape.element_type()));
|
||||
TF_ASSIGN_OR_RETURN(auto to_shape,
|
||||
shape_representation_fn(shape, dtype, fast_mem));
|
||||
if (sharding) {
|
||||
TF_ASSIGN_OR_RETURN(auto hlo_sharding,
|
||||
xla::HloSharding::FromProto(*sharding));
|
||||
TF_RETURN_IF_ERROR(RewriteLayoutWithShardedShape(
|
||||
hlo_sharding, fast_mem, shape_representation_fn, &to_shape));
|
||||
}
|
||||
if (xla::ShapeUtil::Compatible(original_shape, to_shape)) {
|
||||
for (int64 i = 0; i < original_shape.rank(); ++i) {
|
||||
to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i));
|
||||
}
|
||||
}
|
||||
return xla::Reshape(to_shape, original);
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -19,8 +19,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_HELPERS_H_
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -72,6 +73,98 @@ class XlaHelpers {
|
||||
// than the xla::PrimitiveType.
|
||||
static xla::XlaOp ConvertElementType(const xla::XlaOp& operand,
|
||||
const DataType new_element_type);
|
||||
|
||||
typedef std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType,
|
||||
bool)>
|
||||
ShapeRepresentationFn;
|
||||
};
|
||||
|
||||
// Creates an identity shape representation function.
|
||||
XlaHelpers::ShapeRepresentationFn IdentityShapeRepresentationFn();
|
||||
|
||||
// Rewrites the layout of xla_shape if there is tiled sharding.
|
||||
Status RewriteLayoutWithShardedShape(
|
||||
const absl::optional<xla::HloSharding>& sharding, bool use_fast_memory,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
xla::Shape* xla_shape);
|
||||
|
||||
// Adds reshapes to fix the layout of an output, if a shape_representation_fn or
|
||||
// sharding is present.
|
||||
xla::StatusOr<xla::XlaOp> ReshapeWithCorrectRepresentationAndSharding(
|
||||
xla::XlaBuilder* builder, xla::XlaOp original, xla::Shape original_shape,
|
||||
XlaHelpers::ShapeRepresentationFn shape_representation_fn,
|
||||
absl::optional<xla::OpSharding> sharding, bool fast_mem);
|
||||
|
||||
struct XlaOutputDescription {
|
||||
// Type and shape of the output. The shape is the unflattened shape.
|
||||
// When `type` is DT_RESOURCE, `shape` is the shape of the resource
|
||||
// variable's value.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
|
||||
// Constant output value, if known to be constant at JIT compilation time.
|
||||
// 'Tensor' is in host memory.
|
||||
bool is_constant = false;
|
||||
Tensor constant_value;
|
||||
|
||||
// When this output is a resource, i.e. `type == DT_RESOURCE`, this is
|
||||
// the index of the input that contains the resource.
|
||||
int input_index;
|
||||
|
||||
// Whether this output is a TensorList.
|
||||
bool is_tensor_list = false;
|
||||
};
|
||||
|
||||
// Describes a variable write side effect of the computation.
|
||||
struct XlaResourceUpdate {
|
||||
// Index of the input that contains the variable resource to write to.
|
||||
int input_index;
|
||||
|
||||
// Type and shape of the tensor to be written back.
|
||||
// The `shape` field has the same meaning as the Argument::shape field.
|
||||
DataType type;
|
||||
TensorShape shape;
|
||||
|
||||
// Was the value of the variable modified by the computation?
|
||||
// (Always true, unless `return_updated_values_for_all_resources` is true.)
|
||||
bool modified;
|
||||
|
||||
// If the resource is a TensorArray, the set of gradients read or written.
|
||||
std::set<string> tensor_array_gradients_accessed;
|
||||
};
|
||||
|
||||
struct XlaCompilationResult {
|
||||
// Vector that maps from the parameters of the XLA computation to their
|
||||
// original argument positions. To handle compile-time constant inputs, the
|
||||
// parameters to the XLA computation may be a subset of the original
|
||||
// arguments. The relative ordering of parameters are maintained.
|
||||
std::vector<int> input_mapping;
|
||||
|
||||
// Input shapes of the computation. If we are flattening inputs, these are
|
||||
// the flattened shapes.
|
||||
std::vector<xla::Shape> xla_input_shapes;
|
||||
|
||||
// Output shape in XLA format. The output shape is always a tuple. If we
|
||||
// are flattening outputs, these are the flattened shapes.
|
||||
xla::Shape xla_output_shape;
|
||||
|
||||
// TensorFlow shapes of outputs, together with the values of any
|
||||
// constant arguments. Vector indexed by Tensorflow _Retval number,
|
||||
// containing both constant and non-constant results.
|
||||
std::vector<XlaOutputDescription> outputs;
|
||||
|
||||
// TensorFlow shapes and types of sends/recvs from HostCompute Ops to their
|
||||
// matching RecvAtHost/SendFromHost Ops in the outer graph.
|
||||
tf2xla::HostComputeMetadata host_compute_metadata;
|
||||
|
||||
// Resources whose values were updated by the computation, ordered
|
||||
// by return value position (which is the same as the order the resources
|
||||
// were passed as arguments). Resource updates follow the non-constant
|
||||
// results in the outputs of XLA computation.
|
||||
std::vector<XlaResourceUpdate> resource_updates;
|
||||
|
||||
// The XLA computation built from the tensorflow subgraph.
|
||||
std::shared_ptr<xla::XlaComputation> computation;
|
||||
};
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -49,33 +49,13 @@ XlaCompiler* XlaOpKernelContext::compiler() const {
|
||||
return xla_context()->compiler();
|
||||
}
|
||||
|
||||
// Retrieves an XlaExpression that was allocated by a previous Op.
|
||||
const XlaExpression* XlaOpKernelContext::CastExpressionFromTensor(
|
||||
const Tensor& tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
|
||||
CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
return expression;
|
||||
}
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
void XlaOpKernelContext::AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor) {
|
||||
const XlaExpression* expression =
|
||||
reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
|
||||
CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
|
||||
<< expression->HumanString();
|
||||
*const_cast<XlaExpression*>(expression) = value;
|
||||
}
|
||||
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
|
||||
return *CastExpressionFromTensor(context_->input(index));
|
||||
return *XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||
}
|
||||
|
||||
const XlaExpression& XlaOpKernelContext::InputExpression(
|
||||
absl::string_view name) {
|
||||
return *CastExpressionFromTensor(GetInputTensorByName(name));
|
||||
return *XlaExpression::CastExpressionFromTensor(GetInputTensorByName(name));
|
||||
}
|
||||
|
||||
xla::XlaOp XlaOpKernelContext::Input(int index) {
|
||||
@ -108,7 +88,8 @@ DataType XlaOpKernelContext::input_type(int index) const {
|
||||
if (type == DT_UINT8) {
|
||||
// Masqueraded XlaExpression could have different type. See
|
||||
// XlaOpKernelContext::SetOutputExpression for details.
|
||||
auto expression = CastExpressionFromTensor(context_->input(index));
|
||||
auto expression =
|
||||
XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||
type = expression->dtype();
|
||||
}
|
||||
return type;
|
||||
@ -120,7 +101,7 @@ DataType XlaOpKernelContext::InputType(absl::string_view name) {
|
||||
if (type == DT_UINT8) {
|
||||
// Masqueraded XlaExpression could have different type. See
|
||||
// XlaOpKernelContext::SetOutputExpression for details.
|
||||
auto expression = CastExpressionFromTensor(tensor);
|
||||
auto expression = XlaExpression::CastExpressionFromTensor(tensor);
|
||||
type = expression->dtype();
|
||||
}
|
||||
return type;
|
||||
@ -385,7 +366,8 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
|
||||
handles->clear();
|
||||
shapes->clear();
|
||||
for (const Tensor& input : inputs) {
|
||||
handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
||||
handles->push_back(
|
||||
XlaExpression::CastExpressionFromTensor(input)->AsXlaOp(builder()));
|
||||
shapes->push_back(input.shape());
|
||||
}
|
||||
return Status::OK();
|
||||
@ -408,7 +390,7 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
const XlaOpKernelContext* ctx,
|
||||
TensorShape* shape, xla::XlaOp* value) {
|
||||
const XlaExpression* expression =
|
||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
||||
XlaExpression::CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
@ -461,7 +443,8 @@ Status XlaOpKernelContext::ReadVariableInput(absl::string_view name,
|
||||
Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
|
||||
TensorShape* shape) const {
|
||||
const Tensor& tensor = context_->input(index);
|
||||
const XlaExpression* expression = CastExpressionFromTensor(tensor);
|
||||
const XlaExpression* expression =
|
||||
XlaExpression::CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
@ -502,8 +485,8 @@ void XlaOpKernelContext::SetOutputExpression(int index,
|
||||
TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
|
||||
TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
|
||||
}
|
||||
XlaOpKernelContext::AssignExpressionToTensor(
|
||||
expression, context_->mutable_output(index));
|
||||
XlaExpression::AssignExpressionToTensor(expression,
|
||||
context_->mutable_output(index));
|
||||
return Status::OK();
|
||||
}();
|
||||
if (!status.ok()) {
|
||||
@ -542,7 +525,7 @@ void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
|
||||
|
||||
Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
|
||||
const XlaExpression* expression =
|
||||
CastExpressionFromTensor(context_->input(index));
|
||||
XlaExpression::CastExpressionFromTensor(context_->input(index));
|
||||
TF_RET_CHECK(expression->resource() != nullptr);
|
||||
*resource = expression->resource();
|
||||
return Status::OK();
|
||||
@ -554,7 +537,7 @@ Status AssignVariableTensor(const Tensor& tensor, DataType type,
|
||||
const XlaOpKernelContext* ctx, xla::XlaOp handle,
|
||||
xla::XlaBuilder* builder) {
|
||||
const XlaExpression* expression =
|
||||
XlaOpKernelContext::CastExpressionFromTensor(tensor);
|
||||
XlaExpression::CastExpressionFromTensor(tensor);
|
||||
XlaResource* variable = expression->resource();
|
||||
TF_RET_CHECK(variable != nullptr);
|
||||
TF_RET_CHECK(variable->kind() == XlaResource::kVariable);
|
||||
|
@ -17,6 +17,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_KERNEL_H_
|
||||
|
||||
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_expression.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_resource.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_computation.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
@ -284,13 +287,6 @@ class XlaOpKernelContext {
|
||||
// separate specialization of the computation for each DataType.
|
||||
const xla::XlaComputation* GetOrCreateMul(const DataType type);
|
||||
|
||||
// Assigns an XlaExpression to a tensor on an XLA compilation device.
|
||||
static void AssignExpressionToTensor(const XlaExpression& value,
|
||||
Tensor* tensor);
|
||||
|
||||
// Retrieves an XlaExpression that was assigned to the specified tensor.
|
||||
static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor);
|
||||
|
||||
private:
|
||||
// Returns the tensor of input `name`.
|
||||
const Tensor& GetInputTensorByName(absl::string_view name);
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "absl/memory/memory.h"
|
||||
#include "tensorflow/compiler/tf2xla/shape_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/sharding_util.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_context.h"
|
||||
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
|
||||
|
@ -57,6 +57,7 @@ cc_library(
|
||||
":tpu_defs",
|
||||
":tpu_node_device_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
@ -180,8 +181,8 @@ cc_library(
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:tf2xla_util",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -656,6 +656,7 @@ cc_library(
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:common",
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
@ -673,6 +674,7 @@ cc_library(
|
||||
srcs = ["topk_ops.cc"],
|
||||
deps = [
|
||||
"//tensorflow/compiler/tf2xla:xla_compiler",
|
||||
"//tensorflow/compiler/tf2xla:xla_op_registry",
|
||||
"//tensorflow/compiler/xla/client:xla_builder",
|
||||
"//tensorflow/compiler/xla/client/lib:arithmetic",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
|
Loading…
Reference in New Issue
Block a user