added gradients_testutil files and fixed Status
This commit is contained in:
parent
e2b9889c8e
commit
0fdb700766
@ -248,6 +248,47 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_testutil",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"gradients_testutil.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients_testutil.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":tape",
|
||||
":abstract_tensor_handle",
|
||||
":gradients",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_testutil",
|
||||
srcs = [
|
||||
@ -295,6 +336,7 @@ cc_library(
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_testutil",
|
||||
":gradients_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
|
@ -36,82 +36,7 @@ limitations under the License.
|
||||
|
||||
using namespace std;
|
||||
|
||||
// ================== TensorHandle generating functions =================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
// ================== Helper functions =================
|
||||
|
||||
// Fills data with values [start,end) with given step size.
|
||||
void Range(int data[], int start, int end, int step = 1) {
|
||||
@ -146,12 +71,13 @@ Status RunModelAndSum(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> model_outputs(1);
|
||||
|
||||
// Run the model.
|
||||
Status s = RunModel(forward, ctx, inputs, absl::MakeSpan(model_outputs),
|
||||
use_function, registry);
|
||||
TF_RETURN_IF_ERROR(RunModel(forward, ctx, inputs,
|
||||
absl::MakeSpan(model_outputs), use_function,
|
||||
registry));
|
||||
AbstractTensorHandle* f_toSum = model_outputs[0];
|
||||
|
||||
TF_Tensor* model_out_tensor;
|
||||
s = GetValue(f_toSum, &model_out_tensor);
|
||||
TF_RETURN_IF_ERROR(GetValue(f_toSum, &model_out_tensor));
|
||||
int num_dims_out = TF_NumDims(model_out_tensor);
|
||||
|
||||
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
|
||||
@ -163,8 +89,8 @@ Status RunModelAndSum(AbstractContext* ctx, Model forward,
|
||||
sum_inputs[0] = f_toSum;
|
||||
sum_inputs[1] = sum_dims.get();
|
||||
|
||||
s = ops::Sum(ctx, absl::MakeSpan(sum_inputs), absl::MakeSpan(model_outputs),
|
||||
"sum_output");
|
||||
TF_RETURN_IF_ERROR(ops::Sum(ctx, absl::MakeSpan(sum_inputs),
|
||||
absl::MakeSpan(model_outputs), "sum_output"));
|
||||
outputs[0] = model_outputs[0];
|
||||
return Status::OK();
|
||||
}
|
||||
@ -182,21 +108,20 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
||||
} else {
|
||||
s = RunModelAndSum(ctx, forward, inputs, outputs, use_function);
|
||||
}
|
||||
return Status::OK();
|
||||
return s;
|
||||
}
|
||||
// ========================= End Util Functions==============================
|
||||
// ========================= End Helper Functions==============================
|
||||
|
||||
Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
float* dtheta_approx, int gradIndex, bool use_function,
|
||||
bool is_scalar_out) {
|
||||
Status s;
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
float* dtheta_approx, int input_index,
|
||||
bool use_function, bool is_scalar_out) {
|
||||
AbstractTensorHandle* theta =
|
||||
inputs[gradIndex]; // parameter we are grad checking
|
||||
inputs[input_index]; // parameter we are grad checking
|
||||
|
||||
// Convert from AbstractTensor to TF_Tensor.
|
||||
TF_Tensor* theta_tensor;
|
||||
s = GetValue(theta, &theta_tensor);
|
||||
TF_RETURN_IF_ERROR(GetValue(theta, &theta_tensor));
|
||||
|
||||
// Get number of elements and fill data.
|
||||
int num_elems = TF_TensorElementCount(theta_tensor);
|
||||
@ -219,6 +144,8 @@ Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
// Get relative epsilon value
|
||||
float epsilon =
|
||||
std::abs(theta_data[i] * 1e-4 + 1e-4); // add 1e-4 to prevent div by 0
|
||||
AbstractTensorHandlePtr two_eps =
|
||||
GetScalarTensorHandleUtil(ctx, 2 * epsilon);
|
||||
|
||||
// Initialize theta[i] + epsilon.
|
||||
memcpy(&thetaPlus_data[0], TF_TensorData(theta_tensor),
|
||||
@ -235,33 +162,39 @@ Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
GetTensorHandleUtilFloat(ctx, thetaMinus_data, theta_dims, num_dims);
|
||||
|
||||
// Get f(theta + eps):
|
||||
inputs[gradIndex] = thetaPlus.get();
|
||||
s = RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(f_outputs), use_function, is_scalar_out);
|
||||
inputs[input_index] = thetaPlus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(f_outputs), use_function,
|
||||
is_scalar_out));
|
||||
AbstractTensorHandle* fPlus = f_outputs[0];
|
||||
|
||||
// Get f(theta - eps):
|
||||
inputs[gradIndex] = thetaMinus.get();
|
||||
s = RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(f_outputs), use_function, is_scalar_out);
|
||||
inputs[input_index] = thetaMinus.get();
|
||||
TF_RETURN_IF_ERROR(RunAndMaybeSum(ctx, forward, absl::MakeSpan(inputs),
|
||||
absl::MakeSpan(f_outputs), use_function,
|
||||
is_scalar_out));
|
||||
AbstractTensorHandle* fMinus = f_outputs[0];
|
||||
|
||||
// Take Difference of both estimates: (f(x + eps) - f(x - eps)).
|
||||
s = ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top");
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Sub(ctx, {fPlus, fMinus}, absl::MakeSpan(f_outputs), "sub_top"));
|
||||
AbstractTensorHandle* fDiff = f_outputs[0];
|
||||
|
||||
// Get difference value for calculation.
|
||||
TF_Tensor* fDiff_tensor;
|
||||
s = GetValue(fDiff, &fDiff_tensor);
|
||||
float fDiff_data[1];
|
||||
memcpy(&fDiff_data[0], TF_TensorData(fDiff_tensor),
|
||||
TF_TensorByteSize(fDiff_tensor));
|
||||
|
||||
// Calculate using the difference quotient definition:
|
||||
// (f(x + eps) - f(x - eps)) / (2 * eps).
|
||||
float grad_approx = fDiff_data[0] / (2.0 * epsilon);
|
||||
dtheta_approx[i] = grad_approx;
|
||||
TF_RETURN_IF_ERROR(ops::DivNoNan(ctx, {fDiff, two_eps.get()},
|
||||
absl::MakeSpan(f_outputs),
|
||||
"diff_quotient"));
|
||||
AbstractTensorHandle* diff_quotient = f_outputs[0];
|
||||
|
||||
TF_Tensor* grad_tensor;
|
||||
TF_RETURN_IF_ERROR(GetValue(diff_quotient, &grad_tensor));
|
||||
float grad_data[1];
|
||||
memcpy(&grad_data[0], TF_TensorData(grad_tensor),
|
||||
TF_TensorByteSize(grad_tensor));
|
||||
|
||||
dtheta_approx[i] = grad_data[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
@ -24,6 +22,8 @@ limitations under the License.
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
@ -37,17 +37,19 @@ using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
/** Returns numerical grad inside `dtheta_approx` given `forward` model and parameter
|
||||
* specified by `gradIndex`
|
||||
*
|
||||
* `use_function` indicates whether to use graph mode(true) or eager(false)
|
||||
*
|
||||
* `is_scalar_out` should be true when `forward` returns a scalar TensorHandle;
|
||||
* else GradientCheck will reduce_sum the tensor to get a scalar to estimate
|
||||
* the gradient with. Default is false.
|
||||
*/
|
||||
Status GradientCheck(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
float* dtheta_approx,
|
||||
int gradIndex, bool use_function,
|
||||
bool is_scalar_out=false);
|
||||
/** Returns numerical grad inside `dtheta_approx` given `forward` model and
|
||||
* parameter specified by `input_index`.
|
||||
*
|
||||
* I.e. if y = <output of the forward model> and w = inputs[input_index],
|
||||
* this will calculate dy/dw numerically.
|
||||
*
|
||||
* `use_function` indicates whether to use graph mode(true) or eager(false).
|
||||
*
|
||||
* `is_scalar_out` should be true when `forward` returns a scalar TensorHandle;
|
||||
* else this function will reduce_sum the tensor to get a scalar to estimate
|
||||
* the gradient with. Default is false.
|
||||
*/
|
||||
Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
float* dtheta_approx, int input_index,
|
||||
bool use_function, bool is_scalar_out = false);
|
||||
|
@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
// #include "tensorflow/c/eager/gradients_testutil.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
@ -52,85 +53,6 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ========================= Test Util Functions ==============================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
A.reset(a_raw);
|
||||
return A;
|
||||
}
|
||||
|
||||
// =========================== Start Tests ================================
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -159,9 +81,9 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
inputs.push_back(B.get());
|
||||
|
||||
float dapprox[4] = {0};
|
||||
Status s =
|
||||
GradientCheck(ctx.get(), MatMulModel, inputs, dapprox, /*gradIndex=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
Status s = CalcNumericalGrad(ctx.get(), MatMulModel, inputs, dapprox,
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
@ -171,21 +93,66 @@ TEST_P(GradientCheckerTest, TestGradCheckMatMul) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// Will perform z = x*y.
|
||||
// dz/dx = y
|
||||
|
||||
std::vector<AbstractTensorHandle*> inputs;
|
||||
inputs.push_back(x.get());
|
||||
inputs.push_back(y.get());
|
||||
float dapprox[1] = {0};
|
||||
Status s =
|
||||
CalcNumericalGrad(ctx.get(), MulModel, inputs, dapprox, /*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()),
|
||||
/*is_scalar_out=*/true);
|
||||
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_NEAR(dapprox[0], 7.0f, /*tolerance=*/1e-3);
|
||||
}
|
||||
|
||||
TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
/** Test to show how to use this API with analytical gradients:
|
||||
*
|
||||
* We have `SoftmaxLossGradModel`, which is a wrapper for the
|
||||
* Softmax analytical gradient found in c/experimental/nn_grads.
|
||||
*
|
||||
* We will use the GradientChecker by applying finite differences
|
||||
* to the forward pass wrapped in `SoftmaxModel` and verify that
|
||||
* both the analytical and numerical gradients are relatively
|
||||
* close.
|
||||
*
|
||||
*/
|
||||
*
|
||||
* We have `SoftmaxLossGradModel`, which is a wrapper for the
|
||||
* Softmax analytical gradient found in c/experimental/nn_grads.
|
||||
*
|
||||
* We will use the GradientChecker by applying finite differences
|
||||
* to the forward pass wrapped in `SoftmaxModel` and verify that
|
||||
* both the analytical and numerical gradients are relatively
|
||||
* close.
|
||||
*
|
||||
*/
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
@ -235,8 +202,9 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
|
||||
|
||||
// Run numerical gradient approximation using the GradientChecker API.
|
||||
float dapprox[9] = {0}; // Will contain numerical approximation data.
|
||||
s = GradientCheck(ctx.get(), SoftmaxModel, inputs, dapprox, /*gradIndex=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
s = CalcNumericalGrad(ctx.get(), SoftmaxModel, inputs, dapprox,
|
||||
/*input_index=*/0,
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Now compare the two implementations:
|
||||
@ -249,8 +217,6 @@ TEST_P(GradientCheckerTest, TestGradCheckSoftmax) {
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
// supported. It is needed for AddN op which is used for gradient aggregation.
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, GradientCheckerTest,
|
||||
|
271
tensorflow/c/eager/gradients_testutil.cc
Normal file
271
tensorflow/c/eager/gradients_testutil.cc
Normal file
@ -0,0 +1,271 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using namespace std;
|
||||
|
||||
// ================== TensorHandle generating functions =================
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val) {
|
||||
AbstractTensorHandlePtr y;
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx, val, &y_raw);
|
||||
if (s.ok()) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
94
tensorflow/c/eager/gradients_testutil.h
Normal file
94
tensorflow/c/eager/gradients_testutil.h
Normal file
@ -0,0 +1,94 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
|
||||
// Get a scalar TensorHandle with given value.
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions.
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Places data from `t` into *result_tensor.
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data.
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val);
|
||||
|
||||
// Performs gradient update for each weight using given learning rate.
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
// Helper function for RunModel to build the function for graph mode.
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
// Helper function for RunModel to add params for graph mode.
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
// Runs given model in either graph or eager mode depending on value of
|
||||
// use_function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
@ -242,7 +242,8 @@ Status MNISTForwardModel(AbstractContext* ctx,
|
||||
* hidden_layer = tf.nn.relu(mm_out_1)
|
||||
* scores = tf.matmul(hidden_layer,W2)
|
||||
* softmax =
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,y_labels)
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
|
||||
* y_labels)
|
||||
* return scores, softmax
|
||||
*
|
||||
* Use this convention for inputs:
|
||||
@ -455,10 +456,9 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
}
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
@ -476,10 +476,9 @@ Status MatMulModel(AbstractContext* ctx,
|
||||
}
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* y = inputs[1];
|
||||
|
||||
@ -487,7 +486,7 @@ Status MulModel(AbstractContext* ctx,
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
TF_RETURN_IF_ERROR(Mul(ctx, tape, {x, y}, absl::MakeSpan(temp_outputs),
|
||||
"mul0", registry)); // Compute x*y
|
||||
"mul0", registry)); // Compute x*y
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
@ -496,169 +495,23 @@ Status MulModel(AbstractContext* ctx,
|
||||
}
|
||||
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* x = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels}, absl::MakeSpan(temp_outputs),
|
||||
"sm_loss", registry));
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx, tape, {x, labels},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"sm_loss", registry));
|
||||
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
outputs[0] = temp_outputs[0]; // loss values
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
Status s;
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
s = ops::Mul(ctx, {learning_rate, grads[i]}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
s = ops::Add(ctx, {weights[i], dW}, absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str());
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -122,44 +122,16 @@ Status ScalarMulModel(AbstractContext* ctx,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Updates the weights for a neural network given incoming grads and learning
|
||||
// rate
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name);
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
Status SoftmaxModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
@ -87,7 +87,6 @@ Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
@ -125,24 +124,6 @@ Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
return neg_op->Execute(outputs, &num_retvals);
|
||||
}
|
||||
|
||||
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr prod_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(prod_op->Reset("Prod", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(prod_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(prod_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(prod_op->AddInput(inputs[1])); // reduction_indices
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(prod_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr sum_op(ctx->CreateOperation());
|
||||
@ -153,29 +134,31 @@ Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
dyn_cast<tracing::TracingOperation>(sum_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1])); // reduction_indices
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status EuclideanNorm(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr norm_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(norm_op->Reset("EuclideanNorm", /*raw_device_name=*/nullptr));
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr div_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(norm_op.get())) {
|
||||
if (isa<tracing::TracingOperation>(div_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(norm_op.get())->SetOpName(name));
|
||||
dyn_cast<tracing::TracingOperation>(div_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(norm_op->AddInput(inputs[0])); // input_vals
|
||||
TF_RETURN_IF_ERROR(norm_op->AddInput(inputs[1])); // reduction_indices
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0])); // x
|
||||
TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1])); // y
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(norm_op->Execute(outputs, &num_retvals));
|
||||
TF_RETURN_IF_ERROR(div_op->Execute(
|
||||
outputs, &num_retvals)); // z = x / y, (z_i = 0 if y_i = 0)
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -38,17 +38,15 @@ Status MatMul(AbstractContext* ctx,
|
||||
Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Prod(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
Status EuclideanNorm(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
Status DivNoNan(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user