From 38c53e2f5953e8b8fd94ba07de6c4bb2c15b0824 Mon Sep 17 00:00:00 2001
From: George Karpenkov <cheshire@google.com>
Date: Mon, 9 Nov 2020 14:57:06 -0800
Subject: [PATCH] [TF2XLA] Support must-be-constant resource variables for
 compilation

Performs an explicit copy at runtime from device to host if needed.

PiperOrigin-RevId: 341491694
Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5
---
 tensorflow/compiler/jit/get_compiler_ir.cc    |  2 +-
 tensorflow/compiler/jit/kernels/xla_ops.cc    |  7 +-
 .../compiler/jit/xla_compilation_cache.cc     |  1 +
 .../compiler/jit/xla_compile_on_demand_op.cc  |  3 +-
 tensorflow/compiler/jit/xla_launch_util.cc    | 49 +++++++++----
 tensorflow/compiler/jit/xla_launch_util.h     |  3 +-
 tensorflow/compiler/tf2xla/graph_compiler.cc  |  2 +-
 tensorflow/compiler/tf2xla/xla_argument.h     |  3 +
 tensorflow/compiler/tf2xla/xla_compiler.cc    | 13 +++-
 tensorflow/compiler/tf2xla/xla_expression.cc  | 28 +++++---
 tensorflow/compiler/tf2xla/xla_expression.h   | 17 ++++-
 .../compiler/tf2xla/xla_expression_test.cc    | 18 ++++-
 tensorflow/compiler/tf2xla/xla_op_kernel.cc   |  7 ++
 tensorflow/compiler/tf2xla/xla_resource.cc    |  2 +
 tensorflow/compiler/tf2xla/xla_resource.h     |  3 +
 .../python/eager/def_function_xla_jit_test.py | 71 +++++++++++++++++++
 16 files changed, 193 insertions(+), 36 deletions(-)

diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc
index 08b3bea1084..1685bec6706 100644
--- a/tensorflow/compiler/jit/get_compiler_ir.cc
+++ b/tensorflow/compiler/jit/get_compiler_ir.cc
@@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
 
   xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
       XlaComputationLaunchContext::BuildXlaCompilerArguments(
-          constant_arg_indices, inputs, variable_infos);
+          constant_arg_indices, inputs, variable_infos, dev);
   TF_RETURN_IF_ERROR(args.status());
 
   switch (stage) {
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index 0f0f43cbad6..563423b7755 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
                                           may_alias_resource_update;
 
   xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
-      XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
-                                                             variable_infos);
+      XlaComputationLaunchContext::BuildXlaCompilerArguments(
+          constants, inputs, variable_infos,
+          static_cast<Device*>(ctx->device()));
   TF_RETURN_IF_ERROR(args.status());
   return cache->Compile(options, function, *args, compile_options,
                         lazy ? XlaCompilationCache::CompileMode::kLazy
@@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
   se::Stream* stream =
       ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
 
-  VLOG(1) << "Executing XLA Computation...";
-
   absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
   se::DeviceMemoryAllocator* allocator = GetAllocator(
       &tf_allocator_adapter, ctx->device(),
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc
index ea39331c4fb..6251f0353de 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.cc
+++ b/tensorflow/compiler/jit/xla_compilation_cache.cc
@@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
   for (const XlaCompiler::Argument& arg : args) {
     switch (arg.kind) {
       case XlaCompiler::Argument::kConstant:
+      case XlaCompiler::Argument::kConstantResource:
         signature.arg_values.push_back(arg.constant_value);
         break;
       case XlaCompiler::Argument::kParameter:
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index fa32a04a026..4005d0bf0cb 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
         ctx, variables_indices, variable_infos, variable_args));
 
     args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
-        constant_input_indices, inputs, variable_infos);
+        constant_input_indices, inputs, variable_infos,
+        static_cast<Device*>(ctx->device()));
     TF_RETURN_IF_ERROR(args.status());
   }
 
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 1c5581eb4ab..b7f83301d2d 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
 XlaComputationLaunchContext::BuildXlaCompilerArguments(
     absl::Span<int const> must_be_constant_idxs,
     absl::Span<const Tensor* const> inputs,
-    absl::Span<VariableInfo const> variable_args) {
+    absl::Span<VariableInfo const> variable_args, Device* device) {
   CHECK(absl::c_is_sorted(must_be_constant_idxs));
   std::vector<XlaCompiler::Argument> out;
   out.resize(inputs.size());
 
+  // TODO(cheshire): Avoid duplication with framework/op_kernel.h
+  DeviceContext* device_context = nullptr;
+  TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
+  bool using_default_context = false;
+  auto cleanup = xla::MakeCleanup([&] {
+    if (device_context != nullptr && !using_default_context) {
+      device_context->Unref();
+    }
+  });
+  if (device_context == nullptr) {
+    using_default_context = true;
+    auto* dev_info = device->tensorflow_gpu_device_info();
+    if (dev_info) device_context = dev_info->default_context;
+  }
+
   absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
   for (const VariableInfo& info : variable_args) {
     CHECK(!info.var() || info.lock_held())
@@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
     const Tensor* input = inputs[input_num];
 
     XlaCompiler::Argument& arg = out[input_num];
-    if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
-      // Handles compile-time constants.
-
-      // TODO(b/157241314): Support constants located in resource variables.
-      TF_RET_CHECK(input->dtype() != DT_RESOURCE)
-          << "tf2xla bridge does not support must-be-constants located in "
-             "resource variables; try moving them to a tensor";
-      arg.kind = XlaCompiler::Argument::kConstant;
-      arg.type = input->dtype();
-      arg.shape = input->shape();
-      arg.constant_value = *input;
-    } else if (variable_info_lookup.count(input_num)) {
+    if (variable_info_lookup.count(input_num)) {
       // Handles resource variables.
       TF_RET_CHECK(input->dtype() == DT_RESOURCE);
       const VariableInfo& variable = *variable_info_lookup[input_num];
@@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
         arg.type = DT_INVALID;
         arg.shape = TensorShape();
       }
+
+      if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
+        TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
+        const Tensor* value = variable.var()->tensor();
+        Tensor value_on_host(value->dtype(), value->shape());
+        if (!device_context) {
+          value_on_host = *value;
+        } else {
+          TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
+              value, "", device, &value_on_host));
+        }
+        arg.kind = XlaCompiler::Argument::kConstantResource;
+        arg.constant_value = value_on_host;
+      }
+    } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
+      arg.kind = XlaCompiler::Argument::kConstant;
+      arg.type = input->dtype();
+      arg.shape = input->shape();
+      arg.constant_value = *input;
     } else {
       // Normal inputs.
       TF_RET_CHECK(input->dtype() != DT_RESOURCE);
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index ac085a022c8..8b939365ee5 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
   static xla::StatusOr<std::vector<XlaCompiler::Argument>>
   BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
                             absl::Span<const Tensor* const> inputs,
-                            absl::Span<VariableInfo const> variable_args);
+                            absl::Span<VariableInfo const> variable_args,
+                            Device* device);
 
   // Add all inputs within `ctx` as XLA arguments (returned by arguments()).
   // `variables` is a map from TensorFlow argument number to resource variable.
diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc
index 30a7e94775b..2cf10974176 100644
--- a/tensorflow/compiler/tf2xla/graph_compiler.cc
+++ b/tensorflow/compiler/tf2xla/graph_compiler.cc
@@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
     switch (expressions[i]->kind()) {
       case XlaExpression::Kind::kConstant:
         arg.kind = XlaCompiler::Argument::kConstant;
-        arg.constant_value = expressions[i]->constant_value();
+        arg.constant_value = *expressions[i]->constant_value();
         break;
       case XlaExpression::Kind::kXlaOp:
         if (arg_must_be_compile_time_constant[i]) {
diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h
index e2cd634e1d5..c304c479f87 100644
--- a/tensorflow/compiler/tf2xla/xla_argument.h
+++ b/tensorflow/compiler/tf2xla/xla_argument.h
@@ -39,6 +39,9 @@ struct XlaArgument {
     // associated runtime parameter iff `initialized` is true.
     kResource,
 
+    // A resource variable with a constant value known at compile time.
+    kConstantResource,
+
     // Argument is a run-time parameter.
     kParameter,
 
diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc
index 3d6a66c6ebc..56a7e9dd5d8 100644
--- a/tensorflow/compiler/tf2xla/xla_compiler.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiler.cc
@@ -207,7 +207,7 @@ Status BuildComputation(
     switch (retval.kind()) {
       case XlaExpression::Kind::kConstant:
         output.is_constant = true;
-        output.constant_value = retval.constant_value();
+        output.constant_value = *retval.constant_value();
         output.shape = output.constant_value.shape();
         break;
 
@@ -446,6 +446,9 @@ string XlaCompiler::Argument::HumanString() const {
     case kConstant:
       return absl::StrCat("kind=constant", common,
                           " value=", constant_value.DebugString());
+    case kConstantResource:
+      return absl::StrCat("kind=constant-resource", common,
+                          " value=", constant_value.DebugString());
     case kResource: {
       string output = absl::StrCat(
           "kind=resource", common,
@@ -856,6 +859,7 @@ Status XlaCompiler::XLAShapeForArgument(
       *xla_shape = absl::get<xla::Shape>(arg.shape);
       return Status::OK();
     }
+    case XlaCompiler::Argument::kConstantResource:
     case XlaCompiler::Argument::kResource: {
       TF_RET_CHECK(arg.initialized);
 
@@ -959,6 +963,7 @@ Status XlaCompiler::BuildArguments(
     const XlaCompiler::Argument& arg = args[i];
     XlaExpression& arg_expression = (*arg_expressions)[i];
     switch (arg.kind) {
+      case XlaCompiler::Argument::kConstantResource:
       case XlaCompiler::Argument::kResource: {
         TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
         TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
@@ -971,7 +976,10 @@ Status XlaCompiler::BuildArguments(
                 /*max_array_size=*/arg.max_array_size,
                 /*tensor_array_gradients=*/arg.tensor_array_gradients,
                 /*tensor_array_multiple_writes_aggregate=*/true));
-        arg_expression = XlaExpression::Resource(resource);
+        arg_expression =
+            arg.kind == XlaCompiler::Argument::kResource
+                ? XlaExpression::Resource(resource)
+                : XlaExpression::ConstantResource(arg.constant_value, resource);
         if (arg.initialized) {
           input_to_args->push_back(i);
         }
@@ -1124,6 +1132,7 @@ Status XlaCompiler::BuildArguments(
                                    arg_shardings.at(i).DebugString()));
     XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
     switch (arg.kind) {
+      case XlaCompiler::Argument::kConstantResource:
       case XlaCompiler::Argument::kResource: {
         TF_RET_CHECK(arg.initialized);
         XlaResource* resource = arg_expression.resource();
diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc
index f0cc8d26709..40b154b496e 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression.cc
@@ -38,6 +38,16 @@ XlaExpression XlaExpression::Constant(Tensor value) {
   return e;
 }
 
+XlaExpression XlaExpression::ConstantResource(Tensor value,
+                                              XlaResource* resource) {
+  XlaExpression e;
+  e.kind_ = Kind::kResource;
+  e.dtype_ = DT_RESOURCE;
+  e.resource_ = resource;
+  e.constant_value_ = value;
+  return e;
+}
+
 XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
   XlaExpression e;
   e.kind_ = Kind::kXlaOp;
@@ -83,7 +93,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
       case Kind::kConstant: {
         xla::BorrowingLiteral literal;
         TF_RETURN_IF_ERROR(
-            HostTensorToBorrowingLiteral(constant_value_, &literal));
+            HostTensorToBorrowingLiteral(*constant_value_, &literal));
         return xla::ConstantLiteral(builder, literal);
       }
       case Kind::kTensorList:
@@ -106,7 +116,7 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
   switch (kind()) {
     case Kind::kConstant: {
       // Constant values are considered static.
-      Tensor constant_false(DT_BOOL, constant_value().shape());
+      Tensor constant_false(DT_BOOL, constant_value()->shape());
       auto flat = constant_false.flat<bool>();
       for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
       return constant_false;
@@ -147,13 +157,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
     xla::Client* client, bool dynamic_dimension_is_minus_one) const {
   switch (kind()) {
     case Kind::kConstant:
-      return {constant_value()};
+    case Kind::kResource:
+      return constant_value();
     case Kind::kXlaOp:
       break;
     case Kind::kTensorList:
       TF_FALLTHROUGH_INTENDED;
-    case Kind::kResource:
-      TF_FALLTHROUGH_INTENDED;
     case Kind::kInvalid:
       return errors::InvalidArgument(
           "ResolveConstant called on XlaExpression: ", HumanString());
@@ -187,7 +196,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
 xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
   switch (kind_) {
     case Kind::kConstant:
-      return constant_value().shape();
+      return constant_value()->shape();
+    case Kind::kResource:
+      if (constant_value()) {
+        return constant_value()->shape();
+      }
+      return TensorShape({});
     case Kind::kXlaOp: {
       TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
                           handle().builder()->GetShape(handle()));
@@ -197,8 +211,6 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
     }
     case Kind::kTensorList:
       return TensorShape({});
-    case Kind::kResource:
-      return TensorShape({});
     case Kind::kInvalid:
       return errors::InvalidArgument(
           "GetShape() called on invalid XlaExpression");
diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h
index 3546368ff7b..fd6b311ae6e 100644
--- a/tensorflow/compiler/tf2xla/xla_expression.h
+++ b/tensorflow/compiler/tf2xla/xla_expression.h
@@ -74,6 +74,9 @@ class XlaExpression {
   // Builds a resource expression.
   static XlaExpression Resource(XlaResource* resource);
 
+  // Builds a resource whose value is known at a compile time.
+  static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
+
   Kind kind() const { return kind_; }
 
   DataType dtype() const { return dtype_; }
@@ -81,7 +84,15 @@ class XlaExpression {
   // handle() returns the XlaOp that backs a kXlaOp expression.
   const xla::XlaOp& handle() const { return handle_; }
 
-  const Tensor& constant_value() const { return constant_value_; }
+  // Return a constant value associated with this expression. Always set for
+  // constants, might be set for resources.
+  absl::optional<Tensor> constant_value() const {
+    if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
+      // The constant is no longer available if the value was overwritten.
+      return absl::nullopt;
+    }
+    return constant_value_;
+  }
 
   XlaResource* resource() const { return resource_; }
 
@@ -124,8 +135,8 @@ class XlaExpression {
   // a tuple expression if kind_ == kTensorList.
   xla::XlaOp handle_;
 
-  // The value of the constant, if kind_ == kConstant.
-  Tensor constant_value_;
+  // The value of the constant, if available.
+  absl::optional<Tensor> constant_value_;
 
   // The resource, if kind_ == kResource. Not owned.
   XlaResource* resource_ = nullptr;
diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc
index 84202c93139..6e4c4cf675f 100644
--- a/tensorflow/compiler/tf2xla/xla_expression_test.cc
+++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc
@@ -110,8 +110,10 @@ TEST_F(XlaExpressionTest, GetShape) {
 TEST_F(XlaExpressionTest, ResolveConstant) {
   EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
   EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
-  EXPECT_FALSE(
-      XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
+
+  EXPECT_FALSE(XlaExpression::Resource(resource_.get())
+                   .ResolveConstant(client_)
+                   ->has_value());
 
   TF_ASSERT_OK_AND_ASSIGN(
       absl::optional<Tensor> op_constant,
@@ -131,5 +133,17 @@ TEST_F(XlaExpressionTest, ResolveConstant) {
   test::ExpectTensorEqual<int32>(constant_, *constant_constant);
 }
 
+TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
+  XlaExpression constant_resource =
+      XlaExpression::ConstantResource(constant_, resource_.get());
+  EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
+  EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
+  LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
+  xla::StatusOr<absl::optional<Tensor>> resolved_constant =
+      constant_resource.ResolveConstant(client_);
+  EXPECT_TRUE(resolved_constant.ok());
+  EXPECT_FALSE(resolved_constant->has_value());
+}
+
 }  // namespace
 }  // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index c2d1906e47a..1d382fe5b9c 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -477,6 +477,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
     *shape = variable->shape();
   }
 
+  if (!variable->IsOverwritten() && expression->constant_value()) {
+    TF_ASSIGN_OR_RETURN(xla::Literal literal,
+                        HostTensorToLiteral(*expression->constant_value()));
+    *value = xla::ConstantLiteral(ctx->builder(), literal);
+    return Status::OK();
+  }
+
   TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
                       ctx->compiler()->options().shape_representation_fn(
                           variable->shape(), variable->type(),
diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc
index bec0b46611d..8730c6dad54 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.cc
+++ b/tensorflow/compiler/tf2xla/xla_resource.cc
@@ -116,10 +116,12 @@ Status XlaResource::SetValue(const xla::XlaOp& value) {
         "' must be initialized with a valid type before use.");
   }
   value_ = value;
+  is_overwritten_ = true;
   return Status::OK();
 }
 
 Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
+  is_overwritten_ = true;
   if (type_ == DT_INVALID) {
     return errors::InvalidArgument(
         "Resource '", name_,
diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h
index ab3a5bdd9bc..d7b9d2f16d3 100644
--- a/tensorflow/compiler/tf2xla/xla_resource.h
+++ b/tensorflow/compiler/tf2xla/xla_resource.h
@@ -135,6 +135,8 @@ class XlaResource {
   Status SetFromPack(const std::set<string>& gradient_sources,
                      const xla::XlaOp& pack, xla::XlaBuilder* builder);
 
+  bool IsOverwritten() { return is_overwritten_; }
+
   // TensorArray and Stack specific fields
   // TODO(phawkins): refactor this code to use subclasses, rather than putting
   // kind-specific fields in XlaResource.
@@ -179,6 +181,7 @@ class XlaResource {
   bool tensor_array_multiple_writes_aggregate_ = false;
 
   std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
+  bool is_overwritten_ = false;
 };
 
 }  // namespace tensorflow
diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py
index 5820bec31be..281ff142dd6 100644
--- a/tensorflow/python/eager/def_function_xla_jit_test.py
+++ b/tensorflow/python/eager/def_function_xla_jit_test.py
@@ -656,6 +656,77 @@ class DefFunctionTest(xla_test.XLATestCase):
       self.assertIn('tuple',
                     f.experimental_get_compiler_ir(l)())
 
+  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
+                                 'support getting constants out of resources')
+  def testGetConstantOutOfResourceVariable(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      # Use floats to force device placement.
+      a = variables.Variable(50.0)
+      b = variables.Variable(2.0)
+
+      @def_function.function(jit_compile=True)
+      def f(x):
+        return array_ops.reshape(
+            x, [math_ops.cast(a, dtypes.int32),
+                math_ops.cast(b, dtypes.int32)])
+
+      # OK since the value is known at compile time.
+      out = f(random_ops.random_normal([10, 10]))
+      self.assertEqual(out.shape[0], 50)
+      self.assertEqual(out.shape[1], 2)
+
+  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
+                                 'support getting constants out of resources')
+  def testGetConstantOutOfResourceVariableAfterWrite(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      # Use floats to force device placement.
+      a = variables.Variable(50.0)
+      b = variables.Variable(2.0)
+
+      @def_function.function(jit_compile=True)
+      def f(x, val1, val2):
+        a.assign(math_ops.cast(val1, dtypes.float32))
+        b.assign(math_ops.cast(val2, dtypes.float32))
+        return array_ops.reshape(
+            x, [math_ops.cast(a, dtypes.int32),
+                math_ops.cast(b, dtypes.int32)])
+
+      val1 = constant_op.constant(2)
+      val2 = constant_op.constant(50)
+
+      # Returns an error, since the value known at compile time was overriden.
+      with self.assertRaisesRegex(errors.InvalidArgumentError,
+                                  'concrete values at compile time'):
+        f(random_ops.random_normal([10, 10]), val1, val2)
+
+  @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
+                                 'support getting constants out of resources')
+  def testGetConstantOutOfResourceVariableBeforeWrite(self):
+    with ops.device('device:{}:0'.format(self.device)):
+
+      # Use floats to force device placement.
+      a = variables.Variable(50.0)
+      b = variables.Variable(2.0)
+
+      @def_function.function(jit_compile=True)
+      def f(x, val1, val2):
+        out = array_ops.reshape(
+            x, [math_ops.cast(a, dtypes.int32),
+                math_ops.cast(b, dtypes.int32)])
+        a.assign(math_ops.cast(val1, dtypes.float32))
+        b.assign(math_ops.cast(val2, dtypes.float32))
+        return out
+
+      val1 = constant_op.constant(2)
+      val2 = constant_op.constant(50)
+
+      # OK since the write happens after the reshape.
+      out = f(random_ops.random_normal([10, 10]), val1, val2)
+      self.assertEqual(out.shape[0], 50)
+      self.assertEqual(out.shape[1], 2)
+
 
 if __name__ == '__main__':
   ops.enable_eager_execution()