From f895a9e9961788134c2c0fe747776dfbc36edb3c Mon Sep 17 00:00:00 2001
From: Peter Hawkins <phawkins@google.com>
Date: Mon, 12 Nov 2018 09:26:02 -0800
Subject: [PATCH] [TF:XLA] Refactor XlaCompiler, XlaExpression, and
 XlaOpKernelContext in preparation for adding support for compiling small
 computations that don't correspond to a single TF op.

The idea of the refactoring is that XlaExpression is the canonical XLA representation of a symbolic TF value. So in general a computation to compile is a function with type [XlaExpression] -> [XlaExpression], and in a future change we will add a method to XlaCompiler that exposes pretty much exactly that API. The current TF function/graph/op compilation methods are specific ways to build such a function.

* Move XlaExpression into its own file. Improve its ergonomics; it is really a kind of sum type. Also move some useful common methods on XlaExpressions into the XlaExpression class.
* Add support for passing and returning XlaExpressions via XlaOpKernelContext, since they are the underlying representation. The remaining *Input() and *Output() methods are really just conveniences built on top.
* Simplify _Arg and _Retval to just get and set an XlaExpression from an XlaContext. Move logic to flatten return values out of _Retval and move it instead into XlaCompiler so it can be reused when compiling non-graph computations.
* Move logic to assign cores to arguments and return values into a common place in XlaCompiler.

PiperOrigin-RevId: 221104314
---
 tensorflow/compiler/tf2xla/BUILD              |   8 +-
 tensorflow/compiler/tf2xla/graph_compiler.cc  |  51 +-
 tensorflow/compiler/tf2xla/kernels/arg_op.cc  |  12 +-
 .../compiler/tf2xla/kernels/const_op.cc       |   5 -
 .../compiler/tf2xla/kernels/retval_op.cc      |  59 +-
 .../compiler/tf2xla/xla_compilation_device.cc |   9 -
 .../compiler/tf2xla/xla_compilation_device.h  |  41 +-
 tensorflow/compiler/tf2xla/xla_compiler.cc    | 509 +++++++++++-------
 tensorflow/compiler/tf2xla/xla_compiler.h     |   5 +-
 tensorflow/compiler/tf2xla/xla_context.cc     |  50 +-
 tensorflow/compiler/tf2xla/xla_context.h      |  45 +-
 tensorflow/compiler/tf2xla/xla_expression.cc  | 145 +++++
 tensorflow/compiler/tf2xla/xla_expression.h   | 115 ++++
 .../compiler/tf2xla/xla_expression_test.cc    | 135 +++++
 tensorflow/compiler/tf2xla/xla_op_kernel.cc   | 228 +++-----
 tensorflow/compiler/tf2xla/xla_op_kernel.h    |  18 +-
 16 files changed, 843 insertions(+), 592 deletions(-)
 create mode 100644 tensorflow/compiler/tf2xla/xla_expression.cc
 create mode 100644 tensorflow/compiler/tf2xla/xla_expression.h
 create mode 100644 tensorflow/compiler/tf2xla/xla_expression_test.cc

diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index 7bd4e5ae79d..e0171415492 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -166,6 +166,7 @@ cc_library(
         "xla_compilation_device.cc",
         "xla_compiler.cc",
         "xla_context.cc",
+        "xla_expression.cc",
         "xla_helpers.cc",
         "xla_op_kernel.cc",
         "xla_op_registry.cc",
@@ -180,6 +181,7 @@ cc_library(
         "xla_compilation_device.h",
         "xla_compiler.h",
         "xla_context.h",
+        "xla_expression.h",
         "xla_helpers.h",
         "xla_op_kernel.h",
         "xla_op_registry.h",
@@ -364,7 +366,10 @@ tf_cc_test(
 
 tf_cc_test(
     name = "xla_compiler_test",
-    srcs = ["xla_compiler_test.cc"],
+    srcs = [
+        "xla_compiler_test.cc",
+        "xla_expression_test.cc",
+    ],
     deps = [
         ":common",
         ":side_effect_util",
@@ -389,6 +394,7 @@ tf_cc_test(
         "//tensorflow/core:test",
         "//tensorflow/core:test_main",
         "//tensorflow/core:testlib",
+        "@com_google_absl//absl/memory",
         "@com_google_absl//absl/strings",
     ],
 )
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 706ed4f5bbf..efb75749722 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -23,9 +23,9 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
-#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/xla/client/client_library.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
@@ -40,6 +40,7 @@ limitations under the License.
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/node_builder.h"
 #include "tensorflow/core/graph/validate.h"
+#include "tensorflow/core/lib/core/errors.h"
 #include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/logging.h"
@@ -51,12 +52,11 @@ namespace {
 Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
                         const std::vector<const XlaExpression*>& expressions,
                         std::vector<XlaCompiler::Argument>* args) {
-  auto builder = ctx->builder();
   auto client = ctx->compiler()->client();
-  std::vector<bool> compile_time_constant_flags(expressions.size());
+  std::vector<bool> arg_must_be_compile_time_constant(expressions.size());
 
   TF_RETURN_IF_ERROR(
-      BackwardsConstAnalysis(*graph, &compile_time_constant_flags,
+      BackwardsConstAnalysis(*graph, &arg_must_be_compile_time_constant,
                              /*compile_time_const_nodes=*/nullptr));
 
   args->resize(expressions.size());
@@ -65,24 +65,31 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
     arg.type = ctx->input_type(i);
     arg.shape = ctx->InputShape(i);
 
-    if (arg.type == DT_RESOURCE) {
-      return errors::InvalidArgument(
-          "Resource as function argument is not yet implemented.");
-    } else if (expressions[i]->has_constant_value()) {
-      arg.kind = XlaCompiler::Argument::kConstant;
-      arg.constant_value = expressions[i]->constant_value();
-    } else if (compile_time_constant_flags[i]) {
-      arg.kind = XlaCompiler::Argument::kConstant;
-      TF_RET_CHECK(expressions[i]->resource() == nullptr)
-          << "Input with resource is not yet implemented.";
-      TF_ASSIGN_OR_RETURN(auto constant_graph, builder->BuildConstantSubGraph(
-                                                   expressions[i]->handle()));
-      TF_ASSIGN_OR_RETURN(auto literal,
-                          client->ComputeConstant(constant_graph));
-      TF_RETURN_IF_ERROR(
-          LiteralToHostTensor(literal, arg.type, &arg.constant_value));
-    } else {
-      arg.kind = XlaCompiler::Argument::kParameter;
+    switch (expressions[i]->kind()) {
+      case XlaExpression::Kind::kConstant:
+        arg.kind = XlaCompiler::Argument::kConstant;
+        arg.constant_value = expressions[i]->constant_value();
+        break;
+      case XlaExpression::Kind::kXlaOp:
+        if (arg_must_be_compile_time_constant[i]) {
+          TF_ASSIGN_OR_RETURN(absl::optional<Tensor> value,
+                              expressions[i]->ResolveConstant(client));
+          if (!value.has_value()) {
+            return errors::InvalidArgument(
+                "Argument to function must be a compile-time constant, but "
+                "unable to resolve argument value to a constant.");
+          }
+          arg.kind = XlaCompiler::Argument::kConstant;
+          arg.constant_value = *value;
+        } else {
+          arg.kind = XlaCompiler::Argument::kParameter;
+        }
+        break;
+      case XlaExpression::Kind::kResource:
+        return errors::Unimplemented(
+            "Resource as function argument is not yet implemented.");
+      case XlaExpression::Kind::kInvalid:
+        return errors::InvalidArgument("Invalid function argument");
     }
   }
   return Status::OK();
diff --git a/tensorflow/compiler/tf2xla/kernels/arg_op.cc b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
index 276d744c096..2db2514397d 100644
--- a/tensorflow/compiler/tf2xla/kernels/arg_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/arg_op.cc
@@ -14,11 +14,13 @@ limitations under the License.
 ==============================================================================*/
 
 #include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/core/framework/kernel_def_builder.h"
+#include "tensorflow/core/lib/core/errors.h"
 
 namespace tensorflow {
 
@@ -49,13 +51,9 @@ class XlaArgOp : public XlaOpKernel {
     }
 
     const XlaExpression& arg = XlaContext::Get(ctx).args()[index_];
-    if (arg.resource() != nullptr) {
-      ctx->SetResourceOutput(0, arg.resource());
-    } else if (arg.has_constant_value()) {
-      ctx->SetConstantOutput(0, arg.constant_value());
-    } else {
-      ctx->SetOutput(0, arg.handle());
-    }
+    OP_REQUIRES(ctx, arg.kind() != XlaExpression::Kind::kInvalid,
+                errors::InvalidArgument("Invalid/missing argument expression"));
+    ctx->SetOutputExpression(0, arg);
   }
 
  private:
diff --git a/tensorflow/compiler/tf2xla/kernels/const_op.cc b/tensorflow/compiler/tf2xla/kernels/const_op.cc
index 2628ef8e245..dff8af80022 100644
--- a/tensorflow/compiler/tf2xla/kernels/const_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/const_op.cc
@@ -42,11 +42,6 @@ class ConstOp : public XlaOpKernel {
   void Compile(XlaOpKernelContext* ctx) override {
     TensorShape shape(proto_.tensor_shape());
 
-    if (proto_.dtype() == DT_STRING) {
-      LOG(WARNING) << "Not computing Const of type DT_STRING";
-      ctx->SetInvalidOutput(0);
-      return;
-    }
     xla::XlaBuilder* b = ctx->builder();
 
     // To avoid blowups for large constants filled with the same value,
diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
index 53e7624d607..6970dd0a006 100644
--- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc
@@ -47,63 +47,8 @@ class RetvalOp : public XlaOpKernel {
       // compilation.
       OP_REQUIRES_OK(ctx, frame->SetRetval(index_, input));
     } else {
-      xla::XlaOp input = ctx->Input(0);
-      const TensorShape input_shape = ctx->InputShape(0);
-      DataType input_type = ctx->input_type(0);
-      XlaContext& tc = XlaContext::Get(ctx);
-
-      if (input_type == DT_RESOURCE) {
-        XlaResource* resource;
-        OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &resource));
-        ctx->SetStatus(tc.AddResourceRetval(index_, resource));
-        return;
-      }
-
-      auto is_constant = ctx->builder()->IsConstant(input);
-      if (!is_constant.ok()) {
-        ctx->SetStatus(is_constant.status());
-        return;
-      }
-
-      if (tc.resolve_compile_time_constants() &&
-          (input_shape.num_elements() == 0 || is_constant.ValueOrDie())) {
-        xla::Literal literal;
-        OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal));
-        OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal));
-      } else {
-        TensorShape shape = ctx->InputShape(0);
-        ctx->SetStatus(is_constant.status());
-        xla::Shape representation_shape;
-        if (tc.is_entry_computation()) {
-          xla::StatusOr<xla::Shape> shape_or_status =
-              tc.RepresentationShape(shape, ctx->input_type(0));
-          if (!shape_or_status.ok()) {
-            ctx->SetStatus(shape_or_status.status());
-            return;
-          } else {
-            representation_shape = shape_or_status.ValueOrDie();
-          }
-        } else {
-          OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(ctx->input_type(0), shape,
-                                                    &representation_shape));
-        }
-
-        xla::XlaOp output = input;
-        if (tc.is_entry_computation()) {
-          output = xla::Reshape(
-              input, xla::AsInt64Slice(representation_shape.dimensions()));
-        } else {
-          // The core from which a return value is returned depends on the
-          // device assignment of the input to the retval. Since we can't change
-          // the device assignment of "input" at this point, we must always
-          // introduce an operator here, even if the shape does not change.
-          // TODO(b/76097077): propagate device assignments onto arguments and
-          // return values of functions, and then reshape unconditionally.
-          output =
-              xla::GetTupleElement(xla::Tuple(ctx->builder(), {output}), 0);
-        }
-        tc.AddRetval(index_, dtype_, shape, output);
-      }
+      XlaContext& xla_context = XlaContext::Get(ctx);
+      xla_context.SetRetval(index_, ctx->InputExpression(0));
     }
   }
 
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
index cb7843850c3..ddb284966ee 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc
@@ -124,13 +124,4 @@ Status XlaCompilationDevice::MakeTensorFromProto(
       "XLACompilationDevice::MakeTensorFromProto should not be called");
 }
 
