Adding convenience functions for assigning and reading resource variables.

PiperOrigin-RevId: 316807134
Change-Id: Ia37fdf8299064fe1a471eb0b34568e85164808ce
This commit is contained in:
Brian Zhao 2020-06-16 19:48:23 -07:00 committed by TensorFlower Gardener
parent 1444b6fbbd
commit 62beb0fc66
3 changed files with 76 additions and 0 deletions

View File

@ -58,6 +58,36 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
return Status();
}
Status AssignVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value) {
AbstractOpPtr assign_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
int num_retvals = 0;
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
return Status();
}
Status ReadVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output) {
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
AbstractTensorHandleInterface* value = nullptr;
int num_retvals = 1;
TF_RETURN_IF_ERROR(
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
output->reset(value);
return Status();
}
Status DestroyResource(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* handle) {
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());

View File

@ -34,6 +34,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
DataType dtype, TensorShape shape,
AbstractTensorHandlePtr* handle);
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
// with `variable_handle` with `value`. `dtype` must be the datatype of the
// underlying variable for `variable_handle`. Note that it is illegal to assign
// a variable to a Tensor with a different dtype than what the variable was
// created with.
Status AssignVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandleInterface* value);
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
// value of `variable_handle` and copies the value to `output`. `dtype` must be
// the dtype of the variable associated with `variable_handle`.
Status ReadVariable(AbstractContextInterface* ctx,
AbstractTensorHandleInterface* variable_handle,
DataType dtype, AbstractTensorHandlePtr* output);
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290

View File

@ -30,6 +30,13 @@ limitations under the License.
namespace tensorflow {
namespace {
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
float value) {
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
return handle;
}
class VariableOpsTest : public ::testing::Test {
public:
VariableOpsTest()
@ -73,5 +80,28 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
}
// Sanity check for handle assignment and reading
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
AbstractTensorHandlePtr variable;
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
context(), DT_FLOAT, {}, &variable));
// Create a Scalar float TensorHandle with value 42, and assign it to
// the variable.
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
my_value.get()));
// Read back the value from the variable, and check that it is 42.
AbstractTensorHandlePtr read_value_handle;
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
&read_value_handle));
Status status;
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
TF_EXPECT_OK(status);
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
}
} // namespace
} // namespace tensorflow