crude gradient check working, need to decompose functionality

This commit is contained in:
amturati 2020-08-11 22:37:37 +00:00
parent f23b2d2361
commit 713bed7cee
10 changed files with 729 additions and 44 deletions

View File

@ -276,6 +276,84 @@ cc_library(
],
)
cc_library(
name = "gradient_checker",
testonly = True,
srcs = [
"gradient_checker.cc",
],
hdrs = [
"gradient_checker.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
":mnist_gradients_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "gradient_checker_test",
size = "small",
srcs = [
"gradient_checker_test.cc",
],
args = ["--heap_check=local"],
extra_copts = tfe_xla_copts(),
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["nomac"],
deps = [
":gradient_checker",
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
":mnist_gradients_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
tf_cuda_cc_test(
name = "mnist_gradients_test",
size = "small",

View File

@ -0,0 +1,258 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include "tensorflow/c/eager/mnist_gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
// ================== TensorHandle generating functions =================
// Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
void printArr(auto data [], int n){
std::cout<<"[";
for (int i = 0; i < n-1; i++) {
std::cout << data[i] << ", ";
}
std::cout << data[n-1] << "]"<<std::endl;
}
// Fills out_dims with the dimensions of the given tensor
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
int num_dims = TF_NumDims(t);
for (int i = 0; i < num_dims; i++) {
out_dims[i] = TF_Dim(t, i);
}
}
// Fills data with values [start,end) with given step size
void range(int data[], int start, int end, int step = 1) {
for(int i = start; i < end; i += step) {
data[i] = i;
}
}
// ====================================================================
Status GradientCheck(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs,
int gradIndex,
AbstractTensorHandle* dtheta){
float epsilon = 1e-6;
GradientRegistry registry;
Status s;
AbstractTensorHandle* theta = inputs[gradIndex]; // parameter we are grad checking
// Convert from AbstractTensor to TF_Tensor
TF_Tensor* theta_tensor;
s = GetValue(theta, &theta_tensor);
// Get number of elements
int num_elems = TF_TensorElementCount(theta_tensor);
// Get theta shape
int num_dims = TF_NumDims(theta_tensor);
int64_t theta_dims [num_dims];
GetDims(theta_tensor, theta_dims);
// Initialize data structures
float thetaPlus_data [num_elems];
float thetaMinus_data [num_elems];
float dtheta_approx[num_elems];
std::vector<AbstractTensorHandle*> sum_inputs(2);
std::vector<AbstractTensorHandle*> sum_outputs(1);
std::vector<AbstractTensorHandle*> model_outputs(1);
// make this a helper function
int dims_to_sum [num_dims];
int64_t dims_shape[] = {num_dims};
range(dims_to_sum, 0, num_dims);
//printArr(dims_to_sum, num_dims);
AbstractTensorHandlePtr sum_dims =
GetTensorHandleUtilInt(ctx, dims_to_sum, dims_shape, 1);
for (int i = 0; i < num_elems; i++) {
// initialize theta[i] + epsilon
memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaPlus_data[i] += epsilon;
AbstractTensorHandlePtr thetaPlus =
GetTensorHandleUtilFloat(ctx, thetaPlus_data, theta_dims, num_dims);
// initialize theta[i] - epsilon
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
TF_TensorByteSize(theta_tensor));
thetaMinus_data[i] -= epsilon;
AbstractTensorHandlePtr thetaMinus =
GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims);
// Get f(theta + eps)
inputs[gradIndex] = thetaPlus.get();
s = RunModel(forward, ctx, absl::MakeSpan(inputs),
absl::MakeSpan(model_outputs),
/*use_function=*/false, registry);
AbstractTensorHandle* fPlus_toSum = model_outputs[0];
sum_inputs[0] = fPlus_toSum;
sum_inputs[1] = sum_dims.get();
s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sum_output");
AbstractTensorHandle* fPlus = sum_outputs[0];
// Get f(theta - eps)
inputs[gradIndex] = thetaMinus.get();
s = RunModel(forward, ctx, absl::MakeSpan(inputs),
absl::MakeSpan(model_outputs),
/*use_function=*/false, registry);
AbstractTensorHandle* fMinus_toSum = model_outputs[0];
sum_inputs[0] = fMinus_toSum;
sum_inputs[1] = sum_dims.get();
s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sum_output");
AbstractTensorHandle* fMinus = sum_outputs[0];
// Difference Quotient
sum_inputs[0] = fPlus;
sum_inputs[1] = fMinus;
s = ops::Sub(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sub_top");
AbstractTensorHandle* fDiff = sum_outputs[0];
TF_Tensor* fDiff_tensor;
s = GetValue(fDiff, &fDiff_tensor);
// ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float fDiff_data[1];
memcpy(&fDiff_data[0], TF_TensorData(fDiff_tensor),
TF_TensorByteSize(fDiff_tensor));
float diff = fDiff_data[0];
float grad_approx = diff / (2.0*epsilon);
dtheta_approx[i] = grad_approx;
}
printArr(dtheta_approx, num_elems);
return Status::OK();
}

View File

@ -19,32 +19,26 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
Status GradientCheck(Model forward,
std::vector<AbstractTensorHandle*>& inputs,
int gradIndex, AbstractTensorHandle* dtheta){
}
Status GradientCheck(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs,
int gradIndex,
AbstractTensorHandle* dtheta);

View File

@ -0,0 +1,271 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
class GradientCheckerTest
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_SetTracingImplementation(std::get<0>(GetParam()));
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyLossRegisterer));
return Status::OK();
}
// ========================= Test Util Functions ==============================
// Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return Status::OK();
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
A.reset(a_raw);
return A;
}
// =========================== Start Tests ================================
TEST_P(GradientCheckerTest, TestMatMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = AB
* outputs = tape.gradient(Y, [A, B])
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
}
TF_Tensor* dB_tensor;
s = GetValue(outputs[1], &dB_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dB_tensor),
TF_TensorByteSize(dB_tensor));
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(dA_tensor);
TF_DeleteTensor(dB_tensor);
}
TEST_P(GradientCheckerTest, TestGradCheck) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = AB
* outputs = tape.gradient(Y, [A, B])
*/
std::vector<AbstractTensorHandle*> inputs;
inputs.push_back(A.get());
inputs.push_back(B.get());
s = GradientCheck(ctx.get(), MatMulModel, inputs, 0, B.get());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
// TODO(b/160888630): Enable this test with mlir after AddInputList is
// supported. It is needed for AddN op which is used for gradient aggregation.
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, GradientCheckerTest,
::testing::Combine(::testing::Values("graphdef"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -454,6 +454,27 @@ Status ScalarMulModel(AbstractContext* ctx,
return Status::OK();
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TapeVSpace vspace(ctx);
auto tape = new Tape(/*persistent=*/false);
std::vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0", /*transpose_a=*/false,
/*transpose_b=*/false, registry)); // Compute X*W1
outputs[0] = temp_outputs[0];
delete tape;
return Status::OK();
}
// ============================= End Models ================================
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,