-XlaExpression::XlaExpression() = default;
-
-void XlaExpression::set_handle(const xla::XlaOp& h) { handle_ = h; }
-
-void XlaExpression::set_constant_value(Tensor value) {
-  has_constant_value_ = true;
-  constant_value_ = std::move(value);
-}
-
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h
index a6e78825334..de6a3356e05 100644
--- a/tensorflow/compiler/tf2xla/xla_compilation_device.h
+++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h
@@ -18,9 +18,6 @@ limitations under the License.
 
 #include <memory>
 
-#include "tensorflow/compiler/tf2xla/xla_resource.h"
-#include "tensorflow/compiler/xla/client/xla_builder.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
 #include "tensorflow/core/common_runtime/local_device.h"
 #include "tensorflow/core/framework/device_base.h"
 #include "tensorflow/core/framework/tensor.h"
@@ -38,8 +35,8 @@ class XlaCompilationAllocator;
 // This is a 'dummy' TensorFlow device that is only used to execute a
 // subgraph of XLA compilation Ops to construct a compiled version
 // of the subgraph's computation. It has a 'dummy' allocator that
-// backs each Tensor with metadata indicating the computation the
-// Tensor represents.
+// backs each Tensor with an XlaExpression. The shape of the Tensor
+// matches the shape of XlaExpression.
 //
 // We deliberately don't register a device factory because we *never*
 // want placement to put Ops on a compilation device. The device is created
@@ -67,40 +64,6 @@ class XlaCompilationDevice : public LocalDevice {
   std::unique_ptr<XlaCompilationAllocator> allocator_;
 };
 
-// A XlaExpression wraps an XLA computation. Each Tensor on an
-// XlaCompilationDevice contains an XlaExpression, and the shape of the Tensor
-// matches the shape of the subcomputation in the XlaOp. Each
-// expression is either a constant, or a function of previously-compiled
-// expressions.
-class XlaExpression {
- public:
-  XlaExpression();
-
-  // handle() stores the XLA handle of the computation that the
-  // expression represents.
-  void set_handle(const xla::XlaOp& h);
-  const xla::XlaOp& handle() const { return handle_; }
-
-  void set_constant_value(Tensor value);
-  bool has_constant_value() const { return has_constant_value_; }
-  const Tensor& constant_value() const { return constant_value_; }
-
-  void set_resource(XlaResource* resource) { resource_ = resource; }
-  XlaResource* resource() const { return resource_; }
-
- private:
-  // The XLA handle of the expression's computation.
-  xla::XlaOp handle_;
-
-  // If this expression is a constant with a known value, 'constant_value' is a
-  // host-memory Tensor containing the value. Used to avoid invoking XLA for
-  // expressions that are trivially constant.
-  bool has_constant_value_ = false;
-  Tensor constant_value_;
-
-  XlaResource* resource_ = nullptr;  // Not owned.
-};
-
 }  // namespace tensorflow
 
 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILATION_DEVICE_H_
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index e6d7710c244..a08d030ce71 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -36,11 +36,13 @@ limitations under the License.
 #include "tensorflow/core/common_runtime/function.h"
 #include "tensorflow/core/common_runtime/graph_optimizer.h"
 #include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
 #include "tensorflow/core/framework/node_def_util.h"
 #include "tensorflow/core/framework/types.h"
 #include "tensorflow/core/graph/algorithm.h"
 #include "tensorflow/core/graph/graph_constructor.h"
 #include "tensorflow/core/graph/node_builder.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
 #include "tensorflow/core/lib/hash/hash.h"
 #include "tensorflow/core/platform/logging.h"
 
@@ -64,6 +66,240 @@ Status CheckSignature(const DataTypeVector& types,
   return Status::OK();
 }
 
