Adding convenience functions for assigning and reading resource variables.
PiperOrigin-RevId: 316807134 Change-Id: Ia37fdf8299064fe1a471eb0b34568e85164808ce
This commit is contained in:
parent
1444b6fbbd
commit
62beb0fc66
@ -58,6 +58,36 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value) {
|
||||
AbstractOpPtr assign_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle));
|
||||
TF_RETURN_IF_ERROR(assign_op->AddInput(value));
|
||||
|
||||
int num_retvals = 0;
|
||||
TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals));
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output) {
|
||||
AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr));
|
||||
TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype));
|
||||
TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle));
|
||||
|
||||
AbstractTensorHandleInterface* value = nullptr;
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(
|
||||
read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals));
|
||||
output->reset(value);
|
||||
return Status();
|
||||
}
|
||||
|
||||
Status DestroyResource(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* handle) {
|
||||
AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation());
|
||||
|
@ -34,6 +34,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx,
|
||||
DataType dtype, TensorShape shape,
|
||||
AbstractTensorHandlePtr* handle);
|
||||
|
||||
// Executes an AssignVariableOp using `ctx`, assigning the variable associated
|
||||
// with `variable_handle` with `value`. `dtype` must be the datatype of the
|
||||
// underlying variable for `variable_handle`. Note that it is illegal to assign
|
||||
// a variable to a Tensor with a different dtype than what the variable was
|
||||
// created with.
|
||||
Status AssignVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandleInterface* value);
|
||||
|
||||
// Executes a ReadVariableOp using `ctx`. This reads the underlying variable
|
||||
// value of `variable_handle` and copies the value to `output`. `dtype` must be
|
||||
// the dtype of the variable associated with `variable_handle`.
|
||||
Status ReadVariable(AbstractContextInterface* ctx,
|
||||
AbstractTensorHandleInterface* variable_handle,
|
||||
DataType dtype, AbstractTensorHandlePtr* output);
|
||||
|
||||
// Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to
|
||||
// the cleanup that occurs in a tf.Variable's EagerResourceDeleter:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290
|
||||
|
@ -30,6 +30,13 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context,
|
||||
float value) {
|
||||
AbstractTensorPtr tensor(context->CreateFloatScalar(value));
|
||||
AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get()));
|
||||
return handle;
|
||||
}
|
||||
|
||||
class VariableOpsTest : public ::testing::Test {
|
||||
public:
|
||||
VariableOpsTest()
|
||||
@ -73,5 +80,28 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) {
|
||||
TF_EXPECT_OK(internal::DestroyResource(context(), handle.get()));
|
||||
}
|
||||
|
||||
// Sanity check for handle assignment and reading
|
||||
TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) {
|
||||
// Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor
|
||||
AbstractTensorHandlePtr variable;
|
||||
TF_EXPECT_OK(internal::CreateUninitializedResourceVariable(
|
||||
context(), DT_FLOAT, {}, &variable));
|
||||
|
||||
// Create a Scalar float TensorHandle with value 42, and assign it to
|
||||
// the variable.
|
||||
AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0);
|
||||
TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT,
|
||||
my_value.get()));
|
||||
|
||||
// Read back the value from the variable, and check that it is 42.
|
||||
AbstractTensorHandlePtr read_value_handle;
|
||||
TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT,
|
||||
&read_value_handle));
|
||||
Status status;
|
||||
AbstractTensorPtr read_value(read_value_handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status);
|
||||
EXPECT_FLOAT_EQ(42.0, *static_cast<float*>(read_value->Data()));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
Loading…
x
Reference in New Issue
Block a user