View File

@ -121,6 +121,11 @@ Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Updates the weights for a neural network given incoming grads and learning
// rate
Status UpdateWeights(AbstractContext* ctx,

View File

@ -50,35 +50,22 @@ Status ZerosLike(AbstractContext* ctx,
}
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr shape_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(shape_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0]));
int num_retvals = 1;
return shape_op->Execute(outputs, &num_retvals);
}
Status Prod(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr prod_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(
prod_op->Reset("Prod", /*raw_device_name=*/nullptr));
if (isa<tensorflow::tracing::TracingOperation>(prod_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input vals
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // dims
int num_retvals = 1;
return shape_op->Execute(outputs, &num_retvals);
}
TF_RETURN_IF_ERROR(
shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(shape_op.get())) {
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
->SetOpName(name));
}
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
int num_retvals = 1;
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -31,6 +31,10 @@ Status ZerosLike(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Shape(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow

View File

@ -69,6 +69,25 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return Status::OK();
}
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sub_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sub_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sub_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
int num_retvals = 1;
TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
@ -106,5 +125,41 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
return neg_op->Execute(outputs, &num_retvals);
}
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr prod_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(prod_op->Reset("Prod", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(prod_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(prod_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[0])); // input_vals
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[1])); // reduction_indices
int num_retvals = 1;
TF_RETURN_IF_ERROR(prod_op->Execute(outputs, &num_retvals));
return Status::OK();
}
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
AbstractOperationPtr sum_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
if (isa<tracing::TracingOperation>(sum_op.get())) {
TF_RETURN_IF_ERROR(
dyn_cast<tracing::TracingOperation>(sum_op.get())->SetOpName(name));
}
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_i ndices
int num_retvals = 1;
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
return Status::OK();
}
} // namespace ops
} // namespace tensorflow

View File

@ -22,18 +22,30 @@ namespace tensorflow {
namespace ops {
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Conj(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status MatMul(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name,
bool transpose_a, bool transpose_b);
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, const char* name);
} // namespace ops
} // namespace tensorflow