[TF2XLA] Support must-be-constant resource variables for compilation
Performs an explicit copy at runtime from device to host if needed. PiperOrigin-RevId: 341491694 Change-Id: If4a6c0c76a1110637a06e96595c6013c8fac17e5
This commit is contained in:
parent
aed0ad4916
commit
38c53e2f59
tensorflow
compiler
jit
get_compiler_ir.cc
kernels
xla_compilation_cache.ccxla_compile_on_demand_op.ccxla_launch_util.ccxla_launch_util.htf2xla
python/eager
@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_arg_indices, inputs, variable_infos);
|
||||
constant_arg_indices, inputs, variable_infos, dev);
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
|
||||
switch (stage) {
|
||||
|
@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
|
||||
may_alias_resource_update;
|
||||
|
||||
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
|
||||
variable_infos);
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constants, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
return cache->Compile(options, function, *args, compile_options,
|
||||
lazy ? XlaCompilationCache::CompileMode::kLazy
|
||||
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
|
||||
se::Stream* stream =
|
||||
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
|
||||
|
||||
VLOG(1) << "Executing XLA Computation...";
|
||||
|
||||
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
|
||||
se::DeviceMemoryAllocator* allocator = GetAllocator(
|
||||
&tf_allocator_adapter, ctx->device(),
|
||||
|
@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
|
||||
for (const XlaCompiler::Argument& arg : args) {
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstant:
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
signature.arg_values.push_back(arg.constant_value);
|
||||
break;
|
||||
case XlaCompiler::Argument::kParameter:
|
||||
|
@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
|
||||
ctx, variables_indices, variable_infos, variable_args));
|
||||
|
||||
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
constant_input_indices, inputs, variable_infos);
|
||||
constant_input_indices, inputs, variable_infos,
|
||||
static_cast<Device*>(ctx->device()));
|
||||
TF_RETURN_IF_ERROR(args.status());
|
||||
}
|
||||
|
||||
|
@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args) {
|
||||
absl::Span<VariableInfo const> variable_args, Device* device) {
|
||||
CHECK(absl::c_is_sorted(must_be_constant_idxs));
|
||||
std::vector<XlaCompiler::Argument> out;
|
||||
out.resize(inputs.size());
|
||||
|
||||
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
|
||||
DeviceContext* device_context = nullptr;
|
||||
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
|
||||
bool using_default_context = false;
|
||||
auto cleanup = xla::MakeCleanup([&] {
|
||||
if (device_context != nullptr && !using_default_context) {
|
||||
device_context->Unref();
|
||||
}
|
||||
});
|
||||
if (device_context == nullptr) {
|
||||
using_default_context = true;
|
||||
auto* dev_info = device->tensorflow_gpu_device_info();
|
||||
if (dev_info) device_context = dev_info->default_context;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
|
||||
for (const VariableInfo& info : variable_args) {
|
||||
CHECK(!info.var() || info.lock_held())
|
||||
@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
const Tensor* input = inputs[input_num];
|
||||
|
||||
XlaCompiler::Argument& arg = out[input_num];
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
// Handles compile-time constants.
|
||||
|
||||
// TODO(b/157241314): Support constants located in resource variables.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
|
||||
<< "tf2xla bridge does not support must-be-constants located in "
|
||||
"resource variables; try moving them to a tensor";
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else if (variable_info_lookup.count(input_num)) {
|
||||
if (variable_info_lookup.count(input_num)) {
|
||||
// Handles resource variables.
|
||||
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
|
||||
const VariableInfo& variable = *variable_info_lookup[input_num];
|
||||
@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
|
||||
arg.type = DT_INVALID;
|
||||
arg.shape = TensorShape();
|
||||
}
|
||||
|
||||
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
|
||||
const Tensor* value = variable.var()->tensor();
|
||||
Tensor value_on_host(value->dtype(), value->shape());
|
||||
if (!device_context) {
|
||||
value_on_host = *value;
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
|
||||
value, "", device, &value_on_host));
|
||||
}
|
||||
arg.kind = XlaCompiler::Argument::kConstantResource;
|
||||
arg.constant_value = value_on_host;
|
||||
}
|
||||
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.type = input->dtype();
|
||||
arg.shape = input->shape();
|
||||
arg.constant_value = *input;
|
||||
} else {
|
||||
// Normal inputs.
|
||||
TF_RET_CHECK(input->dtype() != DT_RESOURCE);
|
||||
|
@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
|
||||
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
|
||||
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
|
||||
absl::Span<const Tensor* const> inputs,
|
||||
absl::Span<VariableInfo const> variable_args);
|
||||
absl::Span<VariableInfo const> variable_args,
|
||||
Device* device);
|
||||
|
||||
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
|
||||
// `variables` is a map from TensorFlow argument number to resource variable.
|
||||
|
@ -73,7 +73,7 @@ Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph,
|
||||
switch (expressions[i]->kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
arg.kind = XlaCompiler::Argument::kConstant;
|
||||
arg.constant_value = expressions[i]->constant_value();
|
||||
arg.constant_value = *expressions[i]->constant_value();
|
||||
break;
|
||||
case XlaExpression::Kind::kXlaOp:
|
||||
if (arg_must_be_compile_time_constant[i]) {
|
||||
|
@ -39,6 +39,9 @@ struct XlaArgument {
|
||||
// associated runtime parameter iff `initialized` is true.
|
||||
kResource,
|
||||
|
||||
// A resource variable with a constant value known at compile time.
|
||||
kConstantResource,
|
||||
|
||||
// Argument is a run-time parameter.
|
||||
kParameter,
|
||||
|
||||
|
@ -207,7 +207,7 @@ Status BuildComputation(
|
||||
switch (retval.kind()) {
|
||||
case XlaExpression::Kind::kConstant:
|
||||
output.is_constant = true;
|
||||
output.constant_value = retval.constant_value();
|
||||
output.constant_value = *retval.constant_value();
|
||||
output.shape = output.constant_value.shape();
|
||||
break;
|
||||
|
||||
@ -446,6 +446,9 @@ string XlaCompiler::Argument::HumanString() const {
|
||||
case kConstant:
|
||||
return absl::StrCat("kind=constant", common,
|
||||
" value=", constant_value.DebugString());
|
||||
case kConstantResource:
|
||||
return absl::StrCat("kind=constant-resource", common,
|
||||
" value=", constant_value.DebugString());
|
||||
case kResource: {
|
||||
string output = absl::StrCat(
|
||||
"kind=resource", common,
|
||||
@ -856,6 +859,7 @@ Status XlaCompiler::XLAShapeForArgument(
|
||||
*xla_shape = absl::get<xla::Shape>(arg.shape);
|
||||
return Status::OK();
|
||||
}
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
|
||||
@ -959,6 +963,7 @@ Status XlaCompiler::BuildArguments(
|
||||
const XlaCompiler::Argument& arg = args[i];
|
||||
XlaExpression& arg_expression = (*arg_expressions)[i];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.resource_kind != XlaResource::kInvalid);
|
||||
TF_RET_CHECK(absl::holds_alternative<TensorShape>(arg.shape));
|
||||
@ -971,7 +976,10 @@ Status XlaCompiler::BuildArguments(
|
||||
/*max_array_size=*/arg.max_array_size,
|
||||
/*tensor_array_gradients=*/arg.tensor_array_gradients,
|
||||
/*tensor_array_multiple_writes_aggregate=*/true));
|
||||
arg_expression = XlaExpression::Resource(resource);
|
||||
arg_expression =
|
||||
arg.kind == XlaCompiler::Argument::kResource
|
||||
? XlaExpression::Resource(resource)
|
||||
: XlaExpression::ConstantResource(arg.constant_value, resource);
|
||||
if (arg.initialized) {
|
||||
input_to_args->push_back(i);
|
||||
}
|
||||
@ -1124,6 +1132,7 @@ Status XlaCompiler::BuildArguments(
|
||||
arg_shardings.at(i).DebugString()));
|
||||
XlaExpression& arg_expression = (*arg_expressions)[input_to_args->at(i)];
|
||||
switch (arg.kind) {
|
||||
case XlaCompiler::Argument::kConstantResource:
|
||||
case XlaCompiler::Argument::kResource: {
|
||||
TF_RET_CHECK(arg.initialized);
|
||||
XlaResource* resource = arg_expression.resource();
|
||||
|
@ -38,6 +38,16 @@ XlaExpression XlaExpression::Constant(Tensor value) {
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::ConstantResource(Tensor value,
|
||||
XlaResource* resource) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kResource;
|
||||
e.dtype_ = DT_RESOURCE;
|
||||
e.resource_ = resource;
|
||||
e.constant_value_ = value;
|
||||
return e;
|
||||
}
|
||||
|
||||
XlaExpression XlaExpression::XlaOp(xla::XlaOp value, DataType dtype) {
|
||||
XlaExpression e;
|
||||
e.kind_ = Kind::kXlaOp;
|
||||
@ -83,7 +93,7 @@ xla::XlaOp XlaExpression::AsXlaOp(xla::XlaBuilder* builder) const {
|
||||
case Kind::kConstant: {
|
||||
xla::BorrowingLiteral literal;
|
||||
TF_RETURN_IF_ERROR(
|
||||
HostTensorToBorrowingLiteral(constant_value_, &literal));
|
||||
HostTensorToBorrowingLiteral(*constant_value_, &literal));
|
||||
return xla::ConstantLiteral(builder, literal);
|
||||
}
|
||||
case Kind::kTensorList:
|
||||
@ -106,7 +116,7 @@ xla::StatusOr<Tensor> XlaExpression::ResolveDynamism(
|
||||
switch (kind()) {
|
||||
case Kind::kConstant: {
|
||||
// Constant values are considered static.
|
||||
Tensor constant_false(DT_BOOL, constant_value().shape());
|
||||
Tensor constant_false(DT_BOOL, constant_value()->shape());
|
||||
auto flat = constant_false.flat<bool>();
|
||||
for (int64 i = 0; i < flat.size(); ++i) flat(i) = false;
|
||||
return constant_false;
|
||||
@ -147,13 +157,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::Client* client, bool dynamic_dimension_is_minus_one) const {
|
||||
switch (kind()) {
|
||||
case Kind::kConstant:
|
||||
return {constant_value()};
|
||||
case Kind::kResource:
|
||||
return constant_value();
|
||||
case Kind::kXlaOp:
|
||||
break;
|
||||
case Kind::kTensorList:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case Kind::kResource:
|
||||
TF_FALLTHROUGH_INTENDED;
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"ResolveConstant called on XlaExpression: ", HumanString());
|
||||
@ -187,7 +196,12 @@ xla::StatusOr<absl::optional<Tensor>> XlaExpression::ResolveConstant(
|
||||
xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||
switch (kind_) {
|
||||
case Kind::kConstant:
|
||||
return constant_value().shape();
|
||||
return constant_value()->shape();
|
||||
case Kind::kResource:
|
||||
if (constant_value()) {
|
||||
return constant_value()->shape();
|
||||
}
|
||||
return TensorShape({});
|
||||
case Kind::kXlaOp: {
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape xla_shape,
|
||||
handle().builder()->GetShape(handle()));
|
||||
@ -197,8 +211,6 @@ xla::StatusOr<TensorShape> XlaExpression::GetShape() const {
|
||||
}
|
||||
case Kind::kTensorList:
|
||||
return TensorShape({});
|
||||
case Kind::kResource:
|
||||
return TensorShape({});
|
||||
case Kind::kInvalid:
|
||||
return errors::InvalidArgument(
|
||||
"GetShape() called on invalid XlaExpression");
|
||||
|
@ -74,6 +74,9 @@ class XlaExpression {
|
||||
// Builds a resource expression.
|
||||
static XlaExpression Resource(XlaResource* resource);
|
||||
|
||||
// Builds a resource whose value is known at a compile time.
|
||||
static XlaExpression ConstantResource(Tensor value, XlaResource* resource);
|
||||
|
||||
Kind kind() const { return kind_; }
|
||||
|
||||
DataType dtype() const { return dtype_; }
|
||||
@ -81,7 +84,15 @@ class XlaExpression {
|
||||
// handle() returns the XlaOp that backs a kXlaOp expression.
|
||||
const xla::XlaOp& handle() const { return handle_; }
|
||||
|
||||
const Tensor& constant_value() const { return constant_value_; }
|
||||
// Return a constant value associated with this expression. Always set for
|
||||
// constants, might be set for resources.
|
||||
absl::optional<Tensor> constant_value() const {
|
||||
if (kind_ == Kind::kResource && resource_->IsOverwritten()) {
|
||||
// The constant is no longer available if the value was overwritten.
|
||||
return absl::nullopt;
|
||||
}
|
||||
return constant_value_;
|
||||
}
|
||||
|
||||
XlaResource* resource() const { return resource_; }
|
||||
|
||||
@ -124,8 +135,8 @@ class XlaExpression {
|
||||
// a tuple expression if kind_ == kTensorList.
|
||||
xla::XlaOp handle_;
|
||||
|
||||
// The value of the constant, if kind_ == kConstant.
|
||||
Tensor constant_value_;
|
||||
// The value of the constant, if available.
|
||||
absl::optional<Tensor> constant_value_;
|
||||
|
||||
// The resource, if kind_ == kResource. Not owned.
|
||||
XlaResource* resource_ = nullptr;
|
||||
|
@ -110,8 +110,10 @@ TEST_F(XlaExpressionTest, GetShape) {
|
||||
TEST_F(XlaExpressionTest, ResolveConstant) {
|
||||
EXPECT_FALSE(XlaExpression().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(XlaExpression::Invalid().ResolveConstant(client_).ok());
|
||||
EXPECT_FALSE(
|
||||
XlaExpression::Resource(resource_.get()).ResolveConstant(client_).ok());
|
||||
|
||||
EXPECT_FALSE(XlaExpression::Resource(resource_.get())
|
||||
.ResolveConstant(client_)
|
||||
->has_value());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
absl::optional<Tensor> op_constant,
|
||||
@ -131,5 +133,17 @@ TEST_F(XlaExpressionTest, ResolveConstant) {
|
||||
test::ExpectTensorEqual<int32>(constant_, *constant_constant);
|
||||
}
|
||||
|
||||
TEST_F(XlaExpressionTest, ResolveConstantOnResource) {
|
||||
XlaExpression constant_resource =
|
||||
XlaExpression::ConstantResource(constant_, resource_.get());
|
||||
EXPECT_TRUE(constant_resource.ResolveConstant(client_).ok());
|
||||
EXPECT_TRUE(resource_->SetZeroValue(builder_.get()).ok());
|
||||
LOG(ERROR) << "Resource is overwritten: " << resource_->IsOverwritten();
|
||||
xla::StatusOr<absl::optional<Tensor>> resolved_constant =
|
||||
constant_resource.ResolveConstant(client_);
|
||||
EXPECT_TRUE(resolved_constant.ok());
|
||||
EXPECT_FALSE(resolved_constant->has_value());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -477,6 +477,13 @@ Status ReadVariableInputTensor(const Tensor& tensor, DataType type,
|
||||
*shape = variable->shape();
|
||||
}
|
||||
|
||||
if (!variable->IsOverwritten() && expression->constant_value()) {
|
||||
TF_ASSIGN_OR_RETURN(xla::Literal literal,
|
||||
HostTensorToLiteral(*expression->constant_value()));
|
||||
*value = xla::ConstantLiteral(ctx->builder(), literal);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(xla::Shape representation_shape,
|
||||
ctx->compiler()->options().shape_representation_fn(
|
||||
variable->shape(), variable->type(),
|
||||
|
@ -116,10 +116,12 @@ Status XlaResource::SetValue(const xla::XlaOp& value) {
|
||||
"' must be initialized with a valid type before use.");
|
||||
}
|
||||
value_ = value;
|
||||
is_overwritten_ = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status XlaResource::SetZeroValue(xla::XlaBuilder* builder) {
|
||||
is_overwritten_ = true;
|
||||
if (type_ == DT_INVALID) {
|
||||
return errors::InvalidArgument(
|
||||
"Resource '", name_,
|
||||
|
@ -135,6 +135,8 @@ class XlaResource {
|
||||
Status SetFromPack(const std::set<string>& gradient_sources,
|
||||
const xla::XlaOp& pack, xla::XlaBuilder* builder);
|
||||
|
||||
bool IsOverwritten() { return is_overwritten_; }
|
||||
|
||||
// TensorArray and Stack specific fields
|
||||
// TODO(phawkins): refactor this code to use subclasses, rather than putting
|
||||
// kind-specific fields in XlaResource.
|
||||
@ -179,6 +181,7 @@ class XlaResource {
|
||||
bool tensor_array_multiple_writes_aggregate_ = false;
|
||||
|
||||
std::map<string, std::unique_ptr<XlaResource>> tensor_array_gradients_;
|
||||
bool is_overwritten_ = false;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -656,6 +656,77 @@ class DefFunctionTest(xla_test.XLATestCase):
|
||||
self.assertIn('tuple',
|
||||
f.experimental_get_compiler_ir(l)())
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||
'support getting constants out of resources')
|
||||
def testGetConstantOutOfResourceVariable(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
# Use floats to force device placement.
|
||||
a = variables.Variable(50.0)
|
||||
b = variables.Variable(2.0)
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x):
|
||||
return array_ops.reshape(
|
||||
x, [math_ops.cast(a, dtypes.int32),
|
||||
math_ops.cast(b, dtypes.int32)])
|
||||
|
||||
# OK since the value is known at compile time.
|
||||
out = f(random_ops.random_normal([10, 10]))
|
||||
self.assertEqual(out.shape[0], 50)
|
||||
self.assertEqual(out.shape[1], 2)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||
'support getting constants out of resources')
|
||||
def testGetConstantOutOfResourceVariableAfterWrite(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
# Use floats to force device placement.
|
||||
a = variables.Variable(50.0)
|
||||
b = variables.Variable(2.0)
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x, val1, val2):
|
||||
a.assign(math_ops.cast(val1, dtypes.float32))
|
||||
b.assign(math_ops.cast(val2, dtypes.float32))
|
||||
return array_ops.reshape(
|
||||
x, [math_ops.cast(a, dtypes.int32),
|
||||
math_ops.cast(b, dtypes.int32)])
|
||||
|
||||
val1 = constant_op.constant(2)
|
||||
val2 = constant_op.constant(50)
|
||||
|
||||
# Returns an error, since the value known at compile time was overriden.
|
||||
with self.assertRaisesRegex(errors.InvalidArgumentError,
|
||||
'concrete values at compile time'):
|
||||
f(random_ops.random_normal([10, 10]), val1, val2)
|
||||
|
||||
@test_util.disable_mlir_bridge('TODO(b/172845417): MLIR bridge does not '
|
||||
'support getting constants out of resources')
|
||||
def testGetConstantOutOfResourceVariableBeforeWrite(self):
|
||||
with ops.device('device:{}:0'.format(self.device)):
|
||||
|
||||
# Use floats to force device placement.
|
||||
a = variables.Variable(50.0)
|
||||
b = variables.Variable(2.0)
|
||||
|
||||
@def_function.function(jit_compile=True)
|
||||
def f(x, val1, val2):
|
||||
out = array_ops.reshape(
|
||||
x, [math_ops.cast(a, dtypes.int32),
|
||||
math_ops.cast(b, dtypes.int32)])
|
||||
a.assign(math_ops.cast(val1, dtypes.float32))
|
||||
b.assign(math_ops.cast(val2, dtypes.float32))
|
||||
return out
|
||||
|
||||
val1 = constant_op.constant(2)
|
||||
val2 = constant_op.constant(50)
|
||||
|
||||
# OK since the write happens after the reshape.
|
||||
out = f(random_ops.random_normal([10, 10]), val1, val2)
|
||||
self.assertEqual(out.shape[0], 50)
|
||||
self.assertEqual(out.shape[1], 2)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
ops.enable_eager_execution()
|
||||
|
Loading…
Reference in New Issue
Block a user