Rename SparseSoftmaxCrossEntropyLoss to SparseSoftmaxCrossEntropyWithLogits to match up with the op name.

PiperOrigin-RevId: 332353265
Change-Id: I76cf1dec99438a1661707006a9496aadd8226fc5
This commit is contained in:
Saurabh Saxena 2020-09-17 17:43:19 -07:00 committed by TensorFlower Gardener
parent c643fbcc96
commit 3d8020e442
13 changed files with 56 additions and 52 deletions

View File

@ -260,6 +260,7 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@ -343,6 +344,7 @@ tf_cuda_cc_test(
":c_api_unified_internal",
":gradient_checker",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",

View File

@ -39,18 +39,19 @@ using namespace std;
// ================== Helper functions =================
// Fills data with values [start,end) with given step size.
void Range(int data[], int start, int end, int step = 1) {
void Range(vector<int>* data, int start, int end, int step = 1) {
for (int i = start; i < end; i += step) {
data[i] = i;
(*data)[i] = i;
}
}
// Returns AbstractTensorHandlePtr containing [0, ..., n-1].
AbstractTensorHandlePtr GetRangeTensorHandleUtil(AbstractContext* ctx, int n) {
int vals[n];
vector<int> vals(n);
int64_t vals_shape[] = {n};
Range(vals, 0, n);
AbstractTensorHandlePtr r = GetTensorHandleUtilInt(ctx, vals, vals_shape, 1);
Range(&vals, 0, n);
AbstractTensorHandlePtr r =
GetTensorHandleUtilInt(ctx, vals.data(), vals_shape, 1);
return r;
}
@ -118,21 +119,21 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
// Get number of elements and fill data.
int num_elems = TF_TensorElementCount(theta_tensor);
float theta_data[num_elems] = {0};
memcpy(&theta_data[0], TF_TensorData(theta_tensor),
vector<float> theta_data(num_elems);
memcpy(theta_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
// Initialize space for the numerical gradient.
float dtheta_approx[num_elems];
vector<float> dtheta_approx(num_elems);
// Get theta shape and store in theta_dims.
int num_dims = TF_NumDims(theta_tensor);
int64_t theta_dims[num_dims];
GetDims(theta_tensor, theta_dims);
vector<int64_t> theta_dims(num_dims);
GetDims(theta_tensor, theta_dims.data());
// Initialize auxilary data structures.
float thetaPlus_data[num_elems];
float thetaMinus_data[num_elems];
vector<float> thetaPlus_data(num_elems);
vector<float> thetaMinus_data(num_elems);
std::vector<AbstractTensorHandle*> f_outputs(1);
// Numerical Grad Check
@ -144,18 +145,18 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
// Initialize theta[i] + epsilon.
memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor),
memcpy(thetaPlus_data.data(), TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus =
GetTensorHandleUtilFloat(ctx, thetaPlus_data, theta_dims, num_dims);
AbstractTensorHandlePtr thetaPlus = GetTensorHandleUtilFloat(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims);
// Initialize theta[i] - epsilon.
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus =
GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims);
AbstractTensorHandlePtr thetaMinus = GetTensorHandleUtilFloat(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims);
// Get f(theta + eps):
inputs[input_index] = thetaPlus.get();
@ -191,10 +192,10 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(ctx, dtheta_approx, theta_dims,
num_dims, numerical_grad));
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloat(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow
} // namespace tensorflow

View File

@ -51,7 +51,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
@ -89,7 +89,8 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
GetValue(grad_approx, &gt);
s = GetValue(grad_approx, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
@ -136,7 +137,6 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(x.get());
inputs.push_back(y.get());
float dapprox[1] = {0};
AbstractTensorHandle* g;
Status s = CalcNumericalGrad(ctx.get(), MulModel, absl::MakeSpan(inputs),
@ -145,11 +145,12 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
GetValue(g, &gt);
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[1] = {0};
memcpy(&result_data[0], TF_TensorData(gt), TF_TensorByteSize(gt));
ASSERT_NEAR(result_data[0], 7.0f, /*tolerance=*/1e-2);
ASSERT_NEAR(result_data[0], 7.0f, /*abs_error=*/1e-2);
TF_DeleteTensor(gt);
}
@ -223,13 +224,14 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* gt;
GetValue(g, &gt);
s = GetValue(g, &gt);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float dnumerical[9] = {0};
memcpy(&dnumerical[0], TF_TensorData(gt), TF_TensorByteSize(gt));
// Now compare the two implementations:
for (int j = 0; j < 9; j++) {
ASSERT_NEAR(dnumerical[j], danalytical[j], /*tolerance=*/1e-2);
ASSERT_NEAR(dnumerical[j], danalytical[j], /*abs_error=*/1e-2);
}
// Only Unref() first output as 2nd is nullptr grad for labels

View File

@ -85,4 +85,4 @@ Status RunModel(Model model, AbstractContext* ctx,
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow
} // namespace tensorflow

View File

@ -53,7 +53,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}

View File

@ -130,9 +130,9 @@ Status Relu(AbstractContext* ctx, Tape* tape,
registry);
}
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
// Computes `SoftmaxLoss(scores, labels)` where labels are categorical (not
// one-hot) and records it on the tape.
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
@ -277,7 +277,7 @@ Status MNISTForwardModel(AbstractContext* ctx,
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss", registry)); // Compute Softmax(Scores,labels)
@ -351,7 +351,7 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
tape->Watch(ToId(inputs[0])); // Watch scores.
tape->Watch(ToId(inputs[1])); // Watch labels.
vector<AbstractTensorHandle*> sm_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, inputs, absl::MakeSpan(sm_outputs), "softmax0", registry));
std::unordered_map<tensorflow::int64, TapeTensor>
@ -406,7 +406,7 @@ Status MNISTGradModel(AbstractContext* ctx,
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss", registry)); // W2*Relu(X*W1)
@ -501,9 +501,9 @@ Status SoftmaxModel(AbstractContext* ctx,
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels},
absl::MakeSpan(temp_outputs),
"sm_loss", registry));
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyWithLogits(
ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs), "sm_loss",
registry));
outputs[0] = temp_outputs[0]; // loss values

