From 38c53e2f5953e8b8fd94ba07de6c4bb2c15b0824 Mon Sep 17 00:00:00 2001 From: George Karpenkov <cheshire@google.com> Date: Mon, 9 Nov 2020 14:57:06 -0800 Subject: [PATCH] [TF2XLA] Support must-be-constant resource variables for compilation Performs an explicit copy at runtime from device to host if needed. PiperOrigin-RevId: 341491694 Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5 --- tensorflow/compiler/jit/get_compiler_ir.cc | 2 +- tensorflow/compiler/jit/kernels/xla_ops.cc | 7 +- .../compiler/jit/xla_compilation_cache.cc | 1 + .../compiler/jit/xla_compile_on_demand_op.cc | 3 +- tensorflow/compiler/jit/xla_launch_util.cc | 49 +++++++++---- tensorflow/compiler/jit/xla_launch_util.h | 3 +- tensorflow/compiler/tf2xla/graph_compiler.cc | 2 +- tensorflow/compiler/tf2xla/xla_argument.h | 3 + tensorflow/compiler/tf2xla/xla_compiler.cc | 13 +++- tensorflow/compiler/tf2xla/xla_expression.cc | 28 +++++--- tensorflow/compiler/tf2xla/xla_expression.h | 17 ++++- .../compiler/tf2xla/xla_expression_test.cc | 18 ++++- tensorflow/compiler/tf2xla/xla_op_kernel.cc | 7 ++ tensorflow/compiler/tf2xla/xla_resource.cc | 2 + tensorflow/compiler/tf2xla/xla_resource.h | 3 + .../python/eager/def_function_xla_jit_test.py | 71 +++++++++++++++++++ 16 files changed, 193 insertions(+), 36 deletions(-) diff --git a/tensorflow/compiler/jit/get_compiler_ir.cc b/tensorflow/compiler/jit/get_compiler_ir.cc index 08b3bea1084..1685bec6706 100644 --- a/tensorflow/compiler/jit/get_compiler_ir.cc +++ b/tensorflow/compiler/jit/get_compiler_ir.cc @@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr( xla::StatusOr<std::vector<XlaCompiler::Argument>> args = XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_arg_indices, inputs, variable_infos); + constant_arg_indices, inputs, variable_infos, dev); TF_RETURN_IF_ERROR(args.status()); switch (stage) { diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 0f0f43cbad6..563423b7755 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -206,8 +206,9 @@ static Status CompileToLocalExecutable( may_alias_resource_update; xla::StatusOr<std::vector<XlaCompiler::Argument>> args = - XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs, - variable_infos); + XlaComputationLaunchContext::BuildXlaCompilerArguments( + constants, inputs, variable_infos, + static_cast<Device*>(ctx->device())); TF_RETURN_IF_ERROR(args.status()); return cache->Compile(options, function, *args, compile_options, lazy ? XlaCompilationCache::CompileMode::kLazy @@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) { se::Stream* stream = ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr; - VLOG(1) << "Executing XLA Computation..."; - absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter; se::DeviceMemoryAllocator* allocator = GetAllocator( &tf_allocator_adapter, ctx->device(), diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index ea39331c4fb..6251f0353de 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature( for (const XlaCompiler::Argument& arg : args) { switch (arg.kind) { case XlaCompiler::Argument::kConstant: + case XlaCompiler::Argument::kConstantResource: signature.arg_values.push_back(arg.constant_value); break; case XlaCompiler::Argument::kParameter: diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc index fa32a04a026..4005d0bf0cb 100644 --- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc +++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc @@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile( ctx, variables_indices, variable_infos, variable_args)); args = XlaComputationLaunchContext::BuildXlaCompilerArguments( - constant_input_indices, inputs, variable_infos); + constant_input_indices, inputs, variable_infos, + static_cast<Device*>(ctx->device())); TF_RETURN_IF_ERROR(args.status()); } diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc index 1c5581eb4ab..b7f83301d2d 100644 --- a/tensorflow/compiler/jit/xla_launch_util.cc +++ b/tensorflow/compiler/jit/xla_launch_util.cc @@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>> XlaComputationLaunchContext::BuildXlaCompilerArguments( absl::Span<int const> must_be_constant_idxs, absl::Span<const Tensor* const> inputs, - absl::Span<VariableInfo const> variable_args) { + absl::Span<VariableInfo const> variable_args, Device* device) { CHECK(absl::c_is_sorted(must_be_constant_idxs)); std::vector<XlaCompiler::Argument> out; out.resize(inputs.size()); + // TODO(cheshire): Avoid duplication with framework/op_kernel.h + DeviceContext* device_context = nullptr; + TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context)); + bool using_default_context = false; + auto cleanup = xla::MakeCleanup([&] { + if (device_context != nullptr && !using_default_context) { + device_context->Unref(); + } + }); + if (device_context == nullptr) { + using_default_context = true; + auto* dev_info = device->tensorflow_gpu_device_info(); + if (dev_info) device_context = dev_info->default_context; + } + absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup; for (const VariableInfo& info : variable_args) { CHECK(!info.var() || info.lock_held()) @@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( const Tensor* input = inputs[input_num]; XlaCompiler::Argument& arg = out[input_num]; - if (absl::c_binary_search(must_be_constant_idxs, input_num)) { - // Handles compile-time constants. - - // TODO(b/157241314): Support constants located in resource variables. - TF_RET_CHECK(input->dtype() != DT_RESOURCE) - << "tf2xla bridge does not support must-be-constants located in " - "resource variables; try moving them to a tensor"; - arg.kind = XlaCompiler::Argument::kConstant; - arg.type = input->dtype(); - arg.shape = input->shape(); - arg.constant_value = *input; - } else if (variable_info_lookup.count(input_num)) { + if (variable_info_lookup.count(input_num)) { // Handles resource variables. TF_RET_CHECK(input->dtype() == DT_RESOURCE); const VariableInfo& variable = *variable_info_lookup[input_num]; @@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments( arg.type = DT_INVALID; arg.shape = TensorShape(); } + + if (absl::c_binary_search(must_be_constant_idxs, input_num)) { + TF_RET_CHECK(variable.var() && variable.var()->is_initialized); + const Tensor* value = variable.var()->tensor(); + Tensor value_on_host(value->dtype(), value->shape()); + if (!device_context) { + value_on_host = *value; + } else { + TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync( + value, "", device, &value_on_host)); + } + arg.kind = XlaCompiler::Argument::kConstantResource; + arg.constant_value = value_on_host; + } + } else if (absl::c_binary_search(must_be_constant_idxs, input_num)) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.type = input->dtype(); + arg.shape = input->shape(); + arg.constant_value = *input; } else { // Normal inputs. TF_RET_CHECK(input->dtype() != DT_RESOURCE); diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h index ac085a022c8..8b939365ee5 100644 --- a/tensorflow/compiler/jit/xla_launch_util.h +++ b/tensorflow/compiler/jit/xla_launch_util.h @@ -143,7 +143,8 @@ class XlaComputationLaunchContext { static xla::StatusOr<std::vector<XlaCompiler::Argument>> BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs, absl::Span<const Tensor* const> inputs, - absl::Span<VariableInfo const> variable_args); + absl::Span<VariableInfo const> variable_args, + Device* device); // Add all inputs within `ctx` as XLA arguments (returned by arguments()). // `variables` is a map from TensorFlow argument number to resource variable. diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc index 30a7e94775b..2cf10974176 100644 --- a/tensorflow/compiler/tf2xla/graph_compiler.cc +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, switch (expressions[i]->kind()) { case XlaExpression::Kind::kConstant: arg.kind = XlaCompiler::Argument::kConstant; - arg.constant_value = expressions[i]->constant_value(); + arg.constant_value = *expressions[i]->constant_value(); break; case XlaExpression::Kind::kXlaOp: if (arg_must_be_compile_time_constant[i]) { diff --git a/tensorflow/compiler/tf2xla/xla_argument.h b/tensorflow/compiler/tf2xla/xla_argument.h index e2cd634e1d5..c304c479f87 100644 --- a/tensorflow/compiler/tf2xla/xla_argument.h +++ b/tensorflow/compiler/tf2xla/xla_argument.h @@ -39,6 +39,9 @@ struct XlaArgument { // associated runtime parameter iff `initialized` is true. kResource, + // A resource variable with a constant value known at compile time. + kConstantResource, + // Argument is a run-time parameter. kParameter, diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 3d6a66c6ebc..56a7e9dd5d8 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -207,7 +207,7 @@ Status BuildComputation( switch (retval.kind()) { case XlaExpression::Kind::kConstant: output.is_constant = true; - output.constant_value = retval.constant_value(); + output.constant_value = *retval.constant_value(); output.shape = output.constant_value.shape(); break; @@ -446,6 +446,9 @@ string XlaCompiler::Argument::HumanString() const { case kConstant: return absl::StrCat("kind=constant", common, " value=", constant_value.DebugString()); + case kConstantResource: + return absl::StrCat("kind=constant-resource", common, + " value=", constant_value.DebugString()); case kResource: { string output = absl::StrCat( "kind=resource", common, @@ -856,6 +859,7 @@ Status XlaCompiler::XLAShapeForArgument( *xla_shape = absl::get<xla::Shape>(arg.shape); return Status::OK(); } + case XlaCompiler::Argument::kConstantResource: case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); @@ -959,6 +963,7 @@ Status XlaCompiler::BuildArguments( const XlaCompiler::Argument& arg = args[i]; XlaExpression& arg_expression = (*arg_expressions)[i]; switch (arg.kind) { + case XlaCompiler::Argument::kConstantResource: case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid); TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape)); @@ -971,7 +976,10 @@ Status XlaCompiler::BuildArguments( /*max_array_size=*/arg.max_array_size, /*tensor_array_gradients=*/arg.tensor_array_gradients, /*tensor_array_multiple_writes_aggregate=*/true)); - arg_expression = XlaExpression::Resource(resource); + arg_expression = + arg.kind == XlaCompiler::Argument::kResource + ? XlaExpression::Resource(resource) + : XlaExpression::ConstantResource(arg.constant_value, resource); if (arg.initialized) { input_to_args->push_back(i); } @@ -1124,6 +1132,7 @@ Status XlaCompiler::BuildArguments( arg_shardings.at(i).DebugString())); XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)]; switch (arg.kind) { + case XlaCompiler::Argument::kConstantResource: case XlaCompiler::Argument::kResource: { TF_RET_CHECK(arg.initialized); XlaResource* resource = arg_expression.resource(); diff --git a/tensorflow/compiler/tf2xla/xla_expression.cc b/tensorflow/compiler/tf2xla/xla_expression.cc index f0cc8d26709..40b154b496e 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.cc +++ b/tensorflow/compiler/tf2xla/xla_expression.cc @@ -38,6 +38,16 @@ XlaExpression XlaExpression::Constant(Tensor value) { return e; } +XlaExpression XlaExpression::ConstantResource(Tensor value, + XlaResource* resource) { + XlaExpression e; + e.kind_ = Kind::kResource; + e.dtype_ = DT_RESOURCE; + e.resource_ = resource; + e.constant_value_ = value; + return e; +} + XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) { XlaExpression e; e.kind_ = Kind::kXlaOp; @@ -83,7 +93,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const { case Kind::kConstant: { xla::BorrowingLiteral literal; TF_RETURN_IF_ERROR( - HostTensorToBorrowingLiteral(constant_value_, &literal)); + HostTensorToBorrowingLiteral(*constant_value_, &literal)); return xla::ConstantLiteral(builder, literal); } case Kind::kTensorList: @@ -106,7 +116,7 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism( switch (kind()) { case Kind::kConstant: { // Constant values are considered static. - Tensor constant_false(DT_BOOL, constant_value().shape()); + Tensor constant_false(DT_BOOL, constant_value()->shape()); auto flat = constant_false.flat<bool>(); for (int64 i = 0; i < flat.size(); ++i) flat(i) = false; return constant_false; @@ -147,13 +157,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant( xla::Client* client, bool dynamic_dimension_is_minus_one) const { switch (kind()) { case Kind::kConstant: - return {constant_value()}; + case Kind::kResource: + return constant_value(); case Kind::kXlaOp: break; case Kind::kTensorList: TF_FALLTHROUGH_INTENDED; - case Kind::kResource: - TF_FALLTHROUGH_INTENDED; case Kind::kInvalid: return errors::InvalidArgument( "ResolveConstant called on XlaExpression: ", HumanString()); @@ -187,7 +196,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant( xla::StatusOr<TensorShape> XlaExpression::GetShape() const { switch (kind_) { case Kind::kConstant: - return constant_value().shape(); + return constant_value()->shape(); + case Kind::kResource: + if (constant_value()) { + return constant_value()->shape(); + } + return TensorShape({}); case Kind::kXlaOp: { TF_ASSIGN_OR_RETURN(xla::Shape xla_shape, handle().builder()->GetShape(handle())); @@ -197,8 +211,6 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const { } case Kind::kTensorList: return TensorShape({}); - case Kind::kResource: - return TensorShape({}); case Kind::kInvalid: return errors::InvalidArgument( "GetShape() called on invalid XlaExpression"); diff --git a/tensorflow/compiler/tf2xla/xla_expression.h b/tensorflow/compiler/tf2xla/xla_expression.h index 3546368ff7b..fd6b311ae6e 100644 --- a/tensorflow/compiler/tf2xla/xla_expression.h +++ b/tensorflow/compiler/tf2xla/xla_expression.h @@ -74,6 +74,9 @@ class XlaExpression { // Builds a resource expression. static XlaExpression Resource(XlaResource* resource); + // Builds a resource whose value is known at a compile time. + static XlaExpression ConstantResource(Tensor value, XlaResource* resource); + Kind kind() const { return kind_; } DataType dtype() const { return dtype_; } @@ -81,7 +84,15 @@ class XlaExpression { // handle() returns the XlaOp that backs a kXlaOp expression. const xla::XlaOp& handle() const { return handle_; } - const Tensor& constant_value() const { return constant_value_; } + // Return a constant value associated with this expression. Always set for + // constants, might be set for resources. + absl::optional<Tensor> constant_value() const { + if (kind_ == Kind::kResource && resource_->IsOverwritten()) { + // The constant is no longer available if the value was overwritten. + return absl::nullopt; + } + return constant_value_; + } XlaResource* resource() const { return resource_; } @@ -124,8 +135,8 @@ class XlaExpression { // a tuple expression if kind_ == kTensorList. xla::XlaOp handle_; - // The value of the constant, if kind_ == kConstant. - Tensor constant_value_; + // The value of the constant, if available. + absl::optional<Tensor> constant_value_; // The resource, if kind_ == kResource. Not owned. XlaResource* resource_ = nullptr; diff --git a/tensorflow/compiler/tf2xla/xla_expression_test.cc b/tensorflow/compiler/tf2xla/xla_expression_test.cc index 84202c93139..6e4c4cf675f 100644 --- a/tensorflow/compiler/tf2xla/xla_expression_test.cc +++ b/tensorflow/compiler/tf2xla/xla_expression_test.cc @@ -110,8 +110,10 @@ TEST_F(XlaExpressionTest, GetShape) { TEST_F(XlaExpressionTest, ResolveConstant) { EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok()); EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok()); - EXPECT_FALSE( - XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok()); + + EXPECT_FALSE(XlaExpression::Resource(resource_.get()) + .ResolveConstant(client_) + ->has_value()); TF_ASSERT_OK_AND_ASSIGN( absl::optional<Tensor> op_constant, @@ -131,5 +133,17 @@ TEST_F(XlaExpressionTest, ResolveConstant) { test::ExpectTensorEqual<int32>(constant_, *constant_constant); } +TEST_F(XlaExpressionTest, ResolveConstantOnResource) { + XlaExpression constant_resource = + XlaExpression::ConstantResource(constant_, resource_.get()); + EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok()); + EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok()); + LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten(); + xla::StatusOr<absl::optional<Tensor>> resolved_constant = + constant_resource.ResolveConstant(client_); + EXPECT_TRUE(resolved_constant.ok()); + EXPECT_FALSE(resolved_constant->has_value()); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index c2d1906e47a..1d382fe5b9c 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -477,6 +477,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type, *shape = variable->shape(); } + if (!variable->IsOverwritten() && expression->constant_value()) { + TF_ASSIGN_OR_RETURN(xla::Literal literal, + HostTensorToLiteral(*expression->constant_value())); + *value = xla::ConstantLiteral(ctx->builder(), literal); + return Status::OK(); + } + TF_ASSIGN_OR_RETURN(xla::Shape representation_shape, ctx->compiler()->options().shape_representation_fn( variable->shape(), variable->type(), diff --git a/tensorflow/compiler/tf2xla/xla_resource.cc b/tensorflow/compiler/tf2xla/xla_resource.cc index bec0b46611d..8730c6dad54 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.cc +++ b/tensorflow/compiler/tf2xla/xla_resource.cc @@ -116,10 +116,12 @@ Status XlaResource::SetValue(const xla::XlaOp& value) { "' must be initialized with a valid type before use."); } value_ = value; + is_overwritten_ = true; return Status::OK(); } Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) { + is_overwritten_ = true; if (type_ == DT_INVALID) { return errors::InvalidArgument( "Resource '", name_, diff --git a/tensorflow/compiler/tf2xla/xla_resource.h b/tensorflow/compiler/tf2xla/xla_resource.h index ab3a5bdd9bc..d7b9d2f16d3 100644 --- a/tensorflow/compiler/tf2xla/xla_resource.h +++ b/tensorflow/compiler/tf2xla/xla_resource.h @@ -135,6 +135,8 @@ class XlaResource { Status SetFromPack(const std::set<string>& gradient_sources, const xla::XlaOp& pack, xla::XlaBuilder* builder); + bool IsOverwritten() { return is_overwritten_; } + // TensorArray and Stack specific fields // TODO(phawkins): refactor this code to use subclasses, rather than putting // kind-specific fields in XlaResource. @@ -179,6 +181,7 @@ class XlaResource { bool tensor_array_multiple_writes_aggregate_ = false; std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_; + bool is_overwritten_ = false; }; } // namespace tensorflow diff --git a/tensorflow/python/eager/def_function_xla_jit_test.py b/tensorflow/python/eager/def_function_xla_jit_test.py index 5820bec31be..281ff142dd6 100644 --- a/tensorflow/python/eager/def_function_xla_jit_test.py +++ b/tensorflow/python/eager/def_function_xla_jit_test.py @@ -656,6 +656,77 @@ class DefFunctionTest(xla_test.XLATestCase): self.assertIn('tuple', f.experimental_get_compiler_ir(l)()) + @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' + 'support getting constants out of resources') + def testGetConstantOutOfResourceVariable(self): + with ops.device('device:{}:0'.format(self.device)): + + # Use floats to force device placement. + a = variables.Variable(50.0) + b = variables.Variable(2.0) + + @def_function.function(jit_compile=True) + def f(x): + return array_ops.reshape( + x, [math_ops.cast(a, dtypes.int32), + math_ops.cast(b, dtypes.int32)]) + + # OK since the value is known at compile time. + out = f(random_ops.random_normal([10, 10])) + self.assertEqual(out.shape[0], 50) + self.assertEqual(out.shape[1], 2) + + @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' + 'support getting constants out of resources') + def testGetConstantOutOfResourceVariableAfterWrite(self): + with ops.device('device:{}:0'.format(self.device)): + + # Use floats to force device placement. + a = variables.Variable(50.0) + b = variables.Variable(2.0) + + @def_function.function(jit_compile=True) + def f(x, val1, val2): + a.assign(math_ops.cast(val1, dtypes.float32)) + b.assign(math_ops.cast(val2, dtypes.float32)) + return array_ops.reshape( + x, [math_ops.cast(a, dtypes.int32), + math_ops.cast(b, dtypes.int32)]) + + val1 = constant_op.constant(2) + val2 = constant_op.constant(50) + + # Returns an error, since the value known at compile time was overriden. + with self.assertRaisesRegex(errors.InvalidArgumentError, + 'concrete values at compile time'): + f(random_ops.random_normal([10, 10]), val1, val2) + + @test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not ' + 'support getting constants out of resources') + def testGetConstantOutOfResourceVariableBeforeWrite(self): + with ops.device('device:{}:0'.format(self.device)): + + # Use floats to force device placement. + a = variables.Variable(50.0) + b = variables.Variable(2.0) + + @def_function.function(jit_compile=True) + def f(x, val1, val2): + out = array_ops.reshape( + x, [math_ops.cast(a, dtypes.int32), + math_ops.cast(b, dtypes.int32)]) + a.assign(math_ops.cast(val1, dtypes.float32)) + b.assign(math_ops.cast(val2, dtypes.float32)) + return out + + val1 = constant_op.constant(2) + val2 = constant_op.constant(50) + + # OK since the write happens after the reshape. + out = f(random_ops.random_normal([10, 10]), val1, val2) + self.assertEqual(out.shape[0], 50) + self.assertEqual(out.shape[1], 2) + if __name__ == '__main__': ops.enable_eager_execution()