+// Uses the _Arg and _Retval nodes in the graph to determine a core assignment
+// for each argument and return value.
+xla::StatusOr<std::pair<std::map<int, int>, std::map<int, int>>>
+ComputeArgAndRetvalCores(const Graph& graph) {
+  auto get_sharding_for_node = [](const Node* n) -> xla::StatusOr<int> {
+    TF_ASSIGN_OR_RETURN(
+        auto sharding,
+        ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
+    if (sharding.has_value()) {
+      TF_RET_CHECK(sharding.value().type() ==
+                   xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
+      return sharding.value().tile_assignment_devices(0);
+    } else {
+      return -1;
+    }
+  };
+  std::map<int, int> arg_cores;
+  std::map<int, int> retval_cores;
+  for (const Node* n : graph.nodes()) {
+    if (n->type_string() == FunctionLibraryDefinition::kArgOp) {
+      TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
+      if (core < 0) continue;
+      int index;
+      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+      TF_RET_CHECK(index >= 0) << "Negative _Arg index";
+      arg_cores[index] = core;
+    } else if (n->type_string() == FunctionLibraryDefinition::kRetOp) {
+      TF_ASSIGN_OR_RETURN(int core, get_sharding_for_node(n));
+      if (core < 0) continue;
+      int index;
+      TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
+      TF_RET_CHECK(index >= 0) << "Negative _Retval index";
+      TF_ASSIGN_OR_RETURN(retval_cores[index], get_sharding_for_node(n));
+      retval_cores[index] = core;
+    }
+  }
+  return std::make_pair(std::move(arg_cores), std::move(retval_cores));
+}
+
+Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
+                    XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
+                    int64 step_id) {
+  // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
+  // resource manager takes ownership via Create, and unrefs via Cleanup.  We
+  // explicitly add a reference to ensure the refcount at entry is maintained at
+  // all exit points; Create and Cleanup are always called in this function.
+  //
+  // The Executor requires us to use ScopedStepContainer. We wrap it in a
+  // unique_ptr so we can capture the cleanup status in the end.
+  xla_context->Ref();
+  Status status;
+  auto step_container = absl::make_unique<ScopedStepContainer>(
+      step_id, [&status, device](const string& name) {
+        status = device->resource_manager()->Cleanup(name);
+      });
+  TF_RETURN_IF_ERROR(device->resource_manager()->Create(
+      step_container->name(), XlaContext::kXlaContextResourceName,
+      xla_context));
+
+  GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
+  TF_RETURN_IF_ERROR(graph_compiler.Compile());
+  // Explicitly clean up the step container, to capture the cleanup status.
+  step_container.reset();
+  return Status::OK();
+}
+
+// Builds the XLA computation.
+// - `args` is the list of input arguments
+// - `retvals` is the list of retvals produced by _Retval operators, in index
+//   order.
+// - `args_core` and `retval_cores` are mapping from arg/return indices to core
+//   assignments.
+// - If `return_updated_values_for_all_resources` is true, all resources will be
+//   included in `resource_updates`, regardless of whether their value changed.
+// - Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
+// - Sets `*resource_updates` to a description of resources whose values are
+//   written by the computation; the variable writes are the last
+// - `resource_updates.size()` return values from the computation. Each entry in
+//   `resource_updates` is a ResourceUpdate, whose `index` is the index of a
+//   resource variable argument to the computation to be updated, and `type` is
+//   the type of the final output.
+Status BuildComputation(
+    const std::vector<XlaCompiler::Argument>& args,
+    const std::vector<XlaExpression>& retvals,
+    const std::map<int, int>& arg_cores, const std::map<int, int>& retval_cores,
+    const std::vector<std::unique_ptr<XlaResource>>& resources,
+    std::unique_ptr<xla::XlaOp> token_output,
+    const XlaCompiler::ShapeRepresentationFn& shape_representation_fn,
+    bool return_updated_values_for_all_resources, bool always_return_tuple,
+    xla::XlaBuilder* builder, xla::XlaComputation* computation,
+    int* num_computation_outputs, int* num_nonconst_outputs,
+    std::vector<XlaCompiler::OutputDescription>* outputs,
+    std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
+  // Attach a common operator name as metadata. This has no semantic effect — it
+  // merely makes the HLO graph more readable when visualized via TensorBoard,
+  // since TensorBoard forms groups out of operators with similar names.
+  xla::OpMetadata retval_metadata;
+  retval_metadata.set_op_name("XLA_Retvals");
+  builder->SetOpMetadata(retval_metadata);
+  auto cleanup = gtl::MakeCleanup([builder]() { builder->ClearOpMetadata(); });
+
+  // Builds a no-op XLA computation. We need to set the sharding of outputs, but
+  // cannot change the sharding of the existing output op. To do this, we build
+  // a new identity op to which shardings can be applied.
+  auto identity_op = [builder](xla::XlaOp op) {
+    return xla::GetTupleElement(xla::Tuple(builder, {op}), 0);
+  };
+
+  std::vector<xla::XlaOp> elems;
+  elems.reserve(retvals.size());
+  for (int i = 0; i < retvals.size(); ++i) {
+    XlaCompiler::OutputDescription& output = (*outputs)[i];
+    const XlaExpression& retval = retvals[i];
+    output.type = retval.dtype();
+    switch (retval.kind()) {
+      case XlaExpression::Kind::kConstant:
+        output.is_constant = true;
+        output.constant_value = retval.constant_value();
+        output.shape = output.constant_value.shape();
+        break;
+
+      case XlaExpression::Kind::kXlaOp: {
+        output.is_constant = false;
+        TF_ASSIGN_OR_RETURN(output.shape, retval.GetShape());
+        xla::XlaOp value = retval.handle();
+        auto it = retval_cores.find(i);
+        xla::XlaScopedShardingAssignment assign_sharding(
+            builder, it == retval_cores.end()
+                         ? absl::optional<xla::OpSharding>()
+                         : xla::sharding_builder::AssignDevice(it->second));
+        if (shape_representation_fn) {
+          // If there is a shape representation function, reshape the output
+          // tensor to the shape given by the representation shape function.
+          TF_ASSIGN_OR_RETURN(xla::Shape shape, shape_representation_fn(
+                                                    output.shape, output.type));
+          value = xla::Reshape(value, xla::AsInt64Slice(shape.dimensions()));
+        } else if (it != retval_cores.end()) {
+          // Apply the sharding to the output, if there is a core assignment.
+          value = identity_op(value);
+        }
+        elems.push_back(value);
+        break;
+      }
+
+      case XlaExpression::Kind::kResource:
+        output.is_constant = false;
+        output.input_index = retval.resource()->arg_num();
+        output.shape = retval.resource()->shape();
+        break;
+
+      case XlaExpression::Kind::kInvalid:
+        return errors::InvalidArgument(
+            "Invalid expression returned by computation. "
+            "This probably means a return value was not set.");
+    }
+  }
+  *num_nonconst_outputs = elems.size();
+
+  // Add return values for resources whose values have changed.
+  std::vector<const XlaResource*> arg_resources;
+  arg_resources.reserve(resources.size());
+  for (const auto& resource : resources) {
+    if (resource->arg_num() >= 0) {
+      arg_resources.push_back(resource.get());
+    }
+  }
+  std::sort(arg_resources.begin(), arg_resources.end(),
+            [](const XlaResource* a, const XlaResource* b) {
+              return a->arg_num() < b->arg_num();
+            });
+
+  for (const XlaResource* resource : arg_resources) {
+    DCHECK_LT(resource->arg_num(), args.size());
+    const XlaCompiler::Argument& arg = args[resource->arg_num()];
+    auto it = arg_cores.find(resource->arg_num());
+    const int core = it == arg_cores.end() ? -1 : it->second;
+    bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
+    // TensorArray gradients were modified if their values changed or there are
+    // any newly created gradients.
+    for (const auto& grad : resource->tensor_array_gradients()) {
+      modified =
+          modified ||
+          !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
+          arg.tensor_array_gradients.count(grad.first) == 0;
+    }
+    if (return_updated_values_for_all_resources || modified) {
+      resource_updates->emplace_back();
+      XlaCompiler::ResourceUpdate& update = resource_updates->back();
+      update.input_index = resource->arg_num();
+      update.type = resource->type();
+      update.shape = resource->shape();
+      update.modified = modified;
+      for (const auto& grad : resource->tensor_array_gradients()) {
+        update.tensor_array_gradients_accessed.insert(grad.first);
+      }
+
+      // Request that the value be returned on a specific core.
+      xla::XlaScopedShardingAssignment assign_sharding(
+          builder, core == -1 ? absl::optional<xla::OpSharding>()
+                              : xla::sharding_builder::AssignDevice(core));
+
+      xla::XlaOp handle;
+      TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
+
+      // Ensures the correct sharding is applied to the output.
+      handle = identity_op(handle);
+
+      elems.push_back(handle);
+    }
+  }
+
+  // If we have token output, append it as the last one.
+  if (token_output) {
+    elems.push_back(*token_output);
+  }
+
+  *num_computation_outputs = elems.size();
+
+  // Builds the XLA computation. We *always* form a tuple here to ensure that
+  // the output value is the last thing added into the XLA computation, even
+  // if there is only one output value.
+  auto tuple = xla::Tuple(builder, elems);
+  if (!always_return_tuple && elems.size() == 1) {
+    xla::GetTupleElement(tuple, 0);
+  }
+
+  xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
+  if (!computation_status.ok()) {
+    return computation_status.status();
+  }
+  *computation = computation_status.ConsumeValueOrDie();
+  return Status::OK();
+}
+
 }  // namespace
 
 bool XlaCompiler::Argument::operator==(
@@ -252,14 +488,16 @@ Status XlaCompiler::CompileFunction(
   // lowest-numbered core that consumes the argument. We choose the
   // lowest-numbered core so the assignment is deterministic.
   for (Node* n : graph->nodes()) {
-    if (absl::string_view(n->type_string()) == "_Arg") {
+    if (absl::string_view(n->type_string()) ==
+        FunctionLibraryDefinition::kArgOp) {
       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true));
     }
   }
   // Do _Retval as a second loop, in case the retval's input is an _Arg (which
   // may have gotten a device assignment from the first loop).
   for (Node* n : graph->nodes()) {
-    if (absl::string_view(n->type_string()) == "_Retval") {
+    if (absl::string_view(n->type_string()) ==
+        FunctionLibraryDefinition::kRetOp) {
       TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false));
     }
   }
@@ -353,175 +591,16 @@ Status XlaCompiler::XLAShapeForArgument(const XlaCompiler::Argument& arg,
   }
 }
 