View File

@ -26,8 +26,6 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
@ -63,7 +61,7 @@ Status Relu(AbstractContext* ctx, Tape* tape,
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the
// tape.
Status SparseSoftmaxCrossEntropyLoss(
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, Tape* tape,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
@ -142,4 +140,4 @@ Status SoftmaxModel(AbstractContext* ctx,
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_

View File

@ -19,12 +19,8 @@ limitations under the License.
#include "tensorflow/c/experimental/ops/nn_ops.h"
using std::vector;
using tensorflow::ops::Conj;
using tensorflow::ops::Identity;
using tensorflow::ops::Mul;
using tensorflow::ops::ReluGrad;
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
using tensorflow::ops::ZerosLike;
namespace tensorflow {
namespace gradients {
@ -99,7 +95,7 @@ BackwardFunction* ReluRegisterer(const ForwardOperation& op) {
return new BackwardFunction(gradient_function, default_gradients);
}
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op) {
auto gradient_function =
new SparseSoftmaxCrossEntropyLossGradientFunction(op.outputs);

View File

@ -20,9 +20,9 @@ limitations under the License.
namespace tensorflow {
namespace gradients {
BackwardFunction* ReluRegisterer(const ForwardOperation& op);
BackwardFunction* SparseSoftmaxCrossEntropyLossRegisterer(
BackwardFunction* SparseSoftmaxCrossEntropyWithLogitsRegisterer(
const ForwardOperation& op);
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_
#endif // TENSORFLOW_C_EXPERIMENTAL_GRADIENTS_NN_GRAD_H_

View File

@ -80,6 +80,11 @@ cc_library(
":array_ops",
":math_ops",
":nn_ops",
"//tensorflow/c/eager:abstract_context",
"//tensorflow/c/eager:abstract_operation",
"//tensorflow/c/eager:abstract_tensor_handle",
"//tensorflow/c/eager:c_api_unified_internal",
"//tensorflow/core/lib/llvm_rtti",
],
)

View File

@ -21,7 +21,7 @@ namespace tensorflow {
namespace ops {
// Softmax Loss given scores and labels, used by the SoftMaxLossGradient
Status SparseSoftmaxCrossEntropyLoss(
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sm_loss_op(ctx->CreateOperation());

View File

@ -23,7 +23,7 @@ limitations under the License.
namespace tensorflow {
namespace ops {
Status SparseSoftmaxCrossEntropyLoss(
Status SparseSoftmaxCrossEntropyWithLogits(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);

View File

@ -33,7 +33,7 @@ Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}