crude gradient check working, need to decompose functionality
This commit is contained in:
parent
f23b2d2361
commit
713bed7cee
@ -276,6 +276,84 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"gradient_checker.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradient_checker.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "gradient_checker_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"gradient_checker_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":gradient_checker",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
|
258
tensorflow/c/eager/gradient_checker.cc
Normal file
258
tensorflow/c/eager/gradient_checker.cc
Normal file
@ -0,0 +1,258 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
|
||||
// ================== TensorHandle generating functions =================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
void printArr(auto data [], int n){
|
||||
std::cout<<"[";
|
||||
for (int i = 0; i < n-1; i++) {
|
||||
|
||||
std::cout << data[i] << ", ";
|
||||
}
|
||||
std::cout << data[n-1] << "]"<<std::endl;
|
||||
}
|
||||
|
||||
// Fills out_dims with the dimensions of the given tensor
|
||||
void GetDims(const TF_Tensor* t, int64_t* out_dims) {
|
||||
|
||||
int num_dims = TF_NumDims(t);
|
||||
for (int i = 0; i < num_dims; i++) {
|
||||
out_dims[i] = TF_Dim(t, i);
|
||||
}
|
||||
}
|
||||
|
||||
// Fills data with values [start,end) with given step size
|
||||
void range(int data[], int start, int end, int step = 1) {
|
||||
for(int i = start; i < end; i += step) {
|
||||
data[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// ====================================================================
|
||||
|
||||
|
||||
Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
int gradIndex,
|
||||
AbstractTensorHandle* dtheta){
|
||||
|
||||
float epsilon = 1e-6;
|
||||
GradientRegistry registry;
|
||||
|
||||
Status s;
|
||||
AbstractTensorHandle* theta = inputs[gradIndex]; // parameter we are grad checking
|
||||
|
||||
// Convert from AbstractTensor to TF_Tensor
|
||||
TF_Tensor* theta_tensor;
|
||||
s = GetValue(theta, &theta_tensor);
|
||||
|
||||
// Get number of elements
|
||||
int num_elems = TF_TensorElementCount(theta_tensor);
|
||||
|
||||
// Get theta shape
|
||||
int num_dims = TF_NumDims(theta_tensor);
|
||||
int64_t theta_dims [num_dims];
|
||||
GetDims(theta_tensor, theta_dims);
|
||||
|
||||
// Initialize data structures
|
||||
float thetaPlus_data [num_elems];
|
||||
float thetaMinus_data [num_elems];
|
||||
float dtheta_approx[num_elems];
|
||||
|
||||
std::vector<AbstractTensorHandle*> sum_inputs(2);
|
||||
std::vector<AbstractTensorHandle*> sum_outputs(1);
|
||||
std::vector<AbstractTensorHandle*> model_outputs(1);
|
||||
|
||||
|
||||
// make this a helper function
|
||||
int dims_to_sum [num_dims];
|
||||
int64_t dims_shape[] = {num_dims};
|
||||
range(dims_to_sum, 0, num_dims);
|
||||
//printArr(dims_to_sum, num_dims);
|
||||
AbstractTensorHandlePtr sum_dims =
|
||||
GetTensorHandleUtilInt(ctx, dims_to_sum, dims_shape, 1);
|
||||
|
||||
for (int i = 0; i < num_elems; i++) {
|
||||
|
||||
// initialize theta[i] + epsilon
|
||||
memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
thetaPlus_data[i] += epsilon;
|
||||
|
||||
AbstractTensorHandlePtr thetaPlus =
|
||||
GetTensorHandleUtilFloat(ctx, thetaPlus_data, theta_dims, num_dims);
|
||||
|
||||
// initialize theta[i] - epsilon
|
||||
memcpy(&thetaMinus_data[0], TF_TensorData(theta_tensor),
|
||||
TF_TensorByteSize(theta_tensor));
|
||||
thetaMinus_data[i] -= epsilon;
|
||||
|
||||
AbstractTensorHandlePtr thetaMinus =
|
||||
GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims);
|
||||
|
||||
// Get f(theta + eps)
|
||||
inputs[gradIndex] = thetaPlus.get();
|
||||
|
||||
s = RunModel(forward, ctx, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(model_outputs),
|
||||
/*use_function=*/false, registry);
|
||||
|
||||
AbstractTensorHandle* fPlus_toSum = model_outputs[0];
|
||||
sum_inputs[0] = fPlus_toSum;
|
||||
sum_inputs[1] = sum_dims.get();
|
||||
|
||||
s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sum_output");
|
||||
|
||||
AbstractTensorHandle* fPlus = sum_outputs[0];
|
||||
|
||||
// Get f(theta - eps)
|
||||
inputs[gradIndex] = thetaMinus.get();
|
||||
|
||||
s = RunModel(forward, ctx, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(model_outputs),
|
||||
/*use_function=*/false, registry);
|
||||
|
||||
AbstractTensorHandle* fMinus_toSum = model_outputs[0];
|
||||
sum_inputs[0] = fMinus_toSum;
|
||||
sum_inputs[1] = sum_dims.get();
|
||||
|
||||
s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sum_output");
|
||||
|
||||
AbstractTensorHandle* fMinus = sum_outputs[0];
|
||||
|
||||
// Difference Quotient
|
||||
sum_inputs[0] = fPlus;
|
||||
sum_inputs[1] = fMinus;
|
||||
|
||||
s = ops::Sub(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(sum_outputs), "sub_top");
|
||||
AbstractTensorHandle* fDiff = sum_outputs[0];
|
||||
|
||||
TF_Tensor* fDiff_tensor;
|
||||
s = GetValue(fDiff, &fDiff_tensor);
|
||||
// ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float fDiff_data[1];
|
||||
memcpy(&fDiff_data[0], TF_TensorData(fDiff_tensor),
|
||||
TF_TensorByteSize(fDiff_tensor));
|
||||
|
||||
float diff = fDiff_data[0];
|
||||
float grad_approx = diff / (2.0*epsilon);
|
||||
|
||||
dtheta_approx[i] = grad_approx;
|
||||
|
||||
}
|
||||
|
||||
printArr(dtheta_approx, num_elems);
|
||||
|
||||
return Status::OK();
|
||||
}
|
@ -19,32 +19,26 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
Status GradientCheck(Model forward,
|
||||
std::vector<AbstractTensorHandle*>& inputs,
|
||||
int gradIndex, AbstractTensorHandle* dtheta){
|
||||
|
||||
|
||||
|
||||
|
||||
}
|
||||
Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
int gradIndex,
|
||||
AbstractTensorHandle* dtheta);
|
271
tensorflow/c/eager/gradient_checker_test.cc
Normal file
271
tensorflow/c/eager/gradient_checker_test.cc
Normal file
@ -0,0 +1,271 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
class GradientCheckerTest
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()));
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyLossRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ========================= Test Util Functions ==============================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
// =========================== Start Tests ================================
|
||||
|
||||
TEST_P(GradientCheckerTest, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B =
|
||||
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(A)
|
||||
* tape.watch(B)
|
||||
* Y = AB
|
||||
* outputs = tape.gradient(Y, [A, B])
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = GetValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor),
|
||||
TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* dB_tensor;
|
||||
s = GetValue(outputs[1], &dB_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(dB_tensor),
|
||||
TF_TensorByteSize(dB_tensor));
|
||||
|
||||
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
TF_DeleteTensor(dB_tensor);
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheck) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B =
|
||||
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(A)
|
||||
* tape.watch(B)
|
||||
* Y = AB
|
||||
* outputs = tape.gradient(Y, [A, B])
|
||||
*/
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
inputs.push_back(A.get());
|
||||
inputs.push_back(B.get());
|
||||
|
||||
s = GradientCheck(ctx.get(), MatMulModel, inputs, 0, B.get());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, GradientCheckerTest,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, GradientCheckerTest,
|
||||
::testing::Combine(::testing::Values("graphdef"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -454,6 +454,27 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0", /*transpose_a=*/false,
|
||||
/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
|
@ -121,6 +121,11 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Updates the weights for a neural network given incoming grads and learning
|
||||
// rate
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
|
@ -50,35 +50,22 @@ Status ZerosLike(AbstractContext* ctx,
|
||||
}
|
||||
|
||||
Status Shape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr shape_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(shape_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0]));
|
||||
int num_retvals = 1;
|
||||
return shape_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Prod(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr prod_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(
|
||||
prod_op->Reset("Prod", /*raw_device_name=*/nullptr));
|
||||
if (isa<tensorflow::tracing::TracingOperation>(prod_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input vals
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // dims
|
||||
int num_retvals = 1;
|
||||
return shape_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
shape_op->Reset("Shape", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(shape_op.get())) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(shape_op.get())
|
||||
->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(shape_op->AddInput(inputs[0])); // input
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(shape_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -31,6 +31,10 @@ Status ZerosLike(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Shape(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -69,6 +69,25 @@ Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sub_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sub_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sub_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
@ -106,5 +125,41 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return neg_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr prod_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(prod_op->Reset("Prod", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(prod_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(prod_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[1])); // reduction_indices
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(prod_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sum_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(sum_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sum_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_i ndices
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -22,18 +22,30 @@ namespace tensorflow {
|
||||
namespace ops {
|
||||
Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Conj(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b);
|
||||
|
||||
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user