diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 427420cb17c..e9920c02532 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -248,6 +248,47 @@ tf_cuda_cc_test( ], ) +cc_library( + name = "gradients_testutil", + testonly = True, + srcs = [ + "gradients_testutil.cc", + ], + hdrs = [ + "gradients_testutil.h", + ], + visibility = [ + "//tensorflow:internal", + ], + deps = [ + ":abstract_context", + ":abstract_operation", + ":tape", + ":abstract_tensor_handle", + ":gradients", + ":c_api_experimental", + ":c_api_test_util", + ":c_api_unified_internal", + ":gradients_internal", + "//tensorflow/c/experimental/ops:array_ops", + "//tensorflow/c/experimental/ops:math_ops", + "//tensorflow/c/experimental/ops:nn_ops", + "//tensorflow/c:c_api", + "//tensorflow/c:c_test_util", + "//tensorflow/c:tf_status_helper", + "//tensorflow/cc/profiler", + "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/llvm_rtti", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_absl//absl/container:flat_hash_map", + ], +) + cc_library( name = "mnist_gradients_testutil", srcs = [ @@ -295,6 +336,7 @@ cc_library( ":c_api_unified_internal", ":gradients_internal", ":mnist_gradients_testutil", + ":gradients_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", "//tensorflow/c:tf_status_helper", diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index f8b4827d691..532a0f726c7 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -36,82 +36,7 @@ limitations under the License. using namespace std; -// ================== TensorHandle generating functions ================= - -// Get a scalar TensorHandle with given value -Status TestScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a TensorHandle with given float values and dimensions -Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a TensorHandle with given int values and dimensions -Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return Status::OK(); -} - -AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, - float vals[], int64_t dims[], - int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} - -AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], - int64_t dims[], int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} +// ================== Helper functions ================= // Fills data with values [start,end) with given step size. void Range(int data[], int start, int end, int step = 1) { @@ -146,12 +71,13 @@ Status RunModelAndSum(AbstractContext* ctx, Model forward, std::vector model_outputs(1); // Run the model. - Status s = RunModel(forward, ctx, inputs, absl::MakeSpan(model_outputs), - use_function, registry); + TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs, + absl::MakeSpan(model_outputs), use_function, + registry)); AbstractTensorHandle* f_toSum = model_outputs[0]; TF_Tensor* model_out_tensor; - s = GetValue(f_toSum, &model_out_tensor); + TF_RETURN_IF_ERROR(GetValue(f_toSum, &model_out_tensor)); int num_dims_out = TF_NumDims(model_out_tensor); // Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1]. @@ -163,8 +89,8 @@ Status RunModelAndSum(AbstractContext* ctx, Model forward, sum_inputs[0] = f_toSum; sum_inputs[1] = sum_dims.get(); - s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(model_outputs), - "sum_output"); + TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs), + absl::MakeSpan(model_outputs), "sum_output")); outputs[0] = model_outputs[0]; return Status::OK(); } @@ -182,21 +108,20 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward, } else { s = RunModelAndSum(ctx, forward, inputs, outputs, use_function); } - return Status::OK(); + return s; } -// ========================= End Util Functions============================== +// ========================= End Helper Functions============================== -Status GradientCheck(AbstractContext* ctx, Model forward, - std::vector inputs, - float* dtheta_approx, int gradIndex, bool use_function, - bool is_scalar_out) { - Status s; +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + std::vector inputs, + float* dtheta_approx, int input_index, + bool use_function, bool is_scalar_out) { AbstractTensorHandle* theta = - inputs[gradIndex]; // parameter we are grad checking + inputs[input_index]; // parameter we are grad checking // Convert from AbstractTensor to TF_Tensor. TF_Tensor* theta_tensor; - s = GetValue(theta, &theta_tensor); + TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor)); // Get number of elements and fill data. int num_elems = TF_TensorElementCount(theta_tensor); @@ -219,6 +144,8 @@ Status GradientCheck(AbstractContext* ctx, Model forward, // Get relative epsilon value float epsilon = std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0 + AbstractTensorHandlePtr two_eps = + GetScalarTensorHandleUtil(ctx, 2 * epsilon); // Initialize theta[i] + epsilon. memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor), @@ -235,33 +162,39 @@ Status GradientCheck(AbstractContext* ctx, Model forward, GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims); // Get f(theta + eps): - inputs[gradIndex] = thetaPlus.get(); - s = RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs), - absl::MakeSpan(f_outputs), use_function, is_scalar_out); + inputs[input_index] = thetaPlus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs), + absl::MakeSpan(f_outputs), use_function, + is_scalar_out)); AbstractTensorHandle* fPlus = f_outputs[0]; // Get f(theta - eps): - inputs[gradIndex] = thetaMinus.get(); - s = RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs), - absl::MakeSpan(f_outputs), use_function, is_scalar_out); + inputs[input_index] = thetaMinus.get(); + TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs), + absl::MakeSpan(f_outputs), use_function, + is_scalar_out)); AbstractTensorHandle* fMinus = f_outputs[0]; // Take Difference of both estimates: (f(x + eps) - f(x - eps)). - s = ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"); + TF_RETURN_IF_ERROR( + ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top")); AbstractTensorHandle* fDiff = f_outputs[0]; - // Get difference value for calculation. - TF_Tensor* fDiff_tensor; - s = GetValue(fDiff, &fDiff_tensor); - float fDiff_data[1]; - memcpy(&fDiff_data[0], TF_TensorData(fDiff_tensor), - TF_TensorByteSize(fDiff_tensor)); - // Calculate using the difference quotient definition: // (f(x + eps) - f(x - eps)) / (2 * eps). - float grad_approx = fDiff_data[0] / (2.0 * epsilon); - dtheta_approx[i] = grad_approx; + TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()}, + absl::MakeSpan(f_outputs), + "diff_quotient")); + AbstractTensorHandle* diff_quotient = f_outputs[0]; + + TF_Tensor* grad_tensor; + TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor)); + float grad_data[1]; + memcpy(&grad_data[0], TF_TensorData(grad_tensor), + TF_TensorByteSize(grad_tensor)); + + dtheta_approx[i] = grad_data[0]; } return Status::OK(); -} +} \ No newline at end of file diff --git a/tensorflow/c/eager/gradient_checker.h b/tensorflow/c/eager/gradient_checker.h index 12ae688b62a..8234052b0e6 100644 --- a/tensorflow/c/eager/gradient_checker.h +++ b/tensorflow/c/eager/gradient_checker.h @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/c/eager/mnist_gradients_testutil.h" - #include #include "absl/types/span.h" @@ -24,6 +22,8 @@ limitations under the License. #include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/eager/gradients_testutil.h" +#include "tensorflow/c/eager/mnist_gradients_testutil.h" #include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/ops/array_ops.h" @@ -37,17 +37,19 @@ using Model = std::function, absl::Span, const GradientRegistry&)>; -/** Returns numerical grad inside `dtheta_approx` given `forward` model and parameter - * specified by `gradIndex` - * - * `use_function` indicates whether to use graph mode(true) or eager(false) - * - * `is_scalar_out` should be true when `forward` returns a scalar TensorHandle; - * else GradientCheck will reduce_sum the tensor to get a scalar to estimate - * the gradient with. Default is false. - */ -Status GradientCheck(AbstractContext* ctx, Model forward, - std::vector inputs, - float* dtheta_approx, - int gradIndex, bool use_function, - bool is_scalar_out=false); +/** Returns numerical grad inside `dtheta_approx` given `forward` model and + * parameter specified by `input_index`. + * + * I.e. if y = and w = inputs[input_index], + * this will calculate dy/dw numerically. + * + * `use_function` indicates whether to use graph mode(true) or eager(false). + * + * `is_scalar_out` should be true when `forward` returns a scalar TensorHandle; + * else this function will reduce_sum the tensor to get a scalar to estimate + * the gradient with. Default is false. + */ +Status CalcNumericalGrad(AbstractContext* ctx, Model forward, + std::vector inputs, + float* dtheta_approx, int input_index, + bool use_function, bool is_scalar_out = false); diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index 4eb728b68c2..8effa68b72d 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "tensorflow/c/eager/gradient_checker.h" +// #include "tensorflow/c/eager/gradients_testutil.h" #include @@ -52,85 +53,6 @@ Status RegisterGradients(GradientRegistry* registry) { return Status::OK(); } -// ========================= Test Util Functions ============================== - -// Get a scalar TensorHandle with given value -Status TestScalarTensorHandle(AbstractContext* ctx, float value, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a TensorHandle with given float values and dimensions -Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -// Get a TensorHandle with given int values and dimensions -Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], - int64_t dims[], int num_dims, - AbstractTensorHandle** tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_Context* eager_ctx = - TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_TensorHandle* input_eager = - TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); - *tensor = - unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); - return Status::OK(); -} - -Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_TensorHandle* result_t = - TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); - return Status::OK(); -} - -AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, - float vals[], int64_t dims[], - int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} - -AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], - int64_t dims[], int num_dims) { - AbstractTensorHandlePtr A; - AbstractTensorHandle* a_raw = nullptr; - Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); - A.reset(a_raw); - return A; -} - -// =========================== Start Tests ================================ - TEST_P(GradientCheckerTest, TestGradCheckMatMul) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); @@ -159,9 +81,9 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) { inputs.push_back(B.get()); float dapprox[4] = {0}; - Status s = - GradientCheck(ctx.get(), MatMulModel, inputs, dapprox, /*gradIndex=*/0, - /*use_function=*/!std::get<2>(GetParam())); + Status s = CalcNumericalGrad(ctx.get(), MatMulModel, inputs, dapprox, + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam())); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f}; @@ -171,21 +93,66 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) { } } +TEST_P(GradientCheckerTest, TestGradCheckMul) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + + AbstractContextPtr ctx; + { + AbstractContext* ctx_raw = nullptr; + Status s = + BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ctx.reset(ctx_raw); + } + + AbstractTensorHandlePtr x; + { + AbstractTensorHandle* x_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + x.reset(x_raw); + } + + AbstractTensorHandlePtr y; + { + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + y.reset(y_raw); + } + + // Will perform z = x*y. + // dz/dx = y + + std::vector inputs; + inputs.push_back(x.get()); + inputs.push_back(y.get()); + float dapprox[1] = {0}; + Status s = + CalcNumericalGrad(ctx.get(), MulModel, inputs, dapprox, /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam()), + /*is_scalar_out=*/true); + + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); + ASSERT_NEAR(dapprox[0], 7.0f, /*tolerance=*/1e-3); +} + TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); /** Test to show how to use this API with analytical gradients: - * - * We have `SoftmaxLossGradModel`, which is a wrapper for the - * Softmax analytical gradient found in c/experimental/nn_grads. - * - * We will use the GradientChecker by applying finite differences - * to the forward pass wrapped in `SoftmaxModel` and verify that - * both the analytical and numerical gradients are relatively - * close. - * - */ + * + * We have `SoftmaxLossGradModel`, which is a wrapper for the + * Softmax analytical gradient found in c/experimental/nn_grads. + * + * We will use the GradientChecker by applying finite differences + * to the forward pass wrapped in `SoftmaxModel` and verify that + * both the analytical and numerical gradients are relatively + * close. + * + */ AbstractContextPtr ctx; { @@ -235,8 +202,9 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { // Run numerical gradient approximation using the GradientChecker API. float dapprox[9] = {0}; // Will contain numerical approximation data. - s = GradientCheck(ctx.get(), SoftmaxModel, inputs, dapprox, /*gradIndex=*/0, - /*use_function=*/!std::get<2>(GetParam())); + s = CalcNumericalGrad(ctx.get(), SoftmaxModel, inputs, dapprox, + /*input_index=*/0, + /*use_function=*/!std::get<2>(GetParam())); ASSERT_EQ(errors::OK, s.code()) << s.error_message(); // Now compare the two implementations: @@ -249,8 +217,6 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { TF_DeleteTensor(dX_tensor); } -// TODO(b/160888630): Enable this test with mlir after AddInputList is -// supported. It is needed for AddN op which is used for gradient aggregation. #ifdef PLATFORM_GOOGLE INSTANTIATE_TEST_SUITE_P( UnifiedCAPI, GradientCheckerTest, diff --git a/tensorflow/c/eager/gradients_testutil.cc b/tensorflow/c/eager/gradients_testutil.cc new file mode 100644 index 00000000000..ab0db85fa50 --- /dev/null +++ b/tensorflow/c/eager/gradients_testutil.cc @@ -0,0 +1,271 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/eager/gradients_testutil.h" + +#include + +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +using namespace std; + +// ================== TensorHandle generating functions ================= + +// Get a scalar TensorHandle with given value +Status TestScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given float values and dimensions +Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +// Get a TensorHandle with given int values and dimensions +Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_Context* eager_ctx = + TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_TensorHandle* input_eager = + TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); + *tensor = + unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); + return StatusFromTF_Status(status.get()); +} + +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_TensorHandle* result_t = + TF_AbstractTensorGetEagerTensor(wrap(t), status.get()); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + *result_tensor = TFE_TensorHandleResolve(result_t, status.get()); + return StatusFromTF_Status(status.get()); +} + +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims) { + AbstractTensorHandlePtr A; + AbstractTensorHandle* a_raw = nullptr; + Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); + if (s.ok()) { + A.reset(a_raw); + } + return A; +} + +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val) { + AbstractTensorHandlePtr y; + AbstractTensorHandle* y_raw = nullptr; + Status s = TestScalarTensorHandle(ctx, val, &y_raw); + if (s.ok()) { + y.reset(y_raw); + } + return y; +} + +Status UpdateWeights(AbstractContext* ctx, vector& grads, + vector& weights, + AbstractTensorHandle* learning_rate) { + /* Update weights one by one using gradient update rule: + * + * w -= lr*grad[w] + * + * NOTE: assuming learning rate is positive + */ + + int num_grads = grads.size(); + vector temp_outputs(1); + std::string update_str; + + // Negate learning rate for gradient descent + TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, + absl::MakeSpan(temp_outputs), + "neg_lr")); // Compute -lr + learning_rate = temp_outputs[0]; + + for (int i = 0; i < num_grads; i++) { + // Compute dW = -lr * grad(w[i]) + update_str = "update_mul_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + AbstractTensorHandle* dW = temp_outputs[0]; + + // Compute temp = weights[i] + dW + update_str = "update_add_" + std::to_string(i); + TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW}, + absl::MakeSpan(temp_outputs), + update_str.c_str())); + + // Update the weights + weights[i] = temp_outputs[0]; + } + + return Status::OK(); +} + +AbstractContext* BuildFunction(const char* fn_name) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); + return unwrap(graph_ctx); +} + +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + vector* params) { + tracing::TracingTensorHandle* handle = nullptr; + for (auto input : inputs) { + TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( + input->DataType(), &handle)); + params->emplace_back(handle); + } + return Status::OK(); +} + +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry) { + if (use_function) { + const char* fn_name = "test_fn"; + std::unique_ptr scoped_func; + // Returning null tensors from a tf.function is not supported, so we keep + // track of indices in the model's outputs are nullptr in this set. + // The FunctionDef only outputs the non-null tensors. We later pad the + // function op outputs to have nullptrs at the `null_indices`. + absl::flat_hash_set null_indices; + { + AbstractContextPtr func_ctx(BuildFunction(fn_name)); + vector func_inputs; + func_inputs.reserve(inputs.size()); + TF_RETURN_IF_ERROR( + CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); + vector model_outputs; + model_outputs.resize(outputs.size()); + TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), + absl::MakeSpan(model_outputs), registry)); + for (auto func_input : func_inputs) { + func_input->Unref(); + } + AbstractFunction* func = nullptr; + OutputList output_list; + output_list.expected_num_outputs = 0; + output_list.outputs.reserve(outputs.size()); + for (int i = 0; i < model_outputs.size(); i++) { + if (model_outputs[i]) { + output_list.outputs.emplace_back(model_outputs[i]); + output_list.expected_num_outputs += 1; + } else { + null_indices.insert(i); + } + } + TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) + ->Finalize(&output_list, &func)); + scoped_func.reset(func); + for (auto output : output_list.outputs) { + output->Unref(); + } + TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); + } + + AbstractOperationPtr fn_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); + for (auto input : inputs) { + TF_RETURN_IF_ERROR(fn_op->AddInput(input)); + } + int retvals = outputs.size() - null_indices.size(); + vector fn_outputs(retvals); + TF_RETURN_IF_ERROR(fn_op->Execute( + absl::Span(fn_outputs.data(), fn_outputs.size()), + &retvals)); + int skipped_indices = 0; + for (int i = 0; i < outputs.size(); i++) { + if (!null_indices.contains(i)) { + outputs[i] = fn_outputs[i - skipped_indices]; + } else { + skipped_indices += 1; + } + } + TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); + return Status::OK(); + } else { + return model(ctx, inputs, outputs, registry); + } +} + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetTfrt(opts, use_tfrt); + *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); + TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); + TFE_DeleteContextOptions(opts); + return Status::OK(); +} \ No newline at end of file diff --git a/tensorflow/c/eager/gradients_testutil.h b/tensorflow/c/eager/gradients_testutil.h new file mode 100644 index 00000000000..3866e0598ca --- /dev/null +++ b/tensorflow/c/eager/gradients_testutil.h @@ -0,0 +1,94 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/types/span.h" +#include "tensorflow/c/eager/abstract_tensor_handle.h" +#include "tensorflow/c/eager/c_api_experimental.h" +#include "tensorflow/c/eager/c_api_test_util.h" +#include "tensorflow/c/eager/c_api_unified_experimental.h" +#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" +#include "tensorflow/c/eager/gradients.h" +#include "tensorflow/c/eager/gradients_internal.h" +#include "tensorflow/c/experimental/ops/array_ops.h" +#include "tensorflow/c/experimental/ops/math_ops.h" +#include "tensorflow/c/experimental/ops/nn_ops.h" +#include "tensorflow/c/tf_status_helper.h" +#include "tensorflow/c/tf_tensor.h" +#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/test.h" + +using namespace std; +using namespace tensorflow; +using namespace tensorflow::gradients; +using namespace tensorflow::gradients::internal; + +// Get a scalar TensorHandle with given value. +Status TestScalarTensorHandle(AbstractContext* ctx, float value, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given float values and dimensions. +Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor); + +// Get a TensorHandle with given int values and dimensions. +Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], + int64_t dims[], int num_dims, + AbstractTensorHandle** tensor); + +// Places data from `t` into *result_tensor. +Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx, + float vals[], int64_t dims[], + int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data and dims. +AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[], + int64_t dims[], int num_dims); + +// Util function that wraps an AbstractTensorHandle* with given data. +AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx, + float val); + +// Performs gradient update for each weight using given learning rate. +Status UpdateWeights(AbstractContext* ctx, vector& grads, + vector& weights, + AbstractTensorHandle* learning_rate); + +// Helper function for RunModel to build the function for graph mode. +AbstractContext* BuildFunction(const char* fn_name); + +// Helper function for RunModel to add params for graph mode. +Status CreateParamsForInputs(AbstractContext* ctx, + absl::Span inputs, + vector* params); + +using Model = std::function, + absl::Span, const GradientRegistry&)>; + +// Runs given model in either graph or eager mode depending on value of +// use_function. +Status RunModel(Model model, AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, bool use_function, + const GradientRegistry& registry); + +Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); \ No newline at end of file diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index 0371cb99071..6b502ba9119 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -242,7 +242,8 @@ Status MNISTForwardModel(AbstractContext* ctx, * hidden_layer = tf.nn.relu(mm_out_1) * scores = tf.matmul(hidden_layer,W2) * softmax = - * tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels) + * tf.nn.sparse_softmax_cross_entropy_with_logits(scores, + * y_labels) * return scores, softmax * * Use this convention for inputs: @@ -455,10 +456,9 @@ Status ScalarMulModel(AbstractContext* ctx, } Status MatMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { AbstractTensorHandle* X = inputs[0]; AbstractTensorHandle* W1 = inputs[1]; @@ -476,10 +476,9 @@ Status MatMulModel(AbstractContext* ctx, } Status MulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { AbstractTensorHandle* x = inputs[0]; AbstractTensorHandle* y = inputs[1]; @@ -487,7 +486,7 @@ Status MulModel(AbstractContext* ctx, auto tape = new Tape(/*persistent=*/false); std::vector temp_outputs(1); TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs), - "mul0", registry)); // Compute x*y + "mul0", registry)); // Compute x*y outputs[0] = temp_outputs[0]; @@ -496,169 +495,23 @@ Status MulModel(AbstractContext* ctx, } Status SoftmaxModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry) { - + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry) { AbstractTensorHandle* x = inputs[0]; AbstractTensorHandle* labels = inputs[1]; TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector temp_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), - "sm_loss", registry)); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels}, + absl::MakeSpan(temp_outputs), + "sm_loss", registry)); - outputs[0] = temp_outputs[0]; // loss values + outputs[0] = temp_outputs[0]; // loss values delete tape; return Status::OK(); } // ============================= End Models ================================ - -Status UpdateWeights(AbstractContext* ctx, vector& grads, - vector& weights, - AbstractTensorHandle* learning_rate) { - /* Update weights one by one using gradient update rule: - * - * w -= lr*grad[w] - * - * NOTE: assuming learning rate is positive - */ - - Status s; - int num_grads = grads.size(); - vector temp_outputs(1); - std::string update_str; - - // Negate learning rate for gradient descent - TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate}, - absl::MakeSpan(temp_outputs), - "neg_lr")); // Compute -lr - learning_rate = temp_outputs[0]; - - for (int i = 0; i < num_grads; i++) { - // Compute dW = -lr * grad(w[i]) - update_str = "update_mul_" + std::to_string(i); - s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs), - update_str.c_str()); - - AbstractTensorHandle* dW = temp_outputs[0]; - - // Compute temp = weights[i] + dW - update_str = "update_add_" + std::to_string(i); - s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs), - update_str.c_str()); - - // Update the weights - weights[i] = temp_outputs[0]; - } - - return Status::OK(); -} - -AbstractContext* BuildFunction(const char* fn_name) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get()); - return unwrap(graph_ctx); -} - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - vector* params) { - tracing::TracingTensorHandle* handle = nullptr; - for (auto input : inputs) { - TF_RETURN_IF_ERROR(dyn_cast(ctx)->AddParameter( - input->DataType(), &handle)); - params->emplace_back(handle); - } - return Status::OK(); -} - -Status RunModel(Model model, AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, bool use_function, - const GradientRegistry& registry) { - if (use_function) { - const char* fn_name = "test_fn"; - std::unique_ptr scoped_func; - // Returning null tensors from a tf.function is not supported, so we keep - // track of indices in the model's outputs are nullptr in this set. - // The FunctionDef only outputs the non-null tensors. We later pad the - // function op outputs to have nullptrs at the `null_indices`. - absl::flat_hash_set null_indices; - { - AbstractContextPtr func_ctx(BuildFunction(fn_name)); - vector func_inputs; - func_inputs.reserve(inputs.size()); - TF_RETURN_IF_ERROR( - CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs)); - vector model_outputs; - model_outputs.resize(outputs.size()); - TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs), - absl::MakeSpan(model_outputs), registry)); - for (auto func_input : func_inputs) { - func_input->Unref(); - } - AbstractFunction* func = nullptr; - OutputList output_list; - output_list.expected_num_outputs = 0; - output_list.outputs.reserve(outputs.size()); - for (int i = 0; i < model_outputs.size(); i++) { - if (model_outputs[i]) { - output_list.outputs.emplace_back(model_outputs[i]); - output_list.expected_num_outputs += 1; - } else { - null_indices.insert(i); - } - } - TF_RETURN_IF_ERROR(dyn_cast(func_ctx.get()) - ->Finalize(&output_list, &func)); - scoped_func.reset(func); - for (auto output : output_list.outputs) { - output->Unref(); - } - TF_RETURN_IF_ERROR(ctx->RegisterFunction(func)); - } - - AbstractOperationPtr fn_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr)); - for (auto input : inputs) { - TF_RETURN_IF_ERROR(fn_op->AddInput(input)); - } - int retvals = outputs.size() - null_indices.size(); - vector fn_outputs(retvals); - TF_RETURN_IF_ERROR(fn_op->Execute( - absl::Span(fn_outputs.data(), fn_outputs.size()), - &retvals)); - int skipped_indices = 0; - for (int i = 0; i < outputs.size(); i++) { - if (!null_indices.contains(i)) { - outputs[i] = fn_outputs[i - skipped_indices]; - } else { - skipped_indices += 1; - } - } - TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name)); - return Status::OK(); - } else { - return model(ctx, inputs, outputs, registry); - } -} - -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) { - std::unique_ptr status( - TF_NewStatus(), TF_DeleteStatus); - TFE_ContextOptions* opts = TFE_NewContextOptions(); - TFE_ContextOptionsSetTfrt(opts, use_tfrt); - *ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get())); - TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); - TFE_DeleteContextOptions(opts); - return Status::OK(); -} - -} // namespace internal -} // namespace gradients -} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index c608636fe4d..fd61b026ab4 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -122,44 +122,16 @@ Status ScalarMulModel(AbstractContext* ctx, const GradientRegistry& registry); Status MatMulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); Status MulModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -Status SoftmaxModel(AbstractContext* ctx, - absl::Span inputs, - absl::Span outputs, - const GradientRegistry& registry); - -// Updates the weights for a neural network given incoming grads and learning -// rate -Status UpdateWeights(AbstractContext* ctx, - std::vector& grads, - std::vector& weights, - AbstractTensorHandle* learning_rate); - -AbstractContext* BuildFunction(const char* fn_name); - -Status CreateParamsForInputs(AbstractContext* ctx, - absl::Span inputs, - std::vector* params); - -using Model = std::function, - absl::Span, const GradientRegistry&)>; - -Status RunModel(Model model, AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, bool use_function, + absl::Span outputs, const GradientRegistry& registry); -Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); - -} // namespace internal -} // namespace gradients -} // namespace tensorflow +Status SoftmaxModel(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, + const GradientRegistry& registry); diff --git a/tensorflow/c/experimental/ops/math_ops.cc b/tensorflow/c/experimental/ops/math_ops.cc index c6f98a1fe31..2c6d01b5e21 100644 --- a/tensorflow/c/experimental/ops/math_ops.cc +++ b/tensorflow/c/experimental/ops/math_ops.cc @@ -87,7 +87,6 @@ Status Sub(AbstractContext* ctx, absl::Span inputs, return Status::OK(); } - Status MatMul(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name, @@ -125,24 +124,6 @@ Status Neg(AbstractContext* ctx, absl::Span inputs, return neg_op->Execute(outputs, &num_retvals); } -Status Prod(AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, const char* name) { - AbstractOperationPtr prod_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(prod_op->Reset("Prod", /*raw_device_name=*/nullptr)); - - if (isa(prod_op.get())) { - TF_RETURN_IF_ERROR( - dyn_cast(prod_op.get())->SetOpName(name)); - } - - TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[0])); // input_vals - TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[1])); // reduction_indices - - int num_retvals = 1; - TF_RETURN_IF_ERROR(prod_op->Execute(outputs, &num_retvals)); - return Status::OK(); -} - Status Sum(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name) { AbstractOperationPtr sum_op(ctx->CreateOperation()); @@ -153,29 +134,31 @@ Status Sum(AbstractContext* ctx, absl::Span inputs, dyn_cast(sum_op.get())->SetOpName(name)); } - TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals - TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals + TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices int num_retvals = 1; TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals)); return Status::OK(); } -Status EuclideanNorm(AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, const char* name) { - AbstractOperationPtr norm_op(ctx->CreateOperation()); - TF_RETURN_IF_ERROR(norm_op->Reset("EuclideanNorm", /*raw_device_name=*/nullptr)); +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name) { + AbstractOperationPtr div_op(ctx->CreateOperation()); + TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr)); - if (isa(norm_op.get())) { + if (isa(div_op.get())) { TF_RETURN_IF_ERROR( - dyn_cast(norm_op.get())->SetOpName(name)); + dyn_cast(div_op.get())->SetOpName(name)); } - TF_RETURN_IF_ERROR(norm_op->AddInput(inputs[0])); // input_vals - TF_RETURN_IF_ERROR(norm_op->AddInput(inputs[1])); // reduction_indices + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x + TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y int num_retvals = 1; - TF_RETURN_IF_ERROR(norm_op->Execute(outputs, &num_retvals)); + TF_RETURN_IF_ERROR(div_op->Execute( + outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0) return Status::OK(); } diff --git a/tensorflow/c/experimental/ops/math_ops.h b/tensorflow/c/experimental/ops/math_ops.h index 14667ec3b9c..004b8f2bb4d 100644 --- a/tensorflow/c/experimental/ops/math_ops.h +++ b/tensorflow/c/experimental/ops/math_ops.h @@ -38,17 +38,15 @@ Status MatMul(AbstractContext* ctx, Status Neg(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); -Status Prod(AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, const char* name); - Status Sum(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); Status Sub(AbstractContext* ctx, absl::Span inputs, absl::Span outputs, const char* name); -Status EuclideanNorm(AbstractContext* ctx, absl::Span inputs, - absl::Span outputs, const char* name); +Status DivNoNan(AbstractContext* ctx, + absl::Span inputs, + absl::Span outputs, const char* name); } // namespace ops } // namespace tensorflow