diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc index 94548e553ad..a3b3ace7be9 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.cc @@ -58,6 +58,36 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, return Status(); } +Status AssignVariable(AbstractContextInterface* ctx, + AbstractTensorHandleInterface* variable_handle, + DataType dtype, AbstractTensorHandleInterface* value) { + AbstractOpPtr assign_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(assign_op->Reset("AssignVariableOp", nullptr)); + TF_RETURN_IF_ERROR(assign_op->SetAttrType("dtype", dtype)); + TF_RETURN_IF_ERROR(assign_op->AddInput(variable_handle)); + TF_RETURN_IF_ERROR(assign_op->AddInput(value)); + + int num_retvals = 0; + TF_RETURN_IF_ERROR(assign_op->Execute({}, &num_retvals)); + return Status(); +} + +Status ReadVariable(AbstractContextInterface* ctx, + AbstractTensorHandleInterface* variable_handle, + DataType dtype, AbstractTensorHandlePtr* output) { + AbstractOpPtr read_op = AbstractOpPtr(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(read_op->Reset("ReadVariableOp", nullptr)); + TF_RETURN_IF_ERROR(read_op->SetAttrType("dtype", dtype)); + TF_RETURN_IF_ERROR(read_op->AddInput(variable_handle)); + + AbstractTensorHandleInterface* value = nullptr; + int num_retvals = 1; + TF_RETURN_IF_ERROR( + read_op->Execute(absl::MakeSpan(&value, num_retvals), &num_retvals)); + output->reset(value); + return Status(); +} + Status DestroyResource(AbstractContextInterface* ctx, AbstractTensorHandleInterface* handle) { AbstractOpPtr destroy_op = AbstractOpPtr(ctx->CreateOperation()); diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h index 1c4d757af8c..8a410328b9e 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops.h @@ -34,6 +34,22 @@ Status CreateUninitializedResourceVariable(AbstractContextInterface* ctx, DataType dtype, TensorShape shape, AbstractTensorHandlePtr* handle); +// Executes an AssignVariableOp using `ctx`, assigning the variable associated +// with `variable_handle` with `value`. `dtype` must be the datatype of the +// underlying variable for `variable_handle`. Note that it is illegal to assign +// a variable to a Tensor with a different dtype than what the variable was +// created with. +Status AssignVariable(AbstractContextInterface* ctx, + AbstractTensorHandleInterface* variable_handle, + DataType dtype, AbstractTensorHandleInterface* value); + +// Executes a ReadVariableOp using `ctx`. This reads the underlying variable +// value of `variable_handle` and copies the value to `output`. `dtype` must be +// the dtype of the variable associated with `variable_handle`. +Status ReadVariable(AbstractContextInterface* ctx, + AbstractTensorHandleInterface* variable_handle, + DataType dtype, AbstractTensorHandlePtr* output); + // Executes DestroyResourceOp on `handle`, using `ctx`. This is equivalent to // the cleanup that occurs in a tf.Variable's EagerResourceDeleter: // https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/ops/resource_variable_ops.py#L289-L290 diff --git a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc index 7a9486f8ebd..3c57ed4d38a 100644 --- a/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc +++ b/tensorflow/c/experimental/saved_model/core/ops/variable_ops_test.cc @@ -30,6 +30,13 @@ limitations under the License. namespace tensorflow { namespace { +AbstractTensorHandlePtr CreateScalarTensorHandle(EagerContext* context, + float value) { + AbstractTensorPtr tensor(context->CreateFloatScalar(value)); + AbstractTensorHandlePtr handle(context->CreateLocalHandle(tensor.get())); + return handle; +} + class VariableOpsTest : public ::testing::Test { public: VariableOpsTest() @@ -73,5 +80,28 @@ TEST_F(VariableOpsTest, DestroyVariableSuccessful) { TF_EXPECT_OK(internal::DestroyResource(context(), handle.get())); } +// Sanity check for handle assignment and reading +TEST_F(VariableOpsTest, AssignVariableAndReadSuccessful) { + // Create a DT_Resource TensorHandle that points to a scalar DT_FLOAT tensor + AbstractTensorHandlePtr variable; + TF_EXPECT_OK(internal::CreateUninitializedResourceVariable( + context(), DT_FLOAT, {}, &variable)); + + // Create a Scalar float TensorHandle with value 42, and assign it to + // the variable. + AbstractTensorHandlePtr my_value = CreateScalarTensorHandle(context(), 42.0); + TF_EXPECT_OK(internal::AssignVariable(context(), variable.get(), DT_FLOAT, + my_value.get())); + + // Read back the value from the variable, and check that it is 42. + AbstractTensorHandlePtr read_value_handle; + TF_EXPECT_OK(internal::ReadVariable(context(), variable.get(), DT_FLOAT, + &read_value_handle)); + Status status; + AbstractTensorPtr read_value(read_value_handle->Resolve(&status)); + TF_EXPECT_OK(status); + EXPECT_FLOAT_EQ(42.0, *static_cast(read_value->Data())); +} + } // namespace } // namespace tensorflow