diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 1052818c9f2..d259b32f339 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -260,6 +260,7 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/lib/llvm_rtti", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/strings", "@com_google_absl//absl/types:span", ], @@ -343,6 +344,7 @@ tf_cuda_cc_test( ":c_api_unified_internal", ":gradient_checker", ":gradients_internal", + ":gradients_util", ":mnist_gradients_testutil", "//tensorflow/c:c_api", "//tensorflow/c:c_test_util", diff --git a/tensorflow/c/eager/gradient_checker.cc b/tensorflow/c/eager/gradient_checker.cc index 2ca028ad865..640edc7228a 100644 --- a/tensorflow/c/eager/gradient_checker.cc +++ b/tensorflow/c/eager/gradient_checker.cc @@ -39,18 +39,19 @@ using namespace std; // ================== Helper functions ================= // Fills data with values [start,end) with given step size. -void Range(int data[], int start, int end, int step = 1) { +void Range(vector<int>* data, int start, int end, int step = 1) { for (int i = start; i < end; i += step) { - data[i] = i; + (*data)[i] = i; } } // Returns AbstractTensorHandlePtr containing [0, ..., n-1]. AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) { - int vals[n]; + vector<int> vals(n); int64_t vals_shape[] = {n}; - Range(vals, 0, n); - AbstractTensorHandlePtr r = GetTensorHandleUtilInt(ctx, vals, vals_shape, 1); + Range(&vals, 0, n); + AbstractTensorHandlePtr r = + GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1); return r; } @@ -118,21 +119,21 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, // Get number of elements and fill data. int num_elems = TF_TensorElementCount(theta_tensor); - float theta_data[num_elems] = {0}; - memcpy(&theta_data[0], TF_TensorData(theta_tensor), + vector<float> theta_data(num_elems); + memcpy(theta_data.data(), TF_TensorData(theta_tensor), TF_TensorByteSize(theta_tensor)); // Initialize space for the numerical gradient. - float dtheta_approx[num_elems]; + vector<float> dtheta_approx(num_elems); // Get theta shape and store in theta_dims. int num_dims = TF_NumDims(theta_tensor); - int64_t theta_dims[num_dims]; - GetDims(theta_tensor, theta_dims); + vector<int64_t> theta_dims(num_dims); + GetDims(theta_tensor, theta_dims.data()); // Initialize auxilary data structures. - float thetaPlus_data[num_elems]; - float thetaMinus_data[num_elems]; + vector<float> thetaPlus_data(num_elems); + vector<float> thetaMinus_data(num_elems); std::vector<AbstractTensorHandle*> f_outputs(1); // Numerical Grad Check @@ -144,18 +145,18 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, GetScalarTensorHandleUtil(ctx, 2 * epsilon); // Initialize theta[i] + epsilon. - memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor), + memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor), TF_TensorByteSize(theta_tensor)); thetaPlus_data[i] += epsilon; - AbstractTensorHandlePtr thetaPlus = - GetTensorHandleUtilFloat(ctx, thetaPlus_data, theta_dims, num_dims); + AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat( + ctx, thetaPlus_data.data(), theta_dims.data(), num_dims); // Initialize theta[i] - epsilon. memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor), TF_TensorByteSize(theta_tensor)); thetaMinus_data[i] -= epsilon; - AbstractTensorHandlePtr thetaMinus = - GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims); + AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat( + ctx, thetaMinus_data.data(), theta_dims.data(), num_dims); // Get f(theta + eps): inputs[input_index] = thetaPlus.get(); @@ -191,10 +192,10 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward, } // Populate *numerical_grad with the data from dtheta_approx. - TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(ctx, dtheta_approx, theta_dims, - num_dims, numerical_grad)); + TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat( + ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad)); return Status::OK(); } } // namespace gradients -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow/c/eager/gradient_checker_test.cc b/tensorflow/c/eager/gradient_checker_test.cc index c907ad8b9ce..98581a79f1f 100644 --- a/tensorflow/c/eager/gradient_checker_test.cc +++ b/tensorflow/c/eager/gradient_checker_test.cc @@ -51,7 +51,7 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer)); TF_RETURN_IF_ERROR( registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyLossRegisterer)); + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); return Status::OK(); } @@ -89,7 +89,8 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); TF_Tensor* gt; - GetValue(grad_approx, >); + s = GetValue(grad_approx, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); float result_data[4] = {0}; memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); @@ -136,7 +137,6 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) { std::vector<AbstractTensorHandle*> inputs; inputs.push_back(x.get()); inputs.push_back(y.get()); - float dapprox[1] = {0}; AbstractTensorHandle* g; Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs), @@ -145,11 +145,12 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); TF_Tensor* gt; - GetValue(g, >); + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); float result_data[1] = {0}; memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt)); - ASSERT_NEAR(result_data[0], 7.0f, /*tolerance=*/1e-2); + ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2); TF_DeleteTensor(gt); } @@ -223,13 +224,14 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) { ASSERT_EQ(errors::OK, s.code()) << s.error_message(); TF_Tensor* gt; - GetValue(g, >); + s = GetValue(g, >); + ASSERT_EQ(errors::OK, s.code()) << s.error_message(); float dnumerical[9] = {0}; memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt)); // Now compare the two implementations: for (int j = 0; j < 9; j++) { - ASSERT_NEAR(dnumerical[j], danalytical[j], /*tolerance=*/1e-2); + ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2); } // Only Unref() first output as 2nd is nullptr grad for labels diff --git a/tensorflow/c/eager/gradients_util.h b/tensorflow/c/eager/gradients_util.h index 3489a5b370b..cd0bbc0720d 100644 --- a/tensorflow/c/eager/gradients_util.h +++ b/tensorflow/c/eager/gradients_util.h @@ -85,4 +85,4 @@ Status RunModel(Model model, AbstractContext* ctx, Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); } // namespace gradients -} // namespace tensorflow \ No newline at end of file +} // namespace tensorflow diff --git a/tensorflow/c/eager/mnist_gradients_test.cc b/tensorflow/c/eager/mnist_gradients_test.cc index 322f747aa93..78549297edc 100644 --- a/tensorflow/c/eager/mnist_gradients_test.cc +++ b/tensorflow/c/eager/mnist_gradients_test.cc @@ -53,7 +53,7 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); TF_RETURN_IF_ERROR( registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyLossRegisterer)); + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); return Status::OK(); } diff --git a/tensorflow/c/eager/mnist_gradients_testutil.cc b/tensorflow/c/eager/mnist_gradients_testutil.cc index be9ce5a86da..932605ab8e0 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.cc +++ b/tensorflow/c/eager/mnist_gradients_testutil.cc @@ -130,9 +130,9 @@ Status Relu(AbstractContext* ctx, Tape* tape, registry); } -// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the -// tape. -Status SparseSoftmaxCrossEntropyLoss( +// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not +// one-hot) and records it on the tape. +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, Tape* tape, absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle*> outputs, const char* name, @@ -277,7 +277,7 @@ Status MNISTForwardModel(AbstractContext* ctx, AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), "softmax_loss", registry)); // Compute Softmax(Scores,labels) @@ -351,7 +351,7 @@ Status SoftmaxLossGradModel(AbstractContext* ctx, tape->Watch(ToId(inputs[0])); // Watch scores. tape->Watch(ToId(inputs[1])); // Watch labels. vector<AbstractTensorHandle*> sm_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry)); std::unordered_map<tensorflow::int64, TapeTensor> @@ -406,7 +406,7 @@ Status MNISTGradModel(AbstractContext* ctx, AbstractTensorHandle* scores = temp_outputs[0]; temp_outputs.resize(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss( + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs), "softmaxloss", registry)); // W2*Relu(X*W1) @@ -501,9 +501,9 @@ Status SoftmaxModel(AbstractContext* ctx, TapeVSpace vspace(ctx); auto tape = new Tape(/*persistent=*/false); std::vector<AbstractTensorHandle*> temp_outputs(2); - TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels}, - absl::MakeSpan(temp_outputs), - "sm_loss", registry)); + TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits( + ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss", + registry)); outputs[0] = temp_outputs[0]; // loss values diff --git a/tensorflow/c/eager/mnist_gradients_testutil.h b/tensorflow/c/eager/mnist_gradients_testutil.h index b72163b9808..1cf87bb9dee 100644 --- a/tensorflow/c/eager/mnist_gradients_testutil.h +++ b/tensorflow/c/eager/mnist_gradients_testutil.h @@ -26,8 +26,6 @@ limitations under the License. #include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/math_ops.h" #include "tensorflow/c/experimental/ops/nn_ops.h" -#include "tensorflow/c/tf_status_helper.h" -#include "tensorflow/c/tf_tensor.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/platform/status.h" @@ -63,7 +61,7 @@ Status Relu(AbstractContext* ctx, Tape* tape, // Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the // tape. -Status SparseSoftmaxCrossEntropyLoss( +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, Tape* tape, absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle*> outputs, const char* name, @@ -142,4 +140,4 @@ Status SoftmaxModel(AbstractContext* ctx, } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ \ No newline at end of file +#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_ diff --git a/tensorflow/c/experimental/gradients/nn_grad.cc b/tensorflow/c/experimental/gradients/nn_grad.cc index 3da1e0dc153..26e44d38578 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.cc +++ b/tensorflow/c/experimental/gradients/nn_grad.cc @@ -19,12 +19,8 @@ limitations under the License. #include "tensorflow/c/experimental/ops/nn_ops.h" using std::vector; -using tensorflow::ops::Conj; -using tensorflow::ops::Identity; using tensorflow::ops::Mul; using tensorflow::ops::ReluGrad; -using tensorflow::ops::SparseSoftmaxCrossEntropyLoss; -using tensorflow::ops::ZerosLike; namespace tensorflow { namespace gradients { @@ -99,7 +95,7 @@ BackwardFunction* ReluRegisterer(const ForwardOperation& op) { return new BackwardFunction(gradient_function, default_gradients); } -BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op) { auto gradient_function = new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs); diff --git a/tensorflow/c/experimental/gradients/nn_grad.h b/tensorflow/c/experimental/gradients/nn_grad.h index d002725847f..034f20d7325 100644 --- a/tensorflow/c/experimental/gradients/nn_grad.h +++ b/tensorflow/c/experimental/gradients/nn_grad.h @@ -20,9 +20,9 @@ limitations under the License. namespace tensorflow { namespace gradients { BackwardFunction* ReluRegisterer(const ForwardOperation& op); -BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer( +BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer( const ForwardOperation& op); } // namespace gradients } // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ \ No newline at end of file +#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_ diff --git a/tensorflow/c/experimental/ops/BUILD b/tensorflow/c/experimental/ops/BUILD index 9c8dbfcbfb8..d2c22e65f80 100644 --- a/tensorflow/c/experimental/ops/BUILD +++ b/tensorflow/c/experimental/ops/BUILD @@ -80,6 +80,11 @@ cc_library( ":array_ops", ":math_ops", ":nn_ops", + "//tensorflow/c/eager:abstract_context", + "//tensorflow/c/eager:abstract_operation", + "//tensorflow/c/eager:abstract_tensor_handle", + "//tensorflow/c/eager:c_api_unified_internal", + "//tensorflow/core/lib/llvm_rtti", ], ) diff --git a/tensorflow/c/experimental/ops/nn_ops.cc b/tensorflow/c/experimental/ops/nn_ops.cc index 8f5f550bb8b..df18e9352a2 100644 --- a/tensorflow/c/experimental/ops/nn_ops.cc +++ b/tensorflow/c/experimental/ops/nn_ops.cc @@ -21,7 +21,7 @@ namespace tensorflow { namespace ops { // Softmax Loss given scores and labels, used by the SoftMaxLossGradient -Status SparseSoftmaxCrossEntropyLoss( +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle*> outputs, const char* name) { AbstractOperationPtr sm_loss_op(ctx->CreateOperation()); diff --git a/tensorflow/c/experimental/ops/nn_ops.h b/tensorflow/c/experimental/ops/nn_ops.h index 3e618b00869..276a4398a71 100644 --- a/tensorflow/c/experimental/ops/nn_ops.h +++ b/tensorflow/c/experimental/ops/nn_ops.h @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace ops { -Status SparseSoftmaxCrossEntropyLoss( +Status SparseSoftmaxCrossEntropyWithLogits( AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle*> outputs, const char* name); diff --git a/tensorflow/python/framework/experimental/tape.cc b/tensorflow/python/framework/experimental/tape.cc index 003a70141dd..8b5445db562 100644 --- a/tensorflow/python/framework/experimental/tape.cc +++ b/tensorflow/python/framework/experimental/tape.cc @@ -33,7 +33,7 @@ Status RegisterGradients(GradientRegistry* registry) { TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer)); TF_RETURN_IF_ERROR( registry->Register("SparseSoftmaxCrossEntropyWithLogits", - SparseSoftmaxCrossEntropyLossRegisterer)); + SparseSoftmaxCrossEntropyWithLogitsRegisterer)); return Status::OK(); }