diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 7bd4e5ae79d..e0171415492 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -166,6 +166,7 @@ cc_library( "xla_compilation_device.cc", "xla_compiler.cc", "xla_context.cc", + "xla_expression.cc", "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", @@ -180,6 +181,7 @@ cc_library( "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", + "xla_expression.h", "xla_helpers.h", "xla_op_kernel.h", "xla_op_registry.h", @@ -364,7 +366,10 @@ tf_cc_test( tf_cc_test( name = "xla_compiler_test", - srcs = ["xla_compiler_test.cc"], + srcs = [ + "xla_compiler_test.cc", + "xla_expression_test.cc", + ], deps = [ ":common", ":side_effect_util", @@ -389,6 +394,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", ], ) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 706ed4f5bbf..efb75749722 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -23,9 +23,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/xla_builder.h" @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/graph/validate.h" +#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -51,12 +52,11 @@ namespace { Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, const std::vector& expressions, std::vector* args) { - auto builder = ctx->builder(); auto client = ctx->compiler()->client(); - std::vector compile_time_constant_flags(expressions.size()); + std::vector arg_must_be_compile_time_constant(expressions.size()); TF_RETURN_IF_ERROR( - BackwardsConstAnalysis(*graph, &compile_time_constant_flags, + BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant, /*compile_time_const_nodes=*/nullptr)); args->resize(expressions.size()); @@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, arg.type = ctx->input_type(i); arg.shape = ctx->InputShape(i); - if (arg.type == DT_RESOURCE) { - return errors::InvalidArgument( - "Resource as function argument is not yet implemented."); - } else if (expressions[i]->has_constant_value()) { - arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); - } else if (compile_time_constant_flags[i]) { - arg.kind = XlaCompiler::Argument::kConstant; - TF_RET_CHECK(expressions[i]->resource() == nullptr) - << "Input with resource is not yet implemented."; - TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph( - expressions[i]->handle())); - TF_ASSIGN_OR_RETURN(auto literal, - client->ComputeConstant(constant_graph)); - TF_RETURN_IF_ERROR( - LiteralToHostTensor(literal, arg.type, &arg.constant_value)); - } else { - arg.kind = XlaCompiler::Argument::kParameter; + switch (expressions[i]->kind()) { + case XlaExpression::Kind::kConstant: + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + break; + case XlaExpression::Kind::kXlaOp: + if (arg_must_be_compile_time_constant[i]) { + TF_ASSIGN_OR_RETURN(absl::optional value, + expressions[i]->ResolveConstant(client)); + if (!value.has_value()) { + return errors::InvalidArgument( + "Argument to function must be a compile-time constant, but " + "unable to resolve argument value to a constant."); + } + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = *value; + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + break; + case XlaExpression::Kind::kResource: + return errors::Unimplemented( + "Resource as function argument is not yet implemented."); + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument("Invalid function argument"); } } return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc index 276d744c096..2db2514397d 100644 --- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc @@ -14,11 +14,13 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/lib/core/errors.h" namespace tensorflow { @@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel { } const XlaExpression& arg = XlaContext::Get(ctx).args()[index_]; - if (arg.resource() != nullptr) { - ctx->SetResourceOutput(0, arg.resource()); - } else if (arg.has_constant_value()) { - ctx->SetConstantOutput(0, arg.constant_value()); - } else { - ctx->SetOutput(0, arg.handle()); - } + OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid, + errors::InvalidArgument("Invalid/missing argument expression")); + ctx->SetOutputExpression(0, arg); } private: diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc index 2628ef8e245..dff8af80022 100644 --- a/tensorflow/compiler/tf2xla/kernels/const_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc @@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { TensorShape shape(proto_.tensor_shape()); - if (proto_.dtype() == DT_STRING) { - LOG(WARNING) << "Not computing Const of type DT_STRING"; - ctx->SetInvalidOutput(0); - return; - } xla::XlaBuilder* b = ctx->builder(); // To avoid blowups for large constants filled with the same value, diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 53e7624d607..6970dd0a006 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -47,63 +47,8 @@ class RetvalOp : public XlaOpKernel { // compilation. OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input)); } else { - xla::XlaOp input = ctx->Input(0); - const TensorShape input_shape = ctx->InputShape(0); - DataType input_type = ctx->input_type(0); - XlaContext& tc = XlaContext::Get(ctx); - - if (input_type == DT_RESOURCE) { - XlaResource* resource; - OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource)); - ctx->SetStatus(tc.AddResourceRetval(index_, resource)); - return; - } - - auto is_constant = ctx->builder()->IsConstant(input); - if (!is_constant.ok()) { - ctx->SetStatus(is_constant.status()); - return; - } - - if (tc.resolve_compile_time_constants() && - (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) { - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); - OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); - } else { - TensorShape shape = ctx->InputShape(0); - ctx->SetStatus(is_constant.status()); - xla::Shape representation_shape; - if (tc.is_entry_computation()) { - xla::StatusOr shape_or_status = - tc.RepresentationShape(shape, ctx->input_type(0)); - if (!shape_or_status.ok()) { - ctx->SetStatus(shape_or_status.status()); - return; - } else { - representation_shape = shape_or_status.ValueOrDie(); - } - } else { - OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(ctx->input_type(0), shape, - &representation_shape)); - } - - xla::XlaOp output = input; - if (tc.is_entry_computation()) { - output = xla::Reshape( - input, xla::AsInt64Slice(representation_shape.dimensions())); - } else { - // The core from which a return value is returned depends on the - // device assignment of the input to the retval. Since we can't change - // the device assignment of "input" at this point, we must always - // introduce an operator here, even if the shape does not change. - // TODO(b/76097077): propagate device assignments onto arguments and - // return values of functions, and then reshape unconditionally. - output = - xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0); - } - tc.AddRetval(index_, dtype_, shape, output); - } + XlaContext& xla_context = XlaContext::Get(ctx); + xla_context.SetRetval(index_, ctx->InputExpression(0)); } } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index cb7843850c3..ddb284966ee 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto( "XLACompilationDevice::MakeTensorFromProto should not be called"); } -XlaExpression::XlaExpression() = default; - -void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; } - -void XlaExpression::set_constant_value(Tensor value) { - has_constant_value_ = true; - constant_value_ = std::move(value); -} - } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index a6e78825334..de6a3356e05 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -18,9 +18,6 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_resource.h" -#include "tensorflow/compiler/xla/client/xla_builder.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/common_runtime/local_device.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/tensor.h" @@ -38,8 +35,8 @@ class XlaCompilationAllocator; // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that -// backs each Tensor with metadata indicating the computation the -// Tensor represents. +// backs each Tensor with an XlaExpression. The shape of the Tensor +// matches the shape of XlaExpression. // // We deliberately don't register a device factory because we *never* // want placement to put Ops on a compilation device. The device is created @@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice { std::unique_ptr allocator_; }; -// A XlaExpression wraps an XLA computation. Each Tensor on an -// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor -// matches the shape of the subcomputation in the XlaOp. Each -// expression is either a constant, or a function of previously-compiled -// expressions. -class XlaExpression { - public: - XlaExpression(); - - // handle() stores the XLA handle of the computation that the - // expression represents. - void set_handle(const xla::XlaOp& h); - const xla::XlaOp& handle() const { return handle_; } - - void set_constant_value(Tensor value); - bool has_constant_value() const { return has_constant_value_; } - const Tensor& constant_value() const { return constant_value_; } - - void set_resource(XlaResource* resource) { resource_ = resource; } - XlaResource* resource() const { return resource_; } - - private: - // The XLA handle of the expression's computation. - xla::XlaOp handle_; - - // If this expression is a constant with a known value, 'constant_value' is a - // host-memory Tensor containing the value. Used to avoid invoking XLA for - // expressions that are trivially constant. - bool has_constant_value_ = false; - Tensor constant_value_; - - XlaResource* resource_ = nullptr; // Not owned. -}; - } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index e6d7710c244..a08d030ce71 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -36,11 +36,13 @@ limitations under the License. #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/common_runtime/graph_optimizer.h" #include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" #include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/platform/logging.h" @@ -64,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types, return Status::OK(); } +// Uses the _Arg and _Retval nodes in the graph to determine a core assignment +// for each argument and return value. +xla::StatusOr, std::map>> +ComputeArgAndRetvalCores(const Graph& graph) { + auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr { + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + std::map arg_cores; + std::map retval_cores; + for (const Node* n : graph.nodes()) { + if (n->type_string() == FunctionLibraryDefinition::kArgOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Arg index"; + arg_cores[index] = core; + } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) { + TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n)); + if (core < 0) continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0) << "Negative _Retval index"; + TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n)); + retval_cores[index] = core; + } + } + return std::make_pair(std::move(arg_cores), std::move(retval_cores)); +} + +Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, + XlaCompilationDevice* device, FunctionLibraryRuntime* flib, + int64 step_id) { + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the + // resource manager takes ownership via Create, and unrefs via Cleanup. We + // explicitly add a reference to ensure the refcount at entry is maintained at + // all exit points; Create and Cleanup are always called in this function. + // + // The Executor requires us to use ScopedStepContainer. We wrap it in a + // unique_ptr so we can capture the cleanup status in the end. + xla_context->Ref(); + Status status; + auto step_container = absl::make_unique( + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); + }); + TF_RETURN_IF_ERROR(device->resource_manager()->Create( + step_container->name(), XlaContext::kXlaContextResourceName, + xla_context)); + + GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); + // Explicitly clean up the step container, to capture the cleanup status. + step_container.reset(); + return Status::OK(); +} + +// Builds the XLA computation. +// - `args` is the list of input arguments +// - `retvals` is the list of retvals produced by _Retval operators, in index +// order. +// - `args_core` and `retval_cores` are mapping from arg/return indices to core +// assignments. +// - If `return_updated_values_for_all_resources` is true, all resources will be +// included in `resource_updates`, regardless of whether their value changed. +// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. +// - Sets `*resource_updates` to a description of resources whose values are +// written by the computation; the variable writes are the last +// - `resource_updates.size()` return values from the computation. Each entry in +// `resource_updates` is a ResourceUpdate, whose `index` is the index of a +// resource variable argument to the computation to be updated, and `type` is +// the type of the final output. +Status BuildComputation( + const std::vector& args, + const std::vector& retvals, + const std::map& arg_cores, const std::map& retval_cores, + const std::vector>& resources, + std::unique_ptr token_output, + const XlaCompiler::ShapeRepresentationFn& shape_representation_fn, + bool return_updated_values_for_all_resources, bool always_return_tuple, + xla::XlaBuilder* builder, xla::XlaComputation* computation, + int* num_computation_outputs, int* num_nonconst_outputs, + std::vector* outputs, + std::vector* resource_updates) { + // Attach a common operator name as metadata. This has no semantic effect — it + // merely makes the HLO graph more readable when visualized via TensorBoard, + // since TensorBoard forms groups out of operators with similar names. + xla::OpMetadata retval_metadata; + retval_metadata.set_op_name("XLA_Retvals"); + builder->SetOpMetadata(retval_metadata); + auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); }); + + // Builds a no-op XLA computation. We need to set the sharding of outputs, but + // cannot change the sharding of the existing output op. To do this, we build + // a new identity op to which shardings can be applied. + auto identity_op = [builder](xla::XlaOp op) { + return xla::GetTupleElement(xla::Tuple(builder, {op}), 0); + }; + + std::vector elems; + elems.reserve(retvals.size()); + for (int i = 0; i < retvals.size(); ++i) { + XlaCompiler::OutputDescription& output = (*outputs)[i]; + const XlaExpression& retval = retvals[i]; + output.type = retval.dtype(); + switch (retval.kind()) { + case XlaExpression::Kind::kConstant: + output.is_constant = true; + output.constant_value = retval.constant_value(); + output.shape = output.constant_value.shape(); + break; + + case XlaExpression::Kind::kXlaOp: { + output.is_constant = false; + TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape()); + xla::XlaOp value = retval.handle(); + auto it = retval_cores.find(i); + xla::XlaScopedShardingAssignment assign_sharding( + builder, it == retval_cores.end() + ? absl::optional() + : xla::sharding_builder::AssignDevice(it->second)); + if (shape_representation_fn) { + // If there is a shape representation function, reshape the output + // tensor to the shape given by the representation shape function. + TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn( + output.shape, output.type)); + value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions())); + } else if (it != retval_cores.end()) { + // Apply the sharding to the output, if there is a core assignment. + value = identity_op(value); + } + elems.push_back(value); + break; + } + + case XlaExpression::Kind::kResource: + output.is_constant = false; + output.input_index = retval.resource()->arg_num(); + output.shape = retval.resource()->shape(); + break; + + case XlaExpression::Kind::kInvalid: + return errors::InvalidArgument( + "Invalid expression returned by computation. " + "This probably means a return value was not set."); + } + } + *num_nonconst_outputs = elems.size(); + + // Add return values for resources whose values have changed. + std::vector arg_resources; + arg_resources.reserve(resources.size()); + for (const auto& resource : resources) { + if (resource->arg_num() >= 0) { + arg_resources.push_back(resource.get()); + } + } + std::sort(arg_resources.begin(), arg_resources.end(), + [](const XlaResource* a, const XlaResource* b) { + return a->arg_num() < b->arg_num(); + }); + + for (const XlaResource* resource : arg_resources) { + DCHECK_LT(resource->arg_num(), args.size()); + const XlaCompiler::Argument& arg = args[resource->arg_num()]; + auto it = arg_cores.find(resource->arg_num()); + const int core = it == arg_cores.end() ? -1 : it->second; + bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); + // TensorArray gradients were modified if their values changed or there are + // any newly created gradients. + for (const auto& grad : resource->tensor_array_gradients()) { + modified = + modified || + !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || + arg.tensor_array_gradients.count(grad.first) == 0; + } + if (return_updated_values_for_all_resources || modified) { + resource_updates->emplace_back(); + XlaCompiler::ResourceUpdate& update = resource_updates->back(); + update.input_index = resource->arg_num(); + update.type = resource->type(); + update.shape = resource->shape(); + update.modified = modified; + for (const auto& grad : resource->tensor_array_gradients()) { + update.tensor_array_gradients_accessed.insert(grad.first); + } + + // Request that the value be returned on a specific core. + xla::XlaScopedShardingAssignment assign_sharding( + builder, core == -1 ? absl::optional() + : xla::sharding_builder::AssignDevice(core)); + + xla::XlaOp handle; + TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Ensures the correct sharding is applied to the output. + handle = identity_op(handle); + + elems.push_back(handle); + } + } + + // If we have token output, append it as the last one. + if (token_output) { + elems.push_back(*token_output); + } + + *num_computation_outputs = elems.size(); + + // Builds the XLA computation. We *always* form a tuple here to ensure that + // the output value is the last thing added into the XLA computation, even + // if there is only one output value. + auto tuple = xla::Tuple(builder, elems); + if (!always_return_tuple && elems.size() == 1) { + xla::GetTupleElement(tuple, 0); + } + + xla::StatusOr computation_status = builder->Build(); + if (!computation_status.ok()) { + return computation_status.status(); + } + *computation = computation_status.ConsumeValueOrDie(); + return Status::OK(); +} + } // namespace bool XlaCompiler::Argument::operator==( @@ -252,14 +488,16 @@ Status XlaCompiler::CompileFunction( // lowest-numbered core that consumes the argument. We choose the // lowest-numbered core so the assignment is deterministic. for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Arg") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kArgOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); } } // Do _Retval as a second loop, in case the retval's input is an _Arg (which // may have gotten a device assignment from the first loop). for (Node* n : graph->nodes()) { - if (absl::string_view(n->type_string()) == "_Retval") { + if (absl::string_view(n->type_string()) == + FunctionLibraryDefinition::kRetOp) { TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); } } @@ -353,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg, } } -namespace { - -Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, - XlaCompilationDevice* device, FunctionLibraryRuntime* flib, - int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the - // resource manager takes ownership via Create, and unrefs via Cleanup. We - // explicitly add a reference to ensure the refcount at entry is maintained at - // all exit points; Create and Cleanup are always called in this function. - // - // The Executor requires us to use ScopedStepContainer. We wrap it in a - // unique_ptr so we can capture the cleanup status in the end. - xla_context->Ref(); - Status status; - auto step_container = absl::make_unique( - step_id, [&status, device](const string& name) { - status = device->resource_manager()->Cleanup(name); - }); - TF_RETURN_IF_ERROR(device->resource_manager()->Create( - step_container->name(), XlaContext::kXlaContextResourceName, - xla_context)); - - GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get()); - TF_RETURN_IF_ERROR(graph_compiler.Compile()); - // Explicitly clean up the step container, to capture the cleanup status. - step_container.reset(); - return Status::OK(); -} - -// Builds the XLA computation. -// `args` is the list of input arguments, `retvals` is the list of retvals -// produced by _Retval operators, in index order. -// If `return_updated_values_for_all_resources` is true, all resources will be -// included in `resource_updates`, regardless of whether their value changed. -// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`. -// Sets `*resource_updates` to a description of resources whose values are -// written by the computation; the variable writes are the last -// `resource_updates.size()` return values from the computation. Each entry in -// `resource_updates` is a (input_index, type) pair, where `input_index` is the -// index of a resource variable argument to the computation, and `type` is the -// type of the final output. -Status BuildComputation( - const std::vector& args, - const std::vector& arg_cores, - const std::vector& retvals, - const std::vector>& resources, - std::unique_ptr token_output, - bool return_updated_values_for_all_resources, bool always_return_tuple, - xla::XlaBuilder* builder, xla::XlaComputation* computation, - int* num_computation_outputs, int* num_nonconst_outputs, - std::vector* outputs, - std::vector* resource_updates) { - std::vector elems; - elems.reserve(retvals.size()); - for (int i = 0; i < retvals.size(); ++i) { - XlaCompiler::OutputDescription& output = (*outputs)[i]; - output.type = retvals[i].type; - output.shape = retvals[i].shape; - const XlaExpression& retval = retvals[i].expression; - if (retval.has_constant_value()) { - output.is_constant = true; - output.constant_value = retval.constant_value(); - } else if (retval.resource() != nullptr) { - output.is_constant = false; - output.input_index = retval.resource()->arg_num(); - } else { - output.is_constant = false; - elems.push_back(retval.handle()); - } - } - *num_nonconst_outputs = elems.size(); - - // Add return values for resources whose values have changed. - std::vector arg_resources; - arg_resources.reserve(resources.size()); - for (const auto& resource : resources) { - if (resource->arg_num() >= 0) { - arg_resources.push_back(resource.get()); - } - } - std::sort(arg_resources.begin(), arg_resources.end(), - [](const XlaResource* a, const XlaResource* b) { - return a->arg_num() < b->arg_num(); - }); - - // Attach a common operator name as metadata. This has no semantic effect — it - // merely makes the HLO graph more readable when visualized via TensorBoard, - // since TensorBoard forms groups out of operators with similar names. - xla::OpMetadata retval_metadata; - retval_metadata.set_op_name("XLA_Retvals"); - builder->SetOpMetadata(retval_metadata); - - for (const XlaResource* resource : arg_resources) { - const XlaCompiler::Argument& arg = args[resource->arg_num()]; - const int core = arg_cores[resource->arg_num()]; - DCHECK_LT(resource->arg_num(), arg_cores.size()); - bool modified = !resource->value().IsIdenticalTo(resource->initial_value()); - // TensorArray gradients were modified if their values changed or there are - // any newly created gradients. - for (const auto& grad : resource->tensor_array_gradients()) { - modified = - modified || - !grad.second->value().IsIdenticalTo(grad.second->initial_value()) || - arg.tensor_array_gradients.count(grad.first) == 0; - } - if (return_updated_values_for_all_resources || modified) { - resource_updates->emplace_back(); - XlaCompiler::ResourceUpdate& update = resource_updates->back(); - update.input_index = resource->arg_num(); - update.type = resource->type(); - update.shape = resource->shape(); - update.modified = modified; - for (const auto& grad : resource->tensor_array_gradients()) { - update.tensor_array_gradients_accessed.insert(grad.first); - } - - // Request that the value be returned on a specific core. - xla::XlaScopedShardingAssignment assign_sharding( - builder, core == -1 ? absl::optional() - : xla::sharding_builder::AssignDevice(core)); - - xla::XlaOp handle; - TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); - - // Since we can't change the sharding metadata of as this point, - // create a tuple/get-tuple-element combination so that sharding - // assignment will be placed on this value, which will cause the resource - // update to be returned from the same device that provided the resource. - handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0); - elems.push_back(handle); - } - } - - // If we have token output, append it as the last one. - if (token_output) { - elems.push_back(*token_output); - } - - *num_computation_outputs = elems.size(); - - // Builds the XLA computation. We *always* form a tuple here to ensure that - // the output value is the last thing added into the XLA computation, even - // if there is only one output value. - auto tuple = xla::Tuple(builder, elems); - if (!always_return_tuple && elems.size() == 1) { - xla::GetTupleElement(tuple, 0); - } - builder->ClearOpMetadata(); - - xla::StatusOr computation_status = builder->Build(); - if (!computation_status.ok()) { - return computation_status.status(); - } - *computation = computation_status.ConsumeValueOrDie(); - return Status::OK(); -} - -} // namespace - // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. Status XlaCompiler::BuildArguments( const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context, - std::vector* arg_cores, std::vector* arg_expressions, + const std::map& arg_cores, + std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, bool is_entry_computation) { arg_expressions->resize(args.size()); - *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -543,7 +622,7 @@ Status XlaCompiler::BuildArguments( arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(), /*tensor_array_size=*/arg.tensor_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource)); - arg_expression.set_resource(resource); + arg_expression = XlaExpression::Resource(resource); if (arg.initialized) { input_mapping->push_back(i); } @@ -555,7 +634,7 @@ Status XlaCompiler::BuildArguments( break; } case XlaCompiler::Argument::kConstant: - arg_expression.set_constant_value(arg.constant_value); + arg_expression = XlaExpression::Constant(arg.constant_value); break; case XlaCompiler::Argument::kInvalid: return errors::Internal( @@ -580,26 +659,6 @@ Status XlaCompiler::BuildArguments( *input_shapes = arg_shapes; } - // Use the _Arg nodes in the graph to resolve core assignments. - for (const Node* n : graph.nodes()) { - if (absl::string_view(n->type_string()) != "_Arg") continue; - int index; - TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); - TF_RET_CHECK(index >= 0 && index < args.size()) - << "_Arg out of bounds: " << index << " vs " << args.size(); - TF_ASSIGN_OR_RETURN( - auto sharding, - ParseShardingFromDevice(*n, std::numeric_limits::max())); - if (sharding.has_value()) { - TF_RET_CHECK(sharding.value().type() == - xla::OpSharding::Type::OpSharding_Type_MAXIMAL); - const int core = sharding.value().tile_assignment_devices(0); - if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { - (*arg_cores)[index] = core; - } - } - } - // Attach a common operator name as metadata. This has no semantic effect — it // merely makes the HLO graph more readable when visualized via TensorBoard, // since TensorBoard forms groups out of operators with similar names. @@ -615,11 +674,10 @@ Status XlaCompiler::BuildArguments( xla::OpSharding tuple_sharding; tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE); for (int64 parameter : *input_mapping) { - const int core = (*arg_cores)[parameter]; - const int root_device = 0; + auto it = arg_cores.find(parameter); + const int core = it == arg_cores.end() ? 0 : it->second; *tuple_sharding.add_tuple_shardings() = - core == -1 ? xla::sharding_builder::AssignDevice(root_device) - : xla::sharding_builder::AssignDevice(core); + xla::sharding_builder::AssignDevice(core); } xla::XlaScopedShardingAssignment assign_tuple_sharding(builder, tuple_sharding); @@ -628,7 +686,8 @@ Status XlaCompiler::BuildArguments( tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple"); } for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -636,7 +695,8 @@ Status XlaCompiler::BuildArguments( } } else { for (std::vector::size_type i = 0; i < input_mapping->size(); ++i) { - const int core = (*arg_cores)[input_mapping->at(i)]; + auto it = arg_cores.find(i); + const int core = it == arg_cores.end() ? -1 : it->second; xla::XlaScopedShardingAssignment assign_sharding( builder, core == -1 ? absl::optional() : xla::sharding_builder::AssignDevice(core)); @@ -671,14 +731,14 @@ Status XlaCompiler::BuildArguments( // TODO(b/76097077): propagate device assignments onto arguments and // return values of functions, and then reshape unconditionally. if (is_entry_computation) { - arg_expression.set_handle( - xla::Reshape(arg_handles[i], arg.shape.dim_sizes())); + arg_expression = XlaExpression::XlaOp( + xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type); } else { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); } break; case XlaCompiler::Argument::kToken: { - arg_expression.set_handle(arg_handles[i]); + arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type); break; } case XlaCompiler::Argument::kConstant: @@ -710,7 +770,7 @@ Status XlaCompiler::CompileSingleOp( Node* node; string arg_name = absl::StrCat("_arg", i); Status status = - NodeBuilder(arg_name, "_Arg") + NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp) .ControlInput(graph->source_node()) .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE : args[i].type) @@ -724,7 +784,7 @@ Status XlaCompiler::CompileSingleOp( for (int64 i = 0; i < result_types.size(); ++i) { Node* node; string retval_name = absl::StrCat("_retval", i); - Status status = NodeBuilder(retval_name, "_Retval") + Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp) .Input(main_node, i) .Attr("T", result_types[i]) .Attr("index", i) @@ -788,6 +848,32 @@ Status ValidateGraph(const Graph* graph, return Status::OK(); } +// Converts the value of any expressions whose values are known at compile-time +// to constants. +Status ResolveConstantExpressionsToConstants( + xla::Client* client, absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kXlaOp) { + TF_ASSIGN_OR_RETURN(absl::optional constant, + expression.ResolveConstant(client)); + if (constant.has_value()) { + expression = XlaExpression::Constant(*constant); + } + } + } + return Status::OK(); +} + +void ConvertConstantsToExpressions(xla::XlaBuilder* builder, + absl::Span expressions) { + for (XlaExpression& expression : expressions) { + if (expression.kind() == XlaExpression::Kind::kConstant) { + expression = + XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype()); + } + } +} + } // namespace Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, @@ -815,10 +901,9 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, options_.device_type, name)); xla::XlaBuilder builder(name); - XlaContext* context = new XlaContext( - this, &builder, options_.allow_cpu_custom_calls, - options.resolve_compile_time_constants, options.is_entry_computation, - &options_.shape_representation_fn); + XlaContext* context = + new XlaContext(this, &builder, options_.allow_cpu_custom_calls, + &options_.shape_representation_fn); core::ScopedUnref context_unref(context); std::vector real_args(args.begin(), args.end()); @@ -833,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, real_args.push_back(token_arg); } + std::map arg_cores; + std::map retval_cores; + TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores), + ComputeArgAndRetvalCores(*graph)); + std::vector arg_expressions; - std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores, + *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores, &arg_expressions, &result->input_mapping, &result->xla_input_shapes, options.is_entry_computation)); context->set_args(std::move(arg_expressions)); @@ -884,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); result->outputs.resize(context->retvals().size()); + std::vector retvals = context->retvals(); + if (options.resolve_compile_time_constants) { + TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants( + client(), absl::Span(retvals))); + } else { + ConvertConstantsToExpressions(&builder, absl::Span(retvals)); + } TF_RETURN_IF_ERROR(BuildComputation( - real_args, arg_cores, context->retvals(), context->resources(), - std::move(token_output), options.return_updated_values_for_all_resources, + real_args, retvals, arg_cores, retval_cores, context->resources(), + std::move(token_output), + options.is_entry_computation ? options_.shape_representation_fn + : ShapeRepresentationFn{}, + options.return_updated_values_for_all_resources, options.always_return_tuple, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->outputs, &result->resource_updates)); diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index f10cfbe0c65..63426124686 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -21,8 +21,10 @@ limitations under the License. #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/core/common_runtime/device.h" @@ -415,7 +417,8 @@ class XlaCompiler { Status BuildArguments(const Graph& graph, const std::vector& args, bool use_tuple_arg, xla::XlaBuilder* builder, - XlaContext* context, std::vector* arg_cores, + XlaContext* context, + const std::map& arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes, diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc index 1e819dbb694..43095fbb473 100644 --- a/tensorflow/compiler/tf2xla/xla_context.cc +++ b/tensorflow/compiler/tf2xla/xla_context.cc @@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector args) { XlaContext::XlaContext( XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, + bool allow_cpu_custom_calls, const std::function( const TensorShape&, DataType)>* shape_representation_fn) : compiler_(compiler), builder_(builder), allow_cpu_custom_calls_(allow_cpu_custom_calls), - resolve_compile_time_constants_(resolve_compile_time_constants), - is_entry_computation_(is_entry_computation), shape_representation_fn_(shape_representation_fn) {} string XlaContext::DebugString() { return "TLA JIT context"; } -// This is called by the Retval Op to associate a computed value -// with a specific return value of the subgraph. -void XlaContext::AddRetval(int retval_index, DataType type, - const TensorShape& shape, const xla::XlaOp& handle) { - VLOG(1) << "Added retval index " << retval_index << " to XLA computation"; - // Add the return value to the list being built up. - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); +void XlaContext::SetRetval(int index, const XlaExpression& expression) { + if (retvals_.size() <= index) { + retvals_.resize(index + 1); } - XlaExpression e; - e.set_handle(handle); - retvals_[retval_index] = Retval{type, shape, e}; + retvals_[index] = expression; } -Status XlaContext::AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal) { - VLOG(1) << "Adding retval index " << retval_index - << " with non-data-dependent tensor to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - Tensor value; - TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value)); - XlaExpression e; - e.set_constant_value(value); - retvals_[retval_index] = Retval{dtype, value.shape(), e}; - return Status::OK(); -} - -Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) { - VLOG(1) << "Adding retval index " << retval_index << " with resource " - << resource->name() << ":" << resource->shape().DebugString() - << " to XLA computation"; - if (retvals_.size() <= retval_index) { - retvals_.resize(retval_index + 1); - } - XlaExpression e; - e.set_resource(resource); - retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e}; - return Status::OK(); -} - -xla::XlaBuilder* XlaContext::builder() { return builder_; } - Status XlaContext::CreateResource( XlaResource::Kind kind, int arg_num, string name, DataType type, TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size, diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h index 8aad6cbced0..dbfd344c9ba 100644 --- a/tensorflow/compiler/tf2xla/xla_context.h +++ b/tensorflow/compiler/tf2xla/xla_context.h @@ -20,8 +20,8 @@ limitations under the License. #include -#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -46,8 +46,7 @@ class XlaContext : public ResourceBase { // Creates a new XlaContext. See the documentation on the class data fields // for descriptions of the arguments. XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder, - bool allow_cpu_custom_calls, bool resolve_compile_time_constants, - bool is_entry_computation, + bool allow_cpu_custom_calls, const std::function( const TensorShape&, DataType)>* shape_representation_fn); @@ -57,37 +56,19 @@ class XlaContext : public ResourceBase { XlaCompiler* compiler() const { return compiler_; } // Returns the XlaBuilder that Ops use for compiling new expressions. - xla::XlaBuilder* builder(); + xla::XlaBuilder* builder() { return builder_; } bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; } - bool resolve_compile_time_constants() const { - return resolve_compile_time_constants_; - } - bool is_entry_computation() const { return is_entry_computation_; } - const std::vector& args() const { return args_; } void set_args(std::vector args); - struct Retval { - DataType type; - TensorShape shape; - // An XlaExpression representing the Retval's value. - XlaExpression expression; - }; - const std::vector& retvals() { return retvals_; } + const std::vector& retvals() { return retvals_; } - // This is called by the Retval Op to associate a computed value - // with a specific return value of the subgraph. - void AddRetval(int retval_index, DataType type, const TensorShape& shape, - const xla::XlaOp& handle); - - // As for Retval, but for return values that are compile-time constants. - Status AddConstRetval(int retval_index, DataType dtype, - const xla::LiteralSlice& literal); - - // As for Retval, but for return values that are resource handles. - Status AddResourceRetval(int retval_index, XlaResource* resource); + // Sets a return value. + // Since we do not always know in advance how many return values there are, + // grows the return values vector to size index+1 if it is smaller. + void SetRetval(int index, const XlaExpression& expression); // Creates a resource with resource `kind` and initial value `handle`. `name` // is a descriptive name for use in error messages. See the `XlaResource` @@ -140,24 +121,16 @@ class XlaContext : public ResourceBase { // Allow ops to emit CustomCall operations for CPU. const bool allow_cpu_custom_calls_; - // If true, constant return values are returned as Tensors instead of - // run-time computation outputs. - const bool resolve_compile_time_constants_; - // Arguments to the Tensorflow graph, indexed by _Arg index. // Includes both compile-time constant arguments and runtime parameters. std::vector args_; // Return values of the Tensorflow graph, indexed by _Retval index. - std::vector retvals_; + std::vector retvals_; // Holds ownership of resources. The resources are not ordered. std::vector> resources_; - // Is this a top-level computation, or an inner computation (e.g., a while - // body)? - const bool is_entry_computation_; - // Describes the on-host shapes of parameters and return values. Also see: // XlaDevice::Options::shape_representation_fn. const std::function(const TensorShape&, DataType)>* diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc new file mode 100644 index 00000000000..ca0309166b7 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_expression.h" + +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" + +namespace tensorflow { + +XlaExpression::XlaExpression() = default; + +XlaExpression XlaExpression::Invalid() { + XlaExpression e; + e.kind_ = Kind::kInvalid; + return e; +} + +XlaExpression XlaExpression::Constant(Tensor value) { + XlaExpression e; + e.kind_ = Kind::kConstant; + e.dtype_ = value.dtype(); + e.constant_value_ = value; + return e; +} + +XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { + XlaExpression e; + e.kind_ = Kind::kXlaOp; + e.dtype_ = dtype; + e.handle_ = value; + return e; +} + +XlaExpression XlaExpression::Resource(XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + return e; +} + +string XlaExpression::HumanString() const { + switch (kind_) { + case Kind::kInvalid: + return "invalid"; + case Kind::kConstant: + return "constant"; + case Kind::kXlaOp: + return "xla_op"; + case Kind::kResource: + return "resource"; + } +} + +xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { + return builder->ReportErrorOrReturn([&]() -> xla::StatusOr { + switch (kind_) { + case Kind::kConstant: { + xla::BorrowingLiteral literal; + TF_RETURN_IF_ERROR( + HostTensorToBorrowingLiteral(constant_value_, &literal)); + return xla::ConstantLiteral(builder, literal); + } + case Kind::kXlaOp: + if (builder != handle_.builder()) { + return errors::InvalidArgument( + "Mismatched builders in XlaExpression::AsXlaOp"); + } + return handle_; + default: + return errors::InvalidArgument("AsXlaOp called on XlaExpression: ", + HumanString()); + } + }); +} + +xla::StatusOr> XlaExpression::ResolveConstant( + xla::Client* client) const { + switch (kind()) { + case Kind::kConstant: + return {constant_value()}; + case Kind::kXlaOp: + break; + case Kind::kResource: + case Kind::kInvalid: + return errors::InvalidArgument( + "ResolveConstant called on XlaExpression: ", HumanString()); + } + + TF_ASSIGN_OR_RETURN(bool is_constant, + handle().builder()->IsConstant(handle())); + if (!is_constant) return {absl::nullopt}; + + TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph, + handle().builder()->BuildConstantSubGraph(handle())); + + TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape()); + + // The XLA layout is specified minor to major, and TensorFlow uses a major to + // minor order. + std::vector layout_indices(shape.dims()); + std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); + xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); + TF_ASSIGN_OR_RETURN(xla::Literal literal, + client->ComputeConstant(constant_graph, &layout)); + Tensor tensor; + TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor)); + return {tensor}; +} + +xla::StatusOr XlaExpression::GetShape() const { + switch (kind_) { + case Kind::kConstant: + return constant_value().shape(); + case Kind::kXlaOp: { + TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, + handle().builder()->GetShape(handle())); + TensorShape shape; + TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape)); + return shape; + } + case Kind::kResource: + return TensorShape({}); + case Kind::kInvalid: + return errors::InvalidArgument( + "GetShape() called on invalid XlaExpression"); + } +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h new file mode 100644 index 00000000000..bed6761d362 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ + +#include "absl/types/optional.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA +// compilation. +// An expression is one of: +// * a constant tensor. +// * an xla::XlaOp, representing a symbolic XLA value. +// * a resource, e.g., a variable, represented as an XlaResource pointer. +// +// Constant tensors are mostly an optimization to avoid passing large constants +// to XLA, but are also sometimes used to represent tensors that have no XLA +// representation, for example, DT_STRING tensors. A canonical use case might be +// an error message string. +class XlaExpression { + public: + enum class Kind { + kInvalid, + kConstant, + kXlaOp, + kResource, + }; + + XlaExpression(); + XlaExpression(const XlaExpression&) = default; + XlaExpression& operator=(const XlaExpression&) = default; + + // Builds an invalid expression. (Same as the default constructor, but makes + // the intent clearer.) + static XlaExpression Invalid(); + + // Builds a constant XLA expression. + static XlaExpression Constant(Tensor value); + + // Builds a XlaOp expression. Since the mapping from TF data types to XLA + // types is not 1-1, the TF type must also be provided; in general it cannot + // be derived from the XLA type. + static XlaExpression XlaOp(xla::XlaOp value, DataType dtype); + + // Builds a resource expression. + static XlaExpression Resource(XlaResource* resource); + + Kind kind() const { return kind_; } + + DataType dtype() const { return dtype_; } + + // handle() returns the XlaOp that backs a kXlaOp expression. + const xla::XlaOp& handle() const { return handle_; } + + const Tensor& constant_value() const { return constant_value_; } + + XlaResource* resource() const { return resource_; } + + // Returns a human-readable summary of the expression. + string HumanString() const; + + // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns + // an erroneous XlaOp if the expression is not a constant or an expression. + xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const; + + // If a kXlaOp or kConstant expression can be resolved to a compile-time + // constant, returns the value as a host-memory Tensor. Returns an empty + // optional if it cannot be resolved. Returns an error if passed a resource + // expression. + xla::StatusOr> ResolveConstant( + xla::Client* client) const; + + // Returns the shape of the tensor. + // The shape of a resource is the shape of a resource handle (i.e., a scalar), + // not the shape of the resource's value. + xla::StatusOr GetShape() const; + + private: + Kind kind_ = Kind::kInvalid; + + DataType dtype_ = DT_INVALID; + + // The XLA handle of the expression's computation, if kind_ == kXlaOp. + xla::XlaOp handle_; + + // The value of the constant, if kind_ == kConstant. + Tensor constant_value_; + + // The resource, if kind_ == kResource. Not owned. + XlaResource* resource_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc new file mode 100644 index 00000000000..84202c93139 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -0,0 +1,135 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/memory/memory.h" +#include "tensorflow/compiler/tf2xla/xla_expression.h" +#include "tensorflow/compiler/tf2xla/xla_resource.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/client/xla_builder.h" +#include "tensorflow/compiler/xla/literal.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/tests/literal_test_util.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +class XlaExpressionTest : public ::testing::Test { + protected: + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + builder_ = absl::make_unique("acomputation"); + constant_ = test::AsScalar(42); + op_ = xla::ConstantR0(builder_.get(), 7); + non_constant_op_ = xla::Parameter( + builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x"); + resource_ = absl::make_unique( + XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"), + DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1, + /*tensor_array_gradients=*/std::set(), + /*tensor_array_multiple_writes_aggregate=*/false); + } + + xla::Client* client_; + std::unique_ptr builder_; + Tensor constant_; + xla::XlaOp op_; + xla::XlaOp non_constant_op_; + std::unique_ptr resource_; +}; + +TEST_F(XlaExpressionTest, Kind) { + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind()); + EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind()); + EXPECT_TRUE(XlaExpression::Kind::kConstant == + XlaExpression::Constant(constant_).kind()); + EXPECT_TRUE(XlaExpression::Kind::kXlaOp == + XlaExpression::XlaOp(op_, DT_INT32).kind()); + EXPECT_TRUE(XlaExpression::Kind::kResource == + XlaExpression::Resource(resource_.get()).kind()); +} + +TEST_F(XlaExpressionTest, HumanString) { + EXPECT_EQ("invalid", XlaExpression().HumanString()); + EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString()); + EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString()); + EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString()); + EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString()); +} + +TEST_F(XlaExpressionTest, AsXlaOp) { + xla::XlaOp op_as_op = + XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get()); + EXPECT_TRUE(op_.IsIdenticalTo(op_as_op)); + + xla::XlaOp const_as_op = + XlaExpression::Constant(constant_).AsXlaOp(builder_.get()); + TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation, + builder_->BuildConstantSubGraph(const_as_op)); + TF_ASSERT_OK_AND_ASSIGN(xla::Literal value, + client_->ComputeConstant(computation)); + EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0(42), + value)); +} + +TEST_F(XlaExpressionTest, GetShape) { + EXPECT_FALSE(XlaExpression().GetShape().ok()); + EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok()); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape, + XlaExpression::Resource(resource_.get()).GetShape()); + EXPECT_EQ(TensorShape({}), resource_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape, + XlaExpression::XlaOp(op_, DT_INT32).GetShape()); + EXPECT_EQ(TensorShape({}), op_shape); + + TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape, + XlaExpression::Constant(constant_).GetShape()); + EXPECT_EQ(TensorShape({}), constant_shape); +} + +TEST_F(XlaExpressionTest, ResolveConstant) { + EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); + EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); + EXPECT_FALSE( + XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional op_constant, + XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_)); + ASSERT_TRUE(op_constant.has_value()); + test::ExpectTensorEqual(test::AsScalar(7), *op_constant); + + TF_ASSERT_OK_AND_ASSIGN(absl::optional op_nonconstant, + XlaExpression::XlaOp(non_constant_op_, DT_FLOAT) + .ResolveConstant(client_)); + EXPECT_FALSE(op_nonconstant.has_value()); + + TF_ASSERT_OK_AND_ASSIGN( + absl::optional constant_constant, + XlaExpression::Constant(constant_).ResolveConstant(client_)); + ASSERT_TRUE(constant_constant.has_value()); + test::ExpectTensorEqual(constant_, *constant_constant); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index 227915f5703..8dd8def0549 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/shape_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/xla/client/xla_builder.h" #include "tensorflow/compiler/xla/client/xla_computation.h" @@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const { static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) { const XlaExpression* expression = reinterpret_cast(tensor.tensor_data().data()); - CHECK(expression->handle().valid() || expression->resource() != nullptr); - VLOG(1) << "Fetched T" << expression->handle(); + CHECK(expression->kind() != XlaExpression::Kind::kInvalid) + << expression->HumanString(); return expression; } -// Retrieves an uninitialized XlaExpression from a newly-allocated tensor. -static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) { +// Assigns an XlaExpression to a tensor on an XLA compilation device. +static void AssignExpressionToTensor(Tensor* tensor, + const XlaExpression& value) { const XlaExpression* expression = reinterpret_cast(tensor->tensor_data().data()); - CHECK(!expression->handle().valid()); - return const_cast(expression); + CHECK(expression->kind() == XlaExpression::Kind::kInvalid) + << expression->HumanString(); + *const_cast(expression) = value; } -// Retrieves the XlaOp from an input Tensor to an Op. This computation was -// constructed by an Op that executed previously and created the output Tensor -// using CreateOutputTensorFromComputation or CreateConstantOutputTensor. -static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) { - return CastExpressionFromTensor(tensor)->handle(); +const XlaExpression& XlaOpKernelContext::InputExpression(int index) { + return *CastExpressionFromTensor(context_->input(index)); } -const xla::XlaOp& XlaOpKernelContext::Input(int index) { - return GetComputationFromTensor(context_->input(index)); +const XlaExpression& XlaOpKernelContext::InputExpression( + absl::string_view name) { + return *CastExpressionFromTensor(GetInputTensorByName(name)); } -const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) { - return GetComputationFromTensor(GetInputTensorByName(name)); +xla::XlaOp XlaOpKernelContext::Input(int index) { + return InputExpression(index).AsXlaOp(builder()); +} + +xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) { + return InputExpression(name).AsXlaOp(builder()); } TensorShape XlaOpKernelContext::InputShape(int index) { @@ -125,59 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name, Status XlaOpKernelContext::ConstantInputReshaped( int index, absl::Span new_dims, xla::Literal* constant_literal) { - const Tensor& tensor = context_->input(index); - TensorShape new_shape(new_dims); - if (tensor.NumElements() != new_shape.num_elements()) { - return errors::InvalidArgument( - context_->op_kernel().name(), " input ", index, " has shape ", - tensor.shape().DebugString(), - " but was asked to be reshaped to incompatible shape ", - new_shape.DebugString()); - } - const XlaExpression* expression = CastExpressionFromTensor(tensor); - - // If the tensor has a known constant value, there is no need to invoke XLA. - if (expression->has_constant_value()) { - Tensor temp(tensor.dtype()); - if (!temp.CopyFrom(expression->constant_value(), new_shape)) { - // This should never happen. The constant should have a shape compatible - // with the enclosing Tensor. - return errors::Internal("Incompatible shapes in ConstantInputReshaped."); - } - - TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); - return Status::OK(); - } - - // Make sure we treat zero-element tensors as constant. - if (new_shape.num_elements() == 0) { - Tensor temp(tensor.dtype(), new_shape); - TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); - return Status::OK(); - } - - xla::XlaOp handle = expression->handle(); - if (new_shape != tensor.shape()) { - // Reshape the handle to the desired shape. - handle = xla::Reshape(handle, new_shape.dim_sizes()); - } - - // The XLA layout is specified minor to major, and TensorFlow's minor - // dimension is the last one. - std::vector layout_indices(new_shape.dims()); - std::iota(layout_indices.rbegin(), layout_indices.rend(), 0); - xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices); - - xla::StatusOr is_constant = builder()->IsConstant(handle); - if (!is_constant.ok()) { - Status status = is_constant.status(); + XlaExpression e = InputExpression(index); + xla::StatusOr> constant_or_status = + e.ResolveConstant(compiler()->client()); + if (!constant_or_status.ok()) { + Status status = constant_or_status.status(); errors::AppendToMessage(&status, "while evaluating input ", index, " of ", context_->op_kernel().type_string(), " operator as a compile-time constant."); return status; } - - if (!is_constant.ValueOrDie()) { + absl::optional constant = constant_or_status.ValueOrDie(); + if (!constant.has_value()) { return errors::InvalidArgument( "Input ", index, " to ", context_->op_kernel().type_string(), " operator must be a compile-time constant.\n" @@ -190,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped( "stateful operation such as a random number generator."); } - // Ask the XLA compiler to evaluate the data handle to a literal. - xla::StatusOr constant_graph = - builder()->BuildConstantSubGraph(handle); - if (!constant_graph.ok()) { - return errors::Internal( - "Error getting a compile-time constant graph for ", - context_->op_kernel().name(), " input ", index, - ".\nError: ", constant_graph.status().error_message()); + Tensor temp(constant->dtype()); + if (!temp.CopyFrom(*constant, TensorShape(new_dims))) { + return errors::InvalidArgument( + context_->op_kernel().name(), " input ", index, " has shape ", + constant->shape().DebugString(), + " but was asked to be reshaped to incompatible shape ", + TensorShape(new_dims).DebugString()); } - xla::StatusOr computed = compiler()->client()->ComputeConstant( - constant_graph.ValueOrDie(), &layout); - if (!computed.ok()) { - return errors::Internal("Error evaluating ", context_->op_kernel().name(), - " input ", index, - " as a compile-time constant.\nError: ", - computed.status().error_message()); - } - *constant_literal = std::move(computed).ValueOrDie(); + TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp)); return Status::OK(); } @@ -363,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name, handles->clear(); shapes->clear(); for (const Tensor& input : inputs) { - handles->push_back(GetComputationFromTensor(input)); + handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder())); shapes->push_back(input.shape()); } return Status::OK(); @@ -449,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type, return Status::OK(); } -Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape, - Tensor** output) { - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - if (expected_output_dtype(index) == DT_VARIANT) { - // tensor_data() is not supported for variant Tensor (i.e., - // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the - // XlaExpression inside the Tensor's tensor_data() does not work for - // variant. Instead construct a uint8 tensor and store the expression in its - // value. - // TODO(jpienaar): This should be refactored to stop masquerading - // XlaExpressions as Tensors. - *output = new Tensor(); - TensorShape tensor_shape; - TF_RETURN_IF_ERROR( - context_->allocate_temp(DT_UINT8, tensor_shape, *output)); - context_->set_output(index, **output); - } else { - TensorShape tensor_shape; - TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape)); - TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output)); +void XlaOpKernelContext::SetOutputExpression(int index, + const XlaExpression& expression) { + Status status = [&] { + // The step's default allocator is the dummy XlaCompilationAllocator which + // simply allocates a metadata buffer to hold the expression to which it + // corresponds. + Tensor* output = nullptr; + // Provides a special behavior for DT_VARIANT: a variant is treated as + // DT_UINT8 scalar as the type to allow mapping for variant to more generic + // types. + if (expression.dtype() == DT_VARIANT) { + // tensor_data() is not supported for variant Tensor (i.e., + // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the + // XlaExpression inside the Tensor's tensor_data() does not work for + // variant. Instead construct a uint8 tensor and store the expression in + // its value. + // TODO(jpienaar): This should be refactored to stop masquerading + // XlaExpressions as Tensors. + output = new Tensor(); + TensorShape tensor_shape; + TF_RETURN_IF_ERROR( + context_->allocate_temp(DT_UINT8, tensor_shape, output)); + context_->set_output(index, *output); + } else { + TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape()); + TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output)); + } + AssignExpressionToTensor(output, expression); + return Status::OK(); + }(); + if (!status.ok()) { + SetStatus(status); } - return Status::OK(); } void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) { - // Makes the host Tensor that will refer to the expression. - Tensor* output = nullptr; - auto shape_or = builder()->GetShape(handle); - if (!shape_or.ok()) { - SetStatus(shape_or.status()); - return; - } - - OP_REQUIRES_OK(context_, - allocate_output(index, shape_or.ValueOrDie(), &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); + SetOutputExpression( + index, + XlaExpression::XlaOp(handle, context_->expected_output_dtype(index))); } void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) { - const TensorShape& shape = constant.shape(); - - xla::BorrowingLiteral literal; - OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal)); - - xla::XlaOp handle = xla::ConstantLiteral(builder(), literal); - CHECK(handle.valid()); - - // Make the Tensor that will refer to the expression. - Tensor* output = nullptr; - // The step's default allocator is the dummy XlaCompilationAllocator which - // simply allocates a metadata buffer to hold the expression to which it - // corresponds. - OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output)); - - // The expression is stored in the tensor's data buffer. Fill in the - // fields now. - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_handle(handle); - expression->set_constant_value(constant); -} - -void XlaOpKernelContext::SetInvalidOutput(int index) { - Tensor* output = nullptr; - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape({}), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - xla::XlaOp handle; - expression->set_handle(handle); + SetOutputExpression(index, XlaExpression::Constant(constant)); } void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) { - Tensor* output = nullptr; - // The shape of the output tensor is the shape of the resource itself - // (i.e., a scalar), not the shape of the resource's value. - OP_REQUIRES_OK(context_, - context_->allocate_output(index, TensorShape(), &output)); - XlaExpression* expression = CastExpressionFromUninitializedTensor(output); - expression->set_resource(resource); + SetOutputExpression(index, XlaExpression::Resource(resource)); } Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) { diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h index 3d9499f5fae..c06efa2c474 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.h +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h @@ -88,9 +88,9 @@ class XlaOpKernelContext { // Returns input `index` as a XlaOp. Unlike // OpKernelContext::Input returns a symbolic value rather than a concrete // Tensor. - const xla::XlaOp& Input(int index); + xla::XlaOp Input(int index); // Returns input `name` as a XlaOp. - const xla::XlaOp& Input(absl::string_view name); + xla::XlaOp Input(absl::string_view name); // Returns true if all inputs are the same shape, otherwise sets the // status to a non-OK value and returns false. @@ -142,6 +142,10 @@ class XlaOpKernelContext { Status ConstantInputList(absl::string_view name, std::vector* literals); + // Returns an XlaExpression describing the value of 'index'. + const XlaExpression& InputExpression(int index); + const XlaExpression& InputExpression(absl::string_view name); + // Outputs int num_outputs() const { return context_->num_outputs(); } @@ -159,9 +163,8 @@ class XlaOpKernelContext { // SetConstantOutput where possible. void SetConstantOutput(int index, const Tensor& host_tensor); - // Sets output `index` to an invalid value. - // Any subsequent attempt to consume this output will cause an error. - void SetInvalidOutput(int index); + // Returns an XlaExpression describing the value of 'index'. + void SetOutputExpression(int index, const XlaExpression& expression); // Status handling. void SetStatus(const Status& status) { context_->SetStatus(status); } @@ -249,11 +252,6 @@ class XlaOpKernelContext { // Returns the tensor of input `name`. const Tensor& GetInputTensorByName(absl::string_view name); - // Wraps OpKernelContext's allocate_output method while providing special - // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the - // type to allow mapping for variant to more generic types. - Status allocate_output(int index, const xla::Shape& shape, Tensor** output); - // Evaluates input `index`, reshapes it to `new_shape` if new_shape != // InputShape(index), and stores it in `*constant_literal`. If the input // cannot be evaluated, e.g., because it depends on unbound parameters,