took out test-only functions from test_util and put them in gradients_util
This commit is contained in:
parent
0fdb700766
commit
a17c8e3f99
@ -249,13 +249,12 @@ tf_cuda_cc_test(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_testutil",
|
||||
testonly = True,
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
"gradients_testutil.cc",
|
||||
"gradients_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients_testutil.h",
|
||||
"gradients_util.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
@ -266,22 +265,20 @@ cc_library(
|
||||
":tape",
|
||||
":abstract_tensor_handle",
|
||||
":gradients",
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -319,7 +316,6 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
testonly = True,
|
||||
srcs = [
|
||||
"gradient_checker.cc",
|
||||
],
|
||||
@ -332,13 +328,10 @@ cc_library(
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_testutil",
|
||||
":gradients_testutil",
|
||||
":gradients_util",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
@ -349,8 +342,6 @@ cc_library(
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
|
@ -19,12 +19,10 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
@ -32,7 +30,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using namespace std;
|
||||
|
||||
@ -197,4 +197,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -17,13 +17,11 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
@ -31,7 +29,9 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
@ -53,3 +53,6 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
std::vector<AbstractTensorHandle*> inputs,
|
||||
float* dtheta_approx, int input_index,
|
||||
bool use_function, bool is_scalar_out = false);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -10,14 +10,13 @@ See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradient_checker.h"
|
||||
// #include "tensorflow/c/eager/gradients_testutil.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
@ -109,7 +108,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -117,7 +116,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
||||
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
@ -12,14 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
@ -31,28 +30,66 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using namespace std;
|
||||
|
||||
// ================== TensorHandle generating functions =================
|
||||
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value) {
|
||||
float data[] = {value};
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
TFE_TensorHandle* input_eager = ScalarTensorHandleHelper(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -61,14 +98,14 @@ Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
@ -77,7 +114,7 @@ Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
TensorHandleWithDimsIntHelper(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
@ -98,7 +135,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
@ -109,7 +146,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
@ -120,7 +157,7 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val) {
|
||||
AbstractTensorHandlePtr y;
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx, val, &y_raw);
|
||||
Status s = ScalarTensorHandle(ctx, val, &y_raw);
|
||||
if (s.ok()) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -268,4 +305,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -18,7 +18,6 @@ limitations under the License.
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
@ -30,24 +29,37 @@ limitations under the License.
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
using namespace std;
|
||||
using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
|
||||
// Get a scalar TensorHandle with given value.
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
// using namespace std;
|
||||
// using namespace tensorflow;
|
||||
// using namespace tensorflow::gradients;
|
||||
// using namespace tensorflow::gradients::internal;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions.
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
@ -68,8 +80,8 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val);
|
||||
|
||||
// Performs gradient update for each weight using given learning rate.
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
Status UpdateWeights(AbstractContext* ctx, std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
// Helper function for RunModel to build the function for graph mode.
|
||||
@ -78,7 +90,7 @@ AbstractContext* BuildFunction(const char* fn_name);
|
||||
// Helper function for RunModel to add params for graph mode.
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params);
|
||||
std::vector<AbstractTensorHandle*>* params);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
@ -91,4 +103,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user