-namespace {
-
-Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr<Graph> graph,
-                    XlaCompilationDevice* device, FunctionLibraryRuntime* flib,
-                    int64 step_id) {
-  // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the
-  // resource manager takes ownership via Create, and unrefs via Cleanup.  We
-  // explicitly add a reference to ensure the refcount at entry is maintained at
-  // all exit points; Create and Cleanup are always called in this function.
-  //
-  // The Executor requires us to use ScopedStepContainer. We wrap it in a
-  // unique_ptr so we can capture the cleanup status in the end.
-  xla_context->Ref();
-  Status status;
-  auto step_container = absl::make_unique<ScopedStepContainer>(
-      step_id, [&status, device](const string& name) {
-        status = device->resource_manager()->Cleanup(name);
-      });
-  TF_RETURN_IF_ERROR(device->resource_manager()->Create(
-      step_container->name(), XlaContext::kXlaContextResourceName,
-      xla_context));
-
-  GraphCompiler graph_compiler(device, graph.get(), flib, step_container.get());
-  TF_RETURN_IF_ERROR(graph_compiler.Compile());
-  // Explicitly clean up the step container, to capture the cleanup status.
-  step_container.reset();
-  return Status::OK();
-}
-
-// Builds the XLA computation.
-// `args` is the list of input arguments, `retvals` is the list of retvals
-// produced by _Retval operators, in index order.
-// If `return_updated_values_for_all_resources` is true, all resources will be
-// included in `resource_updates`, regardless of whether their value changed.
-// Sets `*num_nonconst_outputs` to the number of outputs of the `computation`.
-// Sets `*resource_updates` to a description of resources whose values are
-// written by the computation; the variable writes are the last
-// `resource_updates.size()` return values from the computation. Each entry in
-// `resource_updates` is a (input_index, type) pair, where `input_index` is the
-// index of a resource variable argument to the computation, and `type` is the
-// type of the final output.
-Status BuildComputation(
-    const std::vector<XlaCompiler::Argument>& args,
-    const std::vector<int>& arg_cores,
-    const std::vector<XlaContext::Retval>& retvals,
-    const std::vector<std::unique_ptr<XlaResource>>& resources,
-    std::unique_ptr<xla::XlaOp> token_output,
-    bool return_updated_values_for_all_resources, bool always_return_tuple,
-    xla::XlaBuilder* builder, xla::XlaComputation* computation,
-    int* num_computation_outputs, int* num_nonconst_outputs,
-    std::vector<XlaCompiler::OutputDescription>* outputs,
-    std::vector<XlaCompiler::ResourceUpdate>* resource_updates) {
-  std::vector<xla::XlaOp> elems;
-  elems.reserve(retvals.size());
-  for (int i = 0; i < retvals.size(); ++i) {
-    XlaCompiler::OutputDescription& output = (*outputs)[i];
-    output.type = retvals[i].type;
-    output.shape = retvals[i].shape;
-    const XlaExpression& retval = retvals[i].expression;
-    if (retval.has_constant_value()) {
-      output.is_constant = true;
-      output.constant_value = retval.constant_value();
-    } else if (retval.resource() != nullptr) {
-      output.is_constant = false;
-      output.input_index = retval.resource()->arg_num();
-    } else {
-      output.is_constant = false;
-      elems.push_back(retval.handle());
-    }
-  }
-  *num_nonconst_outputs = elems.size();
-
-  // Add return values for resources whose values have changed.
-  std::vector<const XlaResource*> arg_resources;
-  arg_resources.reserve(resources.size());
-  for (const auto& resource : resources) {
-    if (resource->arg_num() >= 0) {
-      arg_resources.push_back(resource.get());
-    }
-  }
-  std::sort(arg_resources.begin(), arg_resources.end(),
-            [](const XlaResource* a, const XlaResource* b) {
-              return a->arg_num() < b->arg_num();
-            });
-
-  // Attach a common operator name as metadata. This has no semantic effect — it
-  // merely makes the HLO graph more readable when visualized via TensorBoard,
-  // since TensorBoard forms groups out of operators with similar names.
-  xla::OpMetadata retval_metadata;
-  retval_metadata.set_op_name("XLA_Retvals");
-  builder->SetOpMetadata(retval_metadata);
-
-  for (const XlaResource* resource : arg_resources) {
-    const XlaCompiler::Argument& arg = args[resource->arg_num()];
-    const int core = arg_cores[resource->arg_num()];
-    DCHECK_LT(resource->arg_num(), arg_cores.size());
-    bool modified = !resource->value().IsIdenticalTo(resource->initial_value());
-    // TensorArray gradients were modified if their values changed or there are
-    // any newly created gradients.
-    for (const auto& grad : resource->tensor_array_gradients()) {
-      modified =
-          modified ||
-          !grad.second->value().IsIdenticalTo(grad.second->initial_value()) ||
-          arg.tensor_array_gradients.count(grad.first) == 0;
-    }
-    if (return_updated_values_for_all_resources || modified) {
-      resource_updates->emplace_back();
-      XlaCompiler::ResourceUpdate& update = resource_updates->back();
-      update.input_index = resource->arg_num();
-      update.type = resource->type();
-      update.shape = resource->shape();
-      update.modified = modified;
-      for (const auto& grad : resource->tensor_array_gradients()) {
-        update.tensor_array_gradients_accessed.insert(grad.first);
-      }
-
-      // Request that the value be returned on a specific core.
-      xla::XlaScopedShardingAssignment assign_sharding(
-          builder, core == -1 ? absl::optional<xla::OpSharding>()
-                              : xla::sharding_builder::AssignDevice(core));
-
-      xla::XlaOp handle;
-      TF_RETURN_IF_ERROR(resource->Pack(&handle, builder));
-
-      // Since we can't change the sharding metadata of <value> as this point,
-      // create a tuple/get-tuple-element combination so that sharding
-      // assignment will be placed on this value, which will cause the resource
-      // update to be returned from the same device that provided the resource.
-      handle = xla::GetTupleElement(xla::Tuple(builder, {handle}), 0);
-      elems.push_back(handle);
-    }
-  }
-
-  // If we have token output, append it as the last one.
-  if (token_output) {
-    elems.push_back(*token_output);
-  }
-
-  *num_computation_outputs = elems.size();
-
-  // Builds the XLA computation. We *always* form a tuple here to ensure that
-  // the output value is the last thing added into the XLA computation, even
-  // if there is only one output value.
-  auto tuple = xla::Tuple(builder, elems);
-  if (!always_return_tuple && elems.size() == 1) {
-    xla::GetTupleElement(tuple, 0);
-  }
-  builder->ClearOpMetadata();
-
-  xla::StatusOr<xla::XlaComputation> computation_status = builder->Build();
-  if (!computation_status.ok()) {
-    return computation_status.status();
-  }
-  *computation = computation_status.ConsumeValueOrDie();
-  return Status::OK();
-}
-
-}  // namespace
-
 // Builds XLA computations for each of the arguments to the computation.
 // `args` are the arguments to the computation.
 Status XlaCompiler::BuildArguments(
     const Graph& graph, const std::vector<XlaCompiler::Argument>& args,
     bool use_tuple_arg, xla::XlaBuilder* builder, XlaContext* context,
-    std::vector<int>* arg_cores, std::vector<XlaExpression>* arg_expressions,
+    const std::map<int, int>& arg_cores,
+    std::vector<XlaExpression>* arg_expressions,
     std::vector<int>* input_mapping, std::vector<xla::Shape>* input_shapes,
     bool is_entry_computation) {
   arg_expressions->resize(args.size());
-  *arg_cores = std::vector<int>(args.size(), -1);
 
   // Argument numbers of arguments and resources that are to be passed to the
   // XLA computation as runtime parameters.
@@ -543,7 +622,7 @@ Status XlaCompiler::BuildArguments(
             arg.resource_kind, i, arg.name, arg.type, arg.shape, xla::XlaOp(),
             /*tensor_array_size=*/arg.tensor_array_size,
             /*tensor_array_gradients=*/arg.tensor_array_gradients, &resource));
-        arg_expression.set_resource(resource);
+        arg_expression = XlaExpression::Resource(resource);
         if (arg.initialized) {
           input_mapping->push_back(i);
         }
@@ -555,7 +634,7 @@ Status XlaCompiler::BuildArguments(
         break;
       }
       case XlaCompiler::Argument::kConstant:
-        arg_expression.set_constant_value(arg.constant_value);
+        arg_expression = XlaExpression::Constant(arg.constant_value);
         break;
       case XlaCompiler::Argument::kInvalid:
         return errors::Internal(
@@ -580,26 +659,6 @@ Status XlaCompiler::BuildArguments(
     *input_shapes = arg_shapes;
   }
 
-  // Use the _Arg nodes in the graph to resolve core assignments.
-  for (const Node* n : graph.nodes()) {
-    if (absl::string_view(n->type_string()) != "_Arg") continue;
-    int index;
-    TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
-    TF_RET_CHECK(index >= 0 && index < args.size())
-        << "_Arg out of bounds: " << index << " vs " << args.size();
-    TF_ASSIGN_OR_RETURN(
-        auto sharding,
-        ParseShardingFromDevice(*n, std::numeric_limits<int32>::max()));
-    if (sharding.has_value()) {
-      TF_RET_CHECK(sharding.value().type() ==
-                   xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
-      const int core = sharding.value().tile_assignment_devices(0);
-      if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) {
-        (*arg_cores)[index] = core;
-      }
-    }
-  }
-
   // Attach a common operator name as metadata. This has no semantic effect — it
   // merely makes the HLO graph more readable when visualized via TensorBoard,
   // since TensorBoard forms groups out of operators with similar names.
@@ -615,11 +674,10 @@ Status XlaCompiler::BuildArguments(
       xla::OpSharding tuple_sharding;
       tuple_sharding.set_type(xla::OpSharding::Type::OpSharding_Type_TUPLE);
       for (int64 parameter : *input_mapping) {
-        const int core = (*arg_cores)[parameter];
-        const int root_device = 0;
+        auto it = arg_cores.find(parameter);
+        const int core = it == arg_cores.end() ? 0 : it->second;
         *tuple_sharding.add_tuple_shardings() =
-            core == -1 ? xla::sharding_builder::AssignDevice(root_device)
-                       : xla::sharding_builder::AssignDevice(core);
+            xla::sharding_builder::AssignDevice(core);
       }
       xla::XlaScopedShardingAssignment assign_tuple_sharding(builder,
                                                              tuple_sharding);
@@ -628,7 +686,8 @@ Status XlaCompiler::BuildArguments(
       tuple = xla::Parameter(builder, 0, (*input_shapes)[0], "arg_tuple");
     }
     for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
-      const int core = (*arg_cores)[input_mapping->at(i)];
+      auto it = arg_cores.find(i);
+      const int core = it == arg_cores.end() ? -1 : it->second;
       xla::XlaScopedShardingAssignment assign_sharding(
           builder, core == -1 ? absl::optional<xla::OpSharding>()
                               : xla::sharding_builder::AssignDevice(core));
@@ -636,7 +695,8 @@ Status XlaCompiler::BuildArguments(
     }
   } else {
     for (std::vector<int>::size_type i = 0; i < input_mapping->size(); ++i) {
-      const int core = (*arg_cores)[input_mapping->at(i)];
+      auto it = arg_cores.find(i);
+      const int core = it == arg_cores.end() ? -1 : it->second;
       xla::XlaScopedShardingAssignment assign_sharding(
           builder, core == -1 ? absl::optional<xla::OpSharding>()
                               : xla::sharding_builder::AssignDevice(core));
@@ -671,14 +731,14 @@ Status XlaCompiler::BuildArguments(
         // TODO(b/76097077): propagate device assignments onto arguments and
         // return values of functions, and then reshape unconditionally.
         if (is_entry_computation) {
-          arg_expression.set_handle(
-              xla::Reshape(arg_handles[i], arg.shape.dim_sizes()));
+          arg_expression = XlaExpression::XlaOp(
+              xla::Reshape(arg_handles[i], arg.shape.dim_sizes()), arg.type);
         } else {
-          arg_expression.set_handle(arg_handles[i]);
+          arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
         }
         break;
       case XlaCompiler::Argument::kToken: {
-        arg_expression.set_handle(arg_handles[i]);
+        arg_expression = XlaExpression::XlaOp(arg_handles[i], arg.type);
         break;
       }
       case XlaCompiler::Argument::kConstant:
@@ -710,7 +770,7 @@ Status XlaCompiler::CompileSingleOp(
     Node* node;
     string arg_name = absl::StrCat("_arg", i);
     Status status =
-        NodeBuilder(arg_name, "_Arg")
+        NodeBuilder(arg_name, FunctionLibraryDefinition::kArgOp)
             .ControlInput(graph->source_node())
             .Attr("T", args[i].kind == Argument::kResource ? DT_RESOURCE
                                                            : args[i].type)
@@ -724,7 +784,7 @@ Status XlaCompiler::CompileSingleOp(
   for (int64 i = 0; i < result_types.size(); ++i) {
     Node* node;
     string retval_name = absl::StrCat("_retval", i);
-    Status status = NodeBuilder(retval_name, "_Retval")
+    Status status = NodeBuilder(retval_name, FunctionLibraryDefinition::kRetOp)
                         .Input(main_node, i)
                         .Attr("T", result_types[i])
                         .Attr("index", i)
@@ -788,6 +848,32 @@ Status ValidateGraph(const Graph* graph,
   return Status::OK();
 }
 
+// Converts the value of any expressions whose values are known at compile-time
+// to constants.
+Status ResolveConstantExpressionsToConstants(
+    xla::Client* client, absl::Span<XlaExpression> expressions) {
+  for (XlaExpression& expression : expressions) {
+    if (expression.kind() == XlaExpression::Kind::kXlaOp) {
+      TF_ASSIGN_OR_RETURN(absl::optional<Tensor> constant,
+                          expression.ResolveConstant(client));
+      if (constant.has_value()) {
+        expression = XlaExpression::Constant(*constant);
+      }
+    }
+  }
+  return Status::OK();
+}
+
+void ConvertConstantsToExpressions(xla::XlaBuilder* builder,
+                                   absl::Span<XlaExpression> expressions) {
+  for (XlaExpression& expression : expressions) {
+    if (expression.kind() == XlaExpression::Kind::kConstant) {
+      expression =
+          XlaExpression::XlaOp(expression.AsXlaOp(builder), expression.dtype());
+    }
+  }
+}
+
 }  // namespace
 
 Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
@@ -815,10 +901,9 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
                                    options_.device_type, name));
 
   xla::XlaBuilder builder(name);
-  XlaContext* context = new XlaContext(
-      this, &builder, options_.allow_cpu_custom_calls,
-      options.resolve_compile_time_constants, options.is_entry_computation,
-      &options_.shape_representation_fn);
+  XlaContext* context =
+      new XlaContext(this, &builder, options_.allow_cpu_custom_calls,
+                     &options_.shape_representation_fn);
   core::ScopedUnref context_unref(context);
 
   std::vector<XlaCompiler::Argument> real_args(args.begin(), args.end());
@@ -833,10 +918,14 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
     real_args.push_back(token_arg);
   }
 
+  std::map<int, int> arg_cores;
+  std::map<int, int> retval_cores;
+  TF_ASSIGN_OR_RETURN(std::tie(arg_cores, retval_cores),
+                      ComputeArgAndRetvalCores(*graph));
+
   std::vector<XlaExpression> arg_expressions;
-  std::vector<int> arg_cores;
   TF_RETURN_IF_ERROR(BuildArguments(
-      *graph, real_args, options.use_tuple_arg, &builder, context, &arg_cores,
+      *graph, real_args, options.use_tuple_arg, &builder, context, arg_cores,
       &arg_expressions, &result->input_mapping, &result->xla_input_shapes,
       options.is_entry_computation));
   context->set_args(std::move(arg_expressions));
@@ -884,9 +973,19 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options,
   int num_computation_outputs;
   result->computation = std::make_shared<xla::XlaComputation>();
   result->outputs.resize(context->retvals().size());
+  std::vector<XlaExpression> retvals = context->retvals();
+  if (options.resolve_compile_time_constants) {
+    TF_RETURN_IF_ERROR(ResolveConstantExpressionsToConstants(
+        client(), absl::Span<XlaExpression>(retvals)));
+  } else {
+    ConvertConstantsToExpressions(&builder, absl::Span<XlaExpression>(retvals));
+  }
   TF_RETURN_IF_ERROR(BuildComputation(
-      real_args, arg_cores, context->retvals(), context->resources(),
-      std::move(token_output), options.return_updated_values_for_all_resources,
+      real_args, retvals, arg_cores, retval_cores, context->resources(),
+      std::move(token_output),
+      options.is_entry_computation ? options_.shape_representation_fn
+                                   : ShapeRepresentationFn{},
+      options.return_updated_values_for_all_resources,
       options.always_return_tuple, &builder, result->computation.get(),
       &num_computation_outputs, &num_nonconst_outputs, &result->outputs,
       &result->resource_updates));
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h
index f10cfbe0c65..63426124686 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.h
+++ b/tensorflow/compiler/tf2xla/xla_compiler.h
@@ -21,8 +21,10 @@ limitations under the License.
 #include "absl/types/span.h"
 #include "tensorflow/compiler/tf2xla/host_compute_metadata.pb.h"
 #include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
 #include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/status_macros.h"
 #include "tensorflow/core/common_runtime/device.h"
@@ -415,7 +417,8 @@ class XlaCompiler {
   Status BuildArguments(const Graph& graph,
                         const std::vector<XlaCompiler::Argument>& args,
                         bool use_tuple_arg, xla::XlaBuilder* builder,
-                        XlaContext* context, std::vector<int>* arg_cores,
+                        XlaContext* context,
+                        const std::map<int, int>& arg_cores,
                         std::vector<XlaExpression>* arg_expressions,
                         std::vector<int>* input_mapping,
                         std::vector<xla::Shape>* input_shapes,
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 1e819dbb694..43095fbb473 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -64,63 +64,23 @@ void XlaContext::set_args(std::vector<XlaExpression> args) {
 
 XlaContext::XlaContext(
     XlaCompiler* compiler, xla::XlaBuilder* builder,
-    bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
-    bool is_entry_computation,
+    bool allow_cpu_custom_calls,
     const std::function<xla::StatusOr<xla::Shape>(
         const TensorShape&, DataType)>* shape_representation_fn)
     : compiler_(compiler),
       builder_(builder),
       allow_cpu_custom_calls_(allow_cpu_custom_calls),
-      resolve_compile_time_constants_(resolve_compile_time_constants),
-      is_entry_computation_(is_entry_computation),
       shape_representation_fn_(shape_representation_fn) {}
 
 string XlaContext::DebugString() { return "TLA JIT context"; }
 
-// This is called by the Retval Op to associate a computed value
-// with a specific return value of the subgraph.
-void XlaContext::AddRetval(int retval_index, DataType type,
-                           const TensorShape& shape, const xla::XlaOp& handle) {
-  VLOG(1) << "Added retval index " << retval_index << " to XLA computation";
-  // Add the return value to the list being built up.
-  if (retvals_.size() <= retval_index) {
-    retvals_.resize(retval_index + 1);
+void XlaContext::SetRetval(int index, const XlaExpression& expression) {
+  if (retvals_.size() <= index) {
+    retvals_.resize(index + 1);
   }
-  XlaExpression e;
-  e.set_handle(handle);
-  retvals_[retval_index] = Retval{type, shape, e};
+  retvals_[index] = expression;
 }
 
-Status XlaContext::AddConstRetval(int retval_index, DataType dtype,
-                                  const xla::LiteralSlice& literal) {
-  VLOG(1) << "Adding retval index " << retval_index
-          << " with non-data-dependent tensor to XLA computation";
-  if (retvals_.size() <= retval_index) {
-    retvals_.resize(retval_index + 1);
-  }
-  Tensor value;
-  TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype, &value));
-  XlaExpression e;
-  e.set_constant_value(value);
-  retvals_[retval_index] = Retval{dtype, value.shape(), e};
-  return Status::OK();
-}
-
-Status XlaContext::AddResourceRetval(int retval_index, XlaResource* resource) {
-  VLOG(1) << "Adding retval index " << retval_index << " with resource "
-          << resource->name() << ":" << resource->shape().DebugString()
-          << " to XLA computation";
-  if (retvals_.size() <= retval_index) {
-    retvals_.resize(retval_index + 1);
-  }
-  XlaExpression e;
-  e.set_resource(resource);
-  retvals_[retval_index] = Retval{DT_RESOURCE, resource->shape(), e};
-  return Status::OK();
-}
-
-xla::XlaBuilder* XlaContext::builder() { return builder_; }
-
 Status XlaContext::CreateResource(
     XlaResource::Kind kind, int arg_num, string name, DataType type,
     TensorShape shape, const xla::XlaOp& handle, int64 tensor_array_size,
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index 8aad6cbced0..dbfd344c9ba 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -20,8 +20,8 @@ limitations under the License.
 
 #include <vector>
 
-#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
 #include "tensorflow/compiler/xla/status_macros.h"
@@ -46,8 +46,7 @@ class XlaContext : public ResourceBase {
   // Creates a new XlaContext. See the documentation on the class data fields
   // for descriptions of the arguments.
   XlaContext(XlaCompiler* compiler, xla::XlaBuilder* builder,
-             bool allow_cpu_custom_calls, bool resolve_compile_time_constants,
-             bool is_entry_computation,
+             bool allow_cpu_custom_calls,
              const std::function<xla::StatusOr<xla::Shape>(
                  const TensorShape&, DataType)>* shape_representation_fn);
 
@@ -57,37 +56,19 @@ class XlaContext : public ResourceBase {
   XlaCompiler* compiler() const { return compiler_; }
 
   // Returns the XlaBuilder that Ops use for compiling new expressions.
-  xla::XlaBuilder* builder();
+  xla::XlaBuilder* builder() { return builder_; }
 
   bool allow_cpu_custom_calls() const { return allow_cpu_custom_calls_; }
 
-  bool resolve_compile_time_constants() const {
-    return resolve_compile_time_constants_;
-  }
-  bool is_entry_computation() const { return is_entry_computation_; }
-
   const std::vector<XlaExpression>& args() const { return args_; }
   void set_args(std::vector<XlaExpression> args);
 
-  struct Retval {
-    DataType type;
-    TensorShape shape;
-    // An XlaExpression representing the Retval's value.
-    XlaExpression expression;
-  };
-  const std::vector<Retval>& retvals() { return retvals_; }
+  const std::vector<XlaExpression>& retvals() { return retvals_; }
 
-  // This is called by the Retval Op to associate a computed value
-  // with a specific return value of the subgraph.
-  void AddRetval(int retval_index, DataType type, const TensorShape& shape,
-                 const xla::XlaOp& handle);
-
-  // As for Retval, but for return values that are compile-time constants.
-  Status AddConstRetval(int retval_index, DataType dtype,
-                        const xla::LiteralSlice& literal);
-
-  // As for Retval, but for return values that are resource handles.
-  Status AddResourceRetval(int retval_index, XlaResource* resource);
+  // Sets a return value.
+  // Since we do not always know in advance how many return values there are,
+  // grows the return values vector to size index+1 if it is smaller.
+  void SetRetval(int index, const XlaExpression& expression);
 
   // Creates a resource with resource `kind` and initial value `handle`. `name`
   // is a descriptive name for use in error messages. See the `XlaResource`
@@ -140,24 +121,16 @@ class XlaContext : public ResourceBase {
   // Allow ops to emit CustomCall operations for CPU.
   const bool allow_cpu_custom_calls_;
 
-  // If true, constant return values are returned as Tensors instead of
-  // run-time computation outputs.
-  const bool resolve_compile_time_constants_;
-
   // Arguments to the Tensorflow graph, indexed by _Arg index.
   // Includes both compile-time constant arguments and runtime parameters.
   std::vector<XlaExpression> args_;
 
   // Return values of the Tensorflow graph, indexed by _Retval index.
-  std::vector<Retval> retvals_;
+  std::vector<XlaExpression> retvals_;
 
   // Holds ownership of resources. The resources are not ordered.
   std::vector<std::unique_ptr<XlaResource>> resources_;
 
-  // Is this a top-level computation, or an inner computation (e.g., a while
-  // body)?
-  const bool is_entry_computation_;
-
   // Describes the on-host shapes of parameters and return values. Also see:
   // XlaDevice::Options::shape_representation_fn.
   const std::function<xla::StatusOr<xla::Shape>(const TensorShape&, DataType)>*
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
new file mode 100644
index 00000000000..ca0309166b7
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -0,0 +1,145 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
+
+#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+XlaExpression::XlaExpression() = default;
+
+XlaExpression XlaExpression::Invalid() {
+  XlaExpression e;
+  e.kind_ = Kind::kInvalid;
+  return e;
+}
+
+XlaExpression XlaExpression::Constant(Tensor value) {
+  XlaExpression e;
+  e.kind_ = Kind::kConstant;
+  e.dtype_ = value.dtype();
+  e.constant_value_ = value;
+  return e;
+}
+
+XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
+  XlaExpression e;
+  e.kind_ = Kind::kXlaOp;
+  e.dtype_ = dtype;
+  e.handle_ = value;
+  return e;
+}
+
+XlaExpression XlaExpression::Resource(XlaResource* resource) {
+  XlaExpression e;
+  e.kind_ = Kind::kResource;
+  e.dtype_ = DT_RESOURCE;
+  e.resource_ = resource;
+  return e;
+}
+
+string XlaExpression::HumanString() const {
+  switch (kind_) {
+    case Kind::kInvalid:
+      return "invalid";
+    case Kind::kConstant:
+      return "constant";
+    case Kind::kXlaOp:
+      return "xla_op";
+    case Kind::kResource:
+      return "resource";
+  }
+}
+
+xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
+  return builder->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
+    switch (kind_) {
+      case Kind::kConstant: {
+        xla::BorrowingLiteral literal;
+        TF_RETURN_IF_ERROR(
+            HostTensorToBorrowingLiteral(constant_value_, &literal));
+        return xla::ConstantLiteral(builder, literal);
+      }
+      case Kind::kXlaOp:
+        if (builder != handle_.builder()) {
+          return errors::InvalidArgument(
+              "Mismatched builders in XlaExpression::AsXlaOp");
+        }
+        return handle_;
+      default:
+        return errors::InvalidArgument("AsXlaOp called on XlaExpression: ",
+                                       HumanString());
+    }
+  });
+}
+
+xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
+    xla::Client* client) const {
+  switch (kind()) {
+    case Kind::kConstant:
+      return {constant_value()};
+    case Kind::kXlaOp:
+      break;
+    case Kind::kResource:
+    case Kind::kInvalid:
+      return errors::InvalidArgument(
+          "ResolveConstant called on XlaExpression: ", HumanString());
+  }
+
+  TF_ASSIGN_OR_RETURN(bool is_constant,
+                      handle().builder()->IsConstant(handle()));
+  if (!is_constant) return {absl::nullopt};
+
+  TF_ASSIGN_OR_RETURN(xla::XlaComputation constant_graph,
+                      handle().builder()->BuildConstantSubGraph(handle()));
+
+  TF_ASSIGN_OR_RETURN(TensorShape shape, GetShape());
+
+  // The XLA layout is specified minor to major, and TensorFlow uses a major to
+  // minor order.
+  std::vector<int64> layout_indices(shape.dims());
+  std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
+  xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
+  TF_ASSIGN_OR_RETURN(xla::Literal literal,
+                      client->ComputeConstant(constant_graph, &layout));
+  Tensor tensor;
+  TF_RETURN_IF_ERROR(LiteralToHostTensor(literal, dtype(), &tensor));
+  return {tensor};
+}
+
+xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
+  switch (kind_) {
+    case Kind::kConstant:
+      return constant_value().shape();
+    case Kind::kXlaOp: {
+      TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
+                          handle().builder()->GetShape(handle()));
+      TensorShape shape;
+      TF_RETURN_IF_ERROR(XLAShapeToTensorShape(xla_shape, &shape));
+      return shape;
+    }
+    case Kind::kResource:
+      return TensorShape({});
+    case Kind::kInvalid:
+      return errors::InvalidArgument(
+          "GetShape() called on invalid XlaExpression");
+  }
+}
+
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
new file mode 100644
index 00000000000..bed6761d362
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -0,0 +1,115 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
+#define TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
+
+#include "absl/types/optional.h"
+#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+// A XlaExpression represents a symbolic TensorFlow value in a TF->XLA
+// compilation.
+// An expression is one of:
+// * a constant tensor.
+// * an xla::XlaOp, representing a symbolic XLA value.
+// * a resource, e.g., a variable, represented as an XlaResource pointer.
+//
+// Constant tensors are mostly an optimization to avoid passing large constants
+// to XLA, but are also sometimes used to represent tensors that have no XLA
+// representation, for example, DT_STRING tensors. A canonical use case might be
+// an error message string.
+class XlaExpression {
+ public:
+  enum class Kind {
+    kInvalid,
+    kConstant,
+    kXlaOp,
+    kResource,
+  };
+
+  XlaExpression();
+  XlaExpression(const XlaExpression&) = default;
+  XlaExpression& operator=(const XlaExpression&) = default;
+
+  // Builds an invalid expression. (Same as the default constructor, but makes
+  // the intent clearer.)
+  static XlaExpression Invalid();
+
+  // Builds a constant XLA expression.
+  static XlaExpression Constant(Tensor value);
+
+  // Builds a XlaOp expression. Since the mapping from TF data types to XLA
+  // types is not 1-1, the TF type must also be provided; in general it cannot
+  // be derived from the XLA type.
+  static XlaExpression XlaOp(xla::XlaOp value, DataType dtype);
+
+  // Builds a resource expression.
+  static XlaExpression Resource(XlaResource* resource);
+
+  Kind kind() const { return kind_; }
+
+  DataType dtype() const { return dtype_; }
+
+  // handle() returns the XlaOp that backs a kXlaOp expression.
+  const xla::XlaOp& handle() const { return handle_; }
+
+  const Tensor& constant_value() const { return constant_value_; }
+
+  XlaResource* resource() const { return resource_; }
+
+  // Returns a human-readable summary of the expression.
+  string HumanString() const;
+
+  // Returns the value of a kConstant or kXlaOp as an xla::XlaOp. Returns
+  // an erroneous XlaOp if the expression is not a constant or an expression.
+  xla::XlaOp AsXlaOp(xla::XlaBuilder* builder) const;
+
+  // If a kXlaOp or kConstant expression can be resolved to a compile-time
+  // constant, returns the value as a host-memory Tensor. Returns an empty
+  // optional if it cannot be resolved. Returns an error if passed a resource
+  // expression.
+  xla::StatusOr<absl::optional<Tensor>> ResolveConstant(
+      xla::Client* client) const;
+
+  // Returns the shape of the tensor.
+  // The shape of a resource is the shape of a resource handle (i.e., a scalar),
+  // not the shape of the resource's value.
+  xla::StatusOr<TensorShape> GetShape() const;
+
+ private:
+  Kind kind_ = Kind::kInvalid;
+
+  DataType dtype_ = DT_INVALID;
+
+  // The XLA handle of the expression's computation, if kind_ == kXlaOp.
+  xla::XlaOp handle_;
+
+  // The value of the constant, if kind_ == kConstant.
+  Tensor constant_value_;
+
+  // The resource, if kind_ == kResource. Not owned.
+  XlaResource* resource_ = nullptr;
+};
+
+}  // namespace tensorflow
+
+#endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_EXPRESSION_H_
diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc
new file mode 100644
index 00000000000..84202c93139
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc
@@ -0,0 +1,135 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+    http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <memory>
+
+#include "absl/memory/memory.h"
+#include "tensorflow/compiler/tf2xla/xla_expression.h"
+#include "tensorflow/compiler/tf2xla/xla_resource.h"
+#include "tensorflow/compiler/xla/client/client_library.h"
+#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_builder.h"
+#include "tensorflow/compiler/xla/literal.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/tests/literal_test_util.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace {
+
+class XlaExpressionTest : public ::testing::Test {
+ protected:
+  void SetUp() override {
+    client_ = xla::ClientLibrary::LocalClientOrDie();
+    builder_ = absl::make_unique<xla::XlaBuilder>("acomputation");
+    constant_ = test::AsScalar<int32>(42);
+    op_ = xla::ConstantR0<int32>(builder_.get(), 7);
+    non_constant_op_ = xla::Parameter(
+        builder_.get(), 0, xla::ShapeUtil::MakeShape(xla::F32, {}), "x");
+    resource_ = absl::make_unique<XlaResource>(
+        XlaResource::kVariable, /*arg_num=*/0, /*name=*/string("avariable"),
+        DT_INT32, TensorShape({17, 3}), op_, /*tensor_array_size=*/-1,
+        /*tensor_array_gradients=*/std::set<string>(),
+        /*tensor_array_multiple_writes_aggregate=*/false);
+  }
+
+  xla::Client* client_;
+  std::unique_ptr<xla::XlaBuilder> builder_;
+  Tensor constant_;
+  xla::XlaOp op_;
+  xla::XlaOp non_constant_op_;
+  std::unique_ptr<XlaResource> resource_;
+};
+
+TEST_F(XlaExpressionTest, Kind) {
+  EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression().kind());
+  EXPECT_TRUE(XlaExpression::Kind::kInvalid == XlaExpression::Invalid().kind());
+  EXPECT_TRUE(XlaExpression::Kind::kConstant ==
+              XlaExpression::Constant(constant_).kind());
+  EXPECT_TRUE(XlaExpression::Kind::kXlaOp ==
+              XlaExpression::XlaOp(op_, DT_INT32).kind());
+  EXPECT_TRUE(XlaExpression::Kind::kResource ==
+              XlaExpression::Resource(resource_.get()).kind());
+}
+
+TEST_F(XlaExpressionTest, HumanString) {
+  EXPECT_EQ("invalid", XlaExpression().HumanString());
+  EXPECT_EQ("invalid", XlaExpression::Invalid().HumanString());
+  EXPECT_EQ("constant", XlaExpression::Constant(constant_).HumanString());
+  EXPECT_EQ("xla_op", XlaExpression::XlaOp(op_, DT_INT32).HumanString());
+  EXPECT_EQ("resource", XlaExpression::Resource(resource_.get()).HumanString());
+}
+
+TEST_F(XlaExpressionTest, AsXlaOp) {
+  xla::XlaOp op_as_op =
+      XlaExpression::XlaOp(op_, DT_INT32).AsXlaOp(builder_.get());
+  EXPECT_TRUE(op_.IsIdenticalTo(op_as_op));
+
+  xla::XlaOp const_as_op =
+      XlaExpression::Constant(constant_).AsXlaOp(builder_.get());
+  TF_ASSERT_OK_AND_ASSIGN(xla::XlaComputation computation,
+                          builder_->BuildConstantSubGraph(const_as_op));
+  TF_ASSERT_OK_AND_ASSIGN(xla::Literal value,
+                          client_->ComputeConstant(computation));
+  EXPECT_TRUE(xla::LiteralTestUtil::Equal(xla::LiteralUtil::CreateR0<int32>(42),
+                                          value));
+}
+
+TEST_F(XlaExpressionTest, GetShape) {
+  EXPECT_FALSE(XlaExpression().GetShape().ok());
+  EXPECT_FALSE(XlaExpression::Invalid().GetShape().ok());
+
+  TF_ASSERT_OK_AND_ASSIGN(TensorShape resource_shape,
+                          XlaExpression::Resource(resource_.get()).GetShape());
+  EXPECT_EQ(TensorShape({}), resource_shape);
+
+  TF_ASSERT_OK_AND_ASSIGN(TensorShape op_shape,
+                          XlaExpression::XlaOp(op_, DT_INT32).GetShape());
+  EXPECT_EQ(TensorShape({}), op_shape);
+
+  TF_ASSERT_OK_AND_ASSIGN(TensorShape constant_shape,
+                          XlaExpression::Constant(constant_).GetShape());
+  EXPECT_EQ(TensorShape({}), constant_shape);
+}
+
+TEST_F(XlaExpressionTest, ResolveConstant) {
+  EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
+  EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
+  EXPECT_FALSE(
+      XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      absl::optional<Tensor> op_constant,
+      XlaExpression::XlaOp(op_, DT_INT32).ResolveConstant(client_));
+  ASSERT_TRUE(op_constant.has_value());
+  test::ExpectTensorEqual<int32>(test::AsScalar<int32>(7), *op_constant);
+
+  TF_ASSERT_OK_AND_ASSIGN(absl::optional<Tensor> op_nonconstant,
+                          XlaExpression::XlaOp(non_constant_op_, DT_FLOAT)
+                              .ResolveConstant(client_));
+  EXPECT_FALSE(op_nonconstant.has_value());
+
+  TF_ASSERT_OK_AND_ASSIGN(
+      absl::optional<Tensor> constant_constant,
+      XlaExpression::Constant(constant_).ResolveConstant(client_));
+  ASSERT_TRUE(constant_constant.has_value());
+  test::ExpectTensorEqual<int32>(constant_, *constant_constant);
+}
+
+}  // namespace
+}  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 227915f5703..8dd8def0549 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -20,6 +20,7 @@ limitations under the License.
 #include "tensorflow/compiler/tf2xla/literal_util.h"
 #include "tensorflow/compiler/tf2xla/shape_util.h"
 #include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_compilation_device.h"
 #include "tensorflow/compiler/tf2xla/xla_context.h"
 #include "tensorflow/compiler/xla/client/xla_builder.h"
 #include "tensorflow/compiler/xla/client/xla_computation.h"
@@ -43,32 +44,36 @@ xla::XlaBuilder* XlaOpKernelContext::builder() const {
 static const XlaExpression* CastExpressionFromTensor(const Tensor& tensor) {
   const XlaExpression* expression =
       reinterpret_cast<const XlaExpression*>(tensor.tensor_data().data());
-  CHECK(expression->handle().valid() || expression->resource() != nullptr);
-  VLOG(1) << "Fetched T" << expression->handle();
+  CHECK(expression->kind() != XlaExpression::Kind::kInvalid)
+      << expression->HumanString();
   return expression;
 }
 
-// Retrieves an uninitialized XlaExpression from a newly-allocated tensor.
-static XlaExpression* CastExpressionFromUninitializedTensor(Tensor* tensor) {
+// Assigns an XlaExpression to a tensor on an XLA compilation device.
+static void AssignExpressionToTensor(Tensor* tensor,
+                                     const XlaExpression& value) {
   const XlaExpression* expression =
       reinterpret_cast<const XlaExpression*>(tensor->tensor_data().data());
-  CHECK(!expression->handle().valid());
-  return const_cast<XlaExpression*>(expression);
+  CHECK(expression->kind() == XlaExpression::Kind::kInvalid)
+      << expression->HumanString();
+  *const_cast<XlaExpression*>(expression) = value;
 }
 
-// Retrieves the XlaOp from an input Tensor to an Op. This computation was
-// constructed by an Op that executed previously and created the output Tensor
-// using CreateOutputTensorFromComputation or CreateConstantOutputTensor.
-static const xla::XlaOp& GetComputationFromTensor(const Tensor& tensor) {
-  return CastExpressionFromTensor(tensor)->handle();
+const XlaExpression& XlaOpKernelContext::InputExpression(int index) {
+  return *CastExpressionFromTensor(context_->input(index));
 }
 
-const xla::XlaOp& XlaOpKernelContext::Input(int index) {
-  return GetComputationFromTensor(context_->input(index));
+const XlaExpression& XlaOpKernelContext::InputExpression(
+    absl::string_view name) {
+  return *CastExpressionFromTensor(GetInputTensorByName(name));
 }
 
-const xla::XlaOp& XlaOpKernelContext::Input(absl::string_view name) {
-  return GetComputationFromTensor(GetInputTensorByName(name));
+xla::XlaOp XlaOpKernelContext::Input(int index) {
+  return InputExpression(index).AsXlaOp(builder());
+}
+
+xla::XlaOp XlaOpKernelContext::Input(absl::string_view name) {
+  return InputExpression(name).AsXlaOp(builder());
 }
 
 TensorShape XlaOpKernelContext::InputShape(int index) {
@@ -125,59 +130,18 @@ Status XlaOpKernelContext::ConstantInput(absl::string_view name,
 Status XlaOpKernelContext::ConstantInputReshaped(
     int index, absl::Span<const int64> new_dims,
     xla::Literal* constant_literal) {
-  const Tensor& tensor = context_->input(index);
-  TensorShape new_shape(new_dims);
-  if (tensor.NumElements() != new_shape.num_elements()) {
-    return errors::InvalidArgument(
-        context_->op_kernel().name(), " input ", index, " has shape ",
-        tensor.shape().DebugString(),
-        " but was asked to be reshaped to incompatible shape ",
-        new_shape.DebugString());
-  }
-  const XlaExpression* expression = CastExpressionFromTensor(tensor);
-
-  // If the tensor has a known constant value, there is no need to invoke XLA.
-  if (expression->has_constant_value()) {
-    Tensor temp(tensor.dtype());
-    if (!temp.CopyFrom(expression->constant_value(), new_shape)) {
-      // This should never happen. The constant should have a shape compatible
-      // with the enclosing Tensor.
-      return errors::Internal("Incompatible shapes in ConstantInputReshaped.");
-    }
-
-    TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
-    return Status::OK();
-  }
-
-  // Make sure we treat zero-element tensors as constant.
-  if (new_shape.num_elements() == 0) {
-    Tensor temp(tensor.dtype(), new_shape);
-    TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
-    return Status::OK();
-  }
-
-  xla::XlaOp handle = expression->handle();
-  if (new_shape != tensor.shape()) {
-    // Reshape the handle to the desired shape.
-    handle = xla::Reshape(handle, new_shape.dim_sizes());
-  }
-
-  // The XLA layout is specified minor to major, and TensorFlow's minor
-  // dimension is the last one.
-  std::vector<int64> layout_indices(new_shape.dims());
-  std::iota(layout_indices.rbegin(), layout_indices.rend(), 0);
-  xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
-
-  xla::StatusOr<bool> is_constant = builder()->IsConstant(handle);
-  if (!is_constant.ok()) {
-    Status status = is_constant.status();
+  XlaExpression e = InputExpression(index);
+  xla::StatusOr<absl::optional<Tensor>> constant_or_status =
+      e.ResolveConstant(compiler()->client());
+  if (!constant_or_status.ok()) {
+    Status status = constant_or_status.status();
     errors::AppendToMessage(&status, "while evaluating input ", index, " of ",
                             context_->op_kernel().type_string(),
                             " operator as a compile-time constant.");
     return status;
   }
-
-  if (!is_constant.ValueOrDie()) {
+  absl::optional<Tensor> constant = constant_or_status.ValueOrDie();
+  if (!constant.has_value()) {
     return errors::InvalidArgument(
         "Input ", index, " to ", context_->op_kernel().type_string(),
         " operator must be a compile-time constant.\n"
@@ -190,25 +154,16 @@ Status XlaOpKernelContext::ConstantInputReshaped(
         "stateful operation such as a random number generator.");
   }
 
-  // Ask the XLA compiler to evaluate the data handle to a literal.
-  xla::StatusOr<xla::XlaComputation> constant_graph =
-      builder()->BuildConstantSubGraph(handle);
-  if (!constant_graph.ok()) {
-    return errors::Internal(
-        "Error getting a compile-time constant graph for ",
-        context_->op_kernel().name(), " input ", index,
-        ".\nError: ", constant_graph.status().error_message());
+  Tensor temp(constant->dtype());
+  if (!temp.CopyFrom(*constant, TensorShape(new_dims))) {
+    return errors::InvalidArgument(
+        context_->op_kernel().name(), " input ", index, " has shape ",
+        constant->shape().DebugString(),
+        " but was asked to be reshaped to incompatible shape ",
+        TensorShape(new_dims).DebugString());
   }
-  xla::StatusOr<xla::Literal> computed = compiler()->client()->ComputeConstant(
-      constant_graph.ValueOrDie(), &layout);
-  if (!computed.ok()) {
-    return errors::Internal("Error evaluating ", context_->op_kernel().name(),
-                            " input ", index,
-                            " as a compile-time constant.\nError: ",
-                            computed.status().error_message());
-  }
-  *constant_literal = std::move(computed).ValueOrDie();
 
+  TF_ASSIGN_OR_RETURN(*constant_literal, HostTensorToLiteral(temp));
   return Status::OK();
 }
 
@@ -363,7 +318,7 @@ Status XlaOpKernelContext::InputList(absl::string_view name,
   handles->clear();
   shapes->clear();
   for (const Tensor& input : inputs) {
-    handles->push_back(GetComputationFromTensor(input));
+    handles->push_back(CastExpressionFromTensor(input)->AsXlaOp(builder()));
     shapes->push_back(input.shape());
   }
   return Status::OK();
@@ -449,90 +404,53 @@ Status XlaOpKernelContext::GetVariableTypeAndShape(int index, DataType* type,
   return Status::OK();
 }
 
-Status XlaOpKernelContext::allocate_output(int index, const xla::Shape& shape,
-                                           Tensor** output) {
-  // The step's default allocator is the dummy XlaCompilationAllocator which
-  // simply allocates a metadata buffer to hold the expression to which it
-  // corresponds.
-  if (expected_output_dtype(index) == DT_VARIANT) {
-    // tensor_data() is not supported for variant Tensor (i.e.,
-    // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
-    // XlaExpression inside the Tensor's tensor_data() does not work for
-    // variant. Instead construct a uint8 tensor and store the expression in its
-    // value.
-    // TODO(jpienaar): This should be refactored to stop masquerading
-    // XlaExpressions as Tensors.
-    *output = new Tensor();
-    TensorShape tensor_shape;
-    TF_RETURN_IF_ERROR(
-        context_->allocate_temp(DT_UINT8, tensor_shape, *output));
-    context_->set_output(index, **output);
-  } else {
-    TensorShape tensor_shape;
-    TF_RETURN_IF_ERROR(XLAShapeToTensorShape(shape, &tensor_shape));
-    TF_RETURN_IF_ERROR(context_->allocate_output(index, tensor_shape, output));
+void XlaOpKernelContext::SetOutputExpression(int index,
+                                             const XlaExpression& expression) {
+  Status status = [&] {
+    // The step's default allocator is the dummy XlaCompilationAllocator which
+    // simply allocates a metadata buffer to hold the expression to which it
+    // corresponds.
+    Tensor* output = nullptr;
+    // Provides a special behavior for DT_VARIANT: a variant is treated as
+    // DT_UINT8 scalar as the type to allow mapping for variant to more generic
+    // types.
+    if (expression.dtype() == DT_VARIANT) {
+      // tensor_data() is not supported for variant Tensor (i.e.,
+      // DataTypeCanUseMemcpy is false for DT_VARIANT), and so storing the
+      // XlaExpression inside the Tensor's tensor_data() does not work for
+      // variant. Instead construct a uint8 tensor and store the expression in
+      // its value.
+      // TODO(jpienaar): This should be refactored to stop masquerading
+      // XlaExpressions as Tensors.
+      output = new Tensor();
+      TensorShape tensor_shape;
+      TF_RETURN_IF_ERROR(
+          context_->allocate_temp(DT_UINT8, tensor_shape, output));
+      context_->set_output(index, *output);
+    } else {
+      TF_ASSIGN_OR_RETURN(TensorShape shape, expression.GetShape());
+      TF_RETURN_IF_ERROR(context_->allocate_output(index, shape, &output));
+    }
+    AssignExpressionToTensor(output, expression);
+    return Status::OK();
+  }();
+  if (!status.ok()) {
+    SetStatus(status);
   }
-  return Status::OK();
 }
 
 void XlaOpKernelContext::SetOutput(int index, const xla::XlaOp& handle) {
-  // Makes the host Tensor that will refer to the expression.
-  Tensor* output = nullptr;
-  auto shape_or = builder()->GetShape(handle);
-  if (!shape_or.ok()) {
-    SetStatus(shape_or.status());
-    return;
-  }
-
-  OP_REQUIRES_OK(context_,
-                 allocate_output(index, shape_or.ValueOrDie(), &output));
-
-  // The expression is stored in the tensor's data buffer. Fill in the
-  // fields now.
-  XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
-  expression->set_handle(handle);
+  SetOutputExpression(
+      index,
+      XlaExpression::XlaOp(handle, context_->expected_output_dtype(index)));
 }
 
 void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
-  const TensorShape& shape = constant.shape();
-
-  xla::BorrowingLiteral literal;
-  OP_REQUIRES_OK(context_, HostTensorToBorrowingLiteral(constant, &literal));
-
-  xla::XlaOp handle = xla::ConstantLiteral(builder(), literal);
-  CHECK(handle.valid());
-
-  // Make the Tensor that will refer to the expression.
-  Tensor* output = nullptr;
-  // The step's default allocator is the dummy XlaCompilationAllocator which
-  // simply allocates a metadata buffer to hold the expression to which it
-  // corresponds.
-  OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
-
-  // The expression is stored in the tensor's data buffer. Fill in the
-  // fields now.
-  XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
-  expression->set_handle(handle);
-  expression->set_constant_value(constant);
-}
-
-void XlaOpKernelContext::SetInvalidOutput(int index) {
-  Tensor* output = nullptr;
-  OP_REQUIRES_OK(context_,
-                 context_->allocate_output(index, TensorShape({}), &output));
-  XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
-  xla::XlaOp handle;
-  expression->set_handle(handle);
+  SetOutputExpression(index, XlaExpression::Constant(constant));
 }
 
 void XlaOpKernelContext::SetResourceOutput(int index, XlaResource* resource) {
-  Tensor* output = nullptr;
-  // The shape of the output tensor is the shape of the resource itself
-  // (i.e., a scalar), not the shape of the resource's value.
-  OP_REQUIRES_OK(context_,
-                 context_->allocate_output(index, TensorShape(), &output));
-  XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
-  expression->set_resource(resource);
+  SetOutputExpression(index, XlaExpression::Resource(resource));
 }
 
 Status XlaOpKernelContext::GetResourceInput(int index, XlaResource** resource) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 3d9499f5fae..c06efa2c474 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -88,9 +88,9 @@ class XlaOpKernelContext {
   // Returns input `index` as a XlaOp. Unlike
   // OpKernelContext::Input returns a symbolic value rather than a concrete
   // Tensor.
-  const xla::XlaOp& Input(int index);
+  xla::XlaOp Input(int index);
   // Returns input `name` as a XlaOp.
-  const xla::XlaOp& Input(absl::string_view name);
+  xla::XlaOp Input(absl::string_view name);
 
   // Returns true if all inputs are the same shape, otherwise sets the
   // status to a non-OK value and returns false.
@@ -142,6 +142,10 @@ class XlaOpKernelContext {
   Status ConstantInputList(absl::string_view name,
                            std::vector<xla::Literal>* literals);
 
+  // Returns an XlaExpression describing the value of 'index'.
+  const XlaExpression& InputExpression(int index);
+  const XlaExpression& InputExpression(absl::string_view name);
+
   // Outputs
 
   int num_outputs() const { return context_->num_outputs(); }
@@ -159,9 +163,8 @@ class XlaOpKernelContext {
   // SetConstantOutput where possible.
   void SetConstantOutput(int index, const Tensor& host_tensor);
 
-  // Sets output `index` to an invalid value.
-  // Any subsequent attempt to consume this output will cause an error.
-  void SetInvalidOutput(int index);
+  // Returns an XlaExpression describing the value of 'index'.
+  void SetOutputExpression(int index, const XlaExpression& expression);
 
   // Status handling.
   void SetStatus(const Status& status) { context_->SetStatus(status); }
@@ -249,11 +252,6 @@ class XlaOpKernelContext {
   // Returns the tensor of input `name`.
   const Tensor& GetInputTensorByName(absl::string_view name);
 
-  // Wraps OpKernelContext's allocate_output method while providing special
-  // behavior for DT_VARIANT: a variant is treated as DT_UINT8 scalar as the
-  // type to allow mapping for variant to more generic types.
-  Status allocate_output(int index, const xla::Shape& shape, Tensor** output);
-
   // Evaluates input `index`, reshapes it to `new_shape` if new_shape !=
   // InputShape(index), and stores it in `*constant_literal`. If the input
   // cannot be evaluated, e.g., because it depends on unbound parameters,