took out test-only functions from test_util and put them in gradients_util

This commit is contained in:
amturati 2020-09-02 20:34:49 +00:00
parent 0fdb700766
commit a17c8e3f99
6 changed files with 108 additions and 57 deletions

View File

@ -249,13 +249,12 @@ tf_cuda_cc_test(
)
cc_library(
name = "gradients_testutil",
testonly = True,
name = "gradients_util",
srcs = [
"gradients_testutil.cc",
"gradients_util.cc",
],
hdrs = [
"gradients_testutil.h",
"gradients_util.h",
],
visibility = [
"//tensorflow:internal",
@ -266,22 +265,20 @@ cc_library(
":tape",
":abstract_tensor_handle",
":gradients",
":c_api",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -319,7 +316,6 @@ cc_library(
cc_library(
name = "gradient_checker",
testonly = True,
srcs = [
"gradient_checker.cc",
],
@ -332,13 +328,10 @@ cc_library(
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal",
":gradients_internal",
":mnist_gradients_testutil",
":gradients_testutil",
":gradients_util",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
@ -349,8 +342,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",

View File

@ -19,12 +19,10 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
@ -32,7 +30,9 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using namespace std;
@ -197,4 +197,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
}
return Status::OK();
}
}
} // namespace gradients
} // namespace tensorflow

View File

@ -17,13 +17,11 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_testutil.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
@ -31,7 +29,9 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
@ -53,3 +53,6 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs,
float* dtheta_approx, int input_index,
bool use_function, bool is_scalar_out = false);
} // namespace gradients
} // namespace tensorflow

View File

@ -10,14 +10,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h"
// #include "tensorflow/c/eager/gradients_testutil.h"
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
@ -109,7 +108,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
@ -117,7 +116,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}

View File

@ -12,14 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients_testutil.h"
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
@ -31,28 +30,66 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using namespace std;
// ================== TensorHandle generating functions =================
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value) {
float data[] = {value};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
// Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
TFE_TensorHandle* input_eager = ScalarTensorHandleHelper(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -61,14 +98,14 @@ Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -77,7 +114,7 @@ Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
TensorHandleWithDimsIntHelper(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
@ -98,7 +135,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
@ -109,7 +146,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
@ -120,7 +157,7 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) {
AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx, val, &y_raw);
Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) {
y.reset(y_raw);
}
@ -268,4 +305,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
}
} // namespace gradients
} // namespace tensorflow

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
@ -30,24 +29,37 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h"
using namespace std;
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
// Get a scalar TensorHandle with given value.
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
// using namespace std;
// using namespace tensorflow;
// using namespace tensorflow::gradients;
// using namespace tensorflow::gradients::internal;
namespace tensorflow {
namespace gradients {
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims);
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions.
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions.
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
@ -68,8 +80,8 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val);
// Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
Status UpdateWeights(AbstractContext* ctx, std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
// Helper function for RunModel to build the function for graph mode.
@ -78,7 +90,7 @@ AbstractContext* BuildFunction(const char* fn_name);
// Helper function for RunModel to add params for graph mode.
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params);
std::vector<AbstractTensorHandle*>* params);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
@ -91,4 +103,7 @@ Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow