took out test-only functions from test_util and put them in gradients_util
This commit is contained in:
parent
0fdb700766
commit
a17c8e3f99
@ -249,13 +249,12 @@ tf_cuda_cc_test(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gradients_testutil",
|
name = "gradients_util",
|
||||||
testonly = True,
|
|
||||||
srcs = [
|
srcs = [
|
||||||
"gradients_testutil.cc",
|
"gradients_util.cc",
|
||||||
],
|
],
|
||||||
hdrs = [
|
hdrs = [
|
||||||
"gradients_testutil.h",
|
"gradients_util.h",
|
||||||
],
|
],
|
||||||
visibility = [
|
visibility = [
|
||||||
"//tensorflow:internal",
|
"//tensorflow:internal",
|
||||||
@ -266,22 +265,20 @@ cc_library(
|
|||||||
":tape",
|
":tape",
|
||||||
":abstract_tensor_handle",
|
":abstract_tensor_handle",
|
||||||
":gradients",
|
":gradients",
|
||||||
|
":c_api",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_api_test_util",
|
|
||||||
":c_api_unified_internal",
|
":c_api_unified_internal",
|
||||||
":gradients_internal",
|
":gradients_internal",
|
||||||
"//tensorflow/c/experimental/ops:array_ops",
|
"//tensorflow/c/experimental/ops:array_ops",
|
||||||
"//tensorflow/c/experimental/ops:math_ops",
|
"//tensorflow/c/experimental/ops:math_ops",
|
||||||
"//tensorflow/c/experimental/ops:nn_ops",
|
"//tensorflow/c/experimental/ops:nn_ops",
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_test_util",
|
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/cc/profiler",
|
"//tensorflow/cc/profiler",
|
||||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
@ -319,7 +316,6 @@ cc_library(
|
|||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "gradient_checker",
|
name = "gradient_checker",
|
||||||
testonly = True,
|
|
||||||
srcs = [
|
srcs = [
|
||||||
"gradient_checker.cc",
|
"gradient_checker.cc",
|
||||||
],
|
],
|
||||||
@ -332,13 +328,10 @@ cc_library(
|
|||||||
deps = [
|
deps = [
|
||||||
":abstract_tensor_handle",
|
":abstract_tensor_handle",
|
||||||
":c_api_experimental",
|
":c_api_experimental",
|
||||||
":c_api_test_util",
|
|
||||||
":c_api_unified_internal",
|
":c_api_unified_internal",
|
||||||
":gradients_internal",
|
":gradients_internal",
|
||||||
":mnist_gradients_testutil",
|
":gradients_util",
|
||||||
":gradients_testutil",
|
|
||||||
"//tensorflow/c:c_api",
|
"//tensorflow/c:c_api",
|
||||||
"//tensorflow/c:c_test_util",
|
|
||||||
"//tensorflow/c:tf_status_helper",
|
"//tensorflow/c:tf_status_helper",
|
||||||
"//tensorflow/c/experimental/gradients:math_grad",
|
"//tensorflow/c/experimental/gradients:math_grad",
|
||||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||||
@ -349,8 +342,6 @@ cc_library(
|
|||||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:protos_all_cc",
|
"//tensorflow/core:protos_all_cc",
|
||||||
"//tensorflow/core:test",
|
|
||||||
"//tensorflow/core:test_main",
|
|
||||||
"//tensorflow/core/lib/llvm_rtti",
|
"//tensorflow/core/lib/llvm_rtti",
|
||||||
"@com_google_absl//absl/strings",
|
"@com_google_absl//absl/strings",
|
||||||
"@com_google_absl//absl/types:span",
|
"@com_google_absl//absl/types:span",
|
||||||
|
@ -19,12 +19,10 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
#include "tensorflow/c/eager/gradients_internal.h"
|
#include "tensorflow/c/eager/gradients_internal.h"
|
||||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
|
||||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||||
@ -32,7 +30,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace gradients {
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
@ -197,4 +197,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
|||||||
}
|
}
|
||||||
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace gradients
|
||||||
|
} // namespace tensorflow
|
@ -17,13 +17,11 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
#include "tensorflow/c/eager/gradients_internal.h"
|
#include "tensorflow/c/eager/gradients_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
|
||||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||||
@ -31,7 +29,9 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace gradients {
|
||||||
|
|
||||||
using Model = std::function<Status(
|
using Model = std::function<Status(
|
||||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||||
@ -53,3 +53,6 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
|||||||
std::vector<AbstractTensorHandle*> inputs,
|
std::vector<AbstractTensorHandle*> inputs,
|
||||||
float* dtheta_approx, int input_index,
|
float* dtheta_approx, int input_index,
|
||||||
bool use_function, bool is_scalar_out = false);
|
bool use_function, bool is_scalar_out = false);
|
||||||
|
|
||||||
|
} // namespace gradients
|
||||||
|
} // namespace tensorflow
|
||||||
|
@ -10,14 +10,13 @@ See the License for the specific language governing permissions and
|
|||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/c/eager/gradient_checker.h"
|
#include "tensorflow/c/eager/gradient_checker.h"
|
||||||
// #include "tensorflow/c/eager/gradients_testutil.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
@ -109,7 +108,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
|||||||
AbstractTensorHandlePtr x;
|
AbstractTensorHandlePtr x;
|
||||||
{
|
{
|
||||||
AbstractTensorHandle* x_raw = nullptr;
|
AbstractTensorHandle* x_raw = nullptr;
|
||||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
x.reset(x_raw);
|
x.reset(x_raw);
|
||||||
}
|
}
|
||||||
@ -117,7 +116,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
|
|||||||
AbstractTensorHandlePtr y;
|
AbstractTensorHandlePtr y;
|
||||||
{
|
{
|
||||||
AbstractTensorHandle* y_raw = nullptr;
|
AbstractTensorHandle* y_raw = nullptr;
|
||||||
Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
|
||||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||||
y.reset(y_raw);
|
y.reset(y_raw);
|
||||||
}
|
}
|
||||||
|
@ -12,14 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||||||
See the License for the specific language governing permissions and
|
See the License for the specific language governing permissions and
|
||||||
limitations under the License.
|
limitations under the License.
|
||||||
==============================================================================*/
|
==============================================================================*/
|
||||||
#include "tensorflow/c/eager/gradients_testutil.h"
|
#include "tensorflow/c/eager/gradients_util.h"
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
@ -31,28 +30,66 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace gradients {
|
||||||
|
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
// ================== TensorHandle generating functions =================
|
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value) {
|
||||||
|
float data[] = {value};
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
|
||||||
|
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||||
|
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return th;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||||
|
int64_t dims[], int num_dims) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_Tensor* t =
|
||||||
|
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
|
||||||
|
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||||
|
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return th;
|
||||||
|
}
|
||||||
|
|
||||||
|
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||||
|
int64_t dims[], int num_dims) {
|
||||||
|
TF_Status* status = TF_NewStatus();
|
||||||
|
TF_Tensor* t =
|
||||||
|
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
|
||||||
|
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||||
|
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||||
|
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||||
|
TF_DeleteTensor(t);
|
||||||
|
TF_DeleteStatus(status);
|
||||||
|
return th;
|
||||||
|
}
|
||||||
|
|
||||||
// Get a scalar TensorHandle with given value
|
// Get a scalar TensorHandle with given value
|
||||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||||
AbstractTensorHandle** tensor) {
|
AbstractTensorHandle** tensor) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
TF_NewStatus(), TF_DeleteStatus);
|
TF_NewStatus(), TF_DeleteStatus);
|
||||||
TFE_Context* eager_ctx =
|
TFE_Context* eager_ctx =
|
||||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
TFE_TensorHandle* input_eager = ScalarTensorHandleHelper(eager_ctx, value);
|
||||||
*tensor =
|
*tensor =
|
||||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||||
return StatusFromTF_Status(status.get());
|
return StatusFromTF_Status(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a TensorHandle with given float values and dimensions
|
// Get a TensorHandle with given float values and dimensions
|
||||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||||
int64_t dims[], int num_dims,
|
int64_t dims[], int num_dims,
|
||||||
AbstractTensorHandle** tensor) {
|
AbstractTensorHandle** tensor) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
@ -61,14 +98,14 @@ Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
|||||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
TFE_TensorHandle* input_eager =
|
TFE_TensorHandle* input_eager =
|
||||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, num_dims);
|
||||||
*tensor =
|
*tensor =
|
||||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||||
return StatusFromTF_Status(status.get());
|
return StatusFromTF_Status(status.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get a TensorHandle with given int values and dimensions
|
// Get a TensorHandle with given int values and dimensions
|
||||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||||
int64_t dims[], int num_dims,
|
int64_t dims[], int num_dims,
|
||||||
AbstractTensorHandle** tensor) {
|
AbstractTensorHandle** tensor) {
|
||||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||||
@ -77,7 +114,7 @@ Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
|||||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
TFE_TensorHandle* input_eager =
|
TFE_TensorHandle* input_eager =
|
||||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
TensorHandleWithDimsIntHelper(eager_ctx, data, dims, num_dims);
|
||||||
*tensor =
|
*tensor =
|
||||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||||
return StatusFromTF_Status(status.get());
|
return StatusFromTF_Status(status.get());
|
||||||
@ -98,7 +135,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
|||||||
int num_dims) {
|
int num_dims) {
|
||||||
AbstractTensorHandlePtr A;
|
AbstractTensorHandlePtr A;
|
||||||
AbstractTensorHandle* a_raw = nullptr;
|
AbstractTensorHandle* a_raw = nullptr;
|
||||||
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
A.reset(a_raw);
|
A.reset(a_raw);
|
||||||
}
|
}
|
||||||
@ -109,7 +146,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
|||||||
int64_t dims[], int num_dims) {
|
int64_t dims[], int num_dims) {
|
||||||
AbstractTensorHandlePtr A;
|
AbstractTensorHandlePtr A;
|
||||||
AbstractTensorHandle* a_raw = nullptr;
|
AbstractTensorHandle* a_raw = nullptr;
|
||||||
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
A.reset(a_raw);
|
A.reset(a_raw);
|
||||||
}
|
}
|
||||||
@ -120,7 +157,7 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
|||||||
float val) {
|
float val) {
|
||||||
AbstractTensorHandlePtr y;
|
AbstractTensorHandlePtr y;
|
||||||
AbstractTensorHandle* y_raw = nullptr;
|
AbstractTensorHandle* y_raw = nullptr;
|
||||||
Status s = TestScalarTensorHandle(ctx, val, &y_raw);
|
Status s = ScalarTensorHandle(ctx, val, &y_raw);
|
||||||
if (s.ok()) {
|
if (s.ok()) {
|
||||||
y.reset(y_raw);
|
y.reset(y_raw);
|
||||||
}
|
}
|
||||||
@ -268,4 +305,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
|||||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||||
TFE_DeleteContextOptions(opts);
|
TFE_DeleteContextOptions(opts);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace gradients
|
||||||
|
} // namespace tensorflow
|
@ -18,7 +18,6 @@ limitations under the License.
|
|||||||
#include "absl/types/span.h"
|
#include "absl/types/span.h"
|
||||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||||
#include "tensorflow/c/eager/gradients.h"
|
#include "tensorflow/c/eager/gradients.h"
|
||||||
@ -30,24 +29,37 @@ limitations under the License.
|
|||||||
#include "tensorflow/c/tf_tensor.h"
|
#include "tensorflow/c/tf_tensor.h"
|
||||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||||
#include "tensorflow/core/platform/errors.h"
|
#include "tensorflow/core/platform/errors.h"
|
||||||
#include "tensorflow/core/platform/test.h"
|
#include "tensorflow/c/eager/c_api.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
using namespace std;
|
|
||||||
using namespace tensorflow;
|
|
||||||
using namespace tensorflow::gradients;
|
|
||||||
using namespace tensorflow::gradients::internal;
|
|
||||||
|
|
||||||
// Get a scalar TensorHandle with given value.
|
// using namespace std;
|
||||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
// using namespace tensorflow;
|
||||||
|
// using namespace tensorflow::gradients;
|
||||||
|
// using namespace tensorflow::gradients::internal;
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace gradients {
|
||||||
|
|
||||||
|
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
|
||||||
|
|
||||||
|
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||||
|
int64_t dims[], int num_dims);
|
||||||
|
|
||||||
|
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||||
|
int64_t dims[], int num_dims);
|
||||||
|
|
||||||
|
// Get a scalar TensorHandle with given value
|
||||||
|
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||||
AbstractTensorHandle** tensor);
|
AbstractTensorHandle** tensor);
|
||||||
|
|
||||||
// Get a TensorHandle with given float values and dimensions.
|
// Get a TensorHandle with given float values and dimensions
|
||||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||||
int64_t dims[], int num_dims,
|
int64_t dims[], int num_dims,
|
||||||
AbstractTensorHandle** tensor);
|
AbstractTensorHandle** tensor);
|
||||||
|
|
||||||
// Get a TensorHandle with given int values and dimensions.
|
// Get a TensorHandle with given int values and dimensions
|
||||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
|
||||||
int64_t dims[], int num_dims,
|
int64_t dims[], int num_dims,
|
||||||
AbstractTensorHandle** tensor);
|
AbstractTensorHandle** tensor);
|
||||||
|
|
||||||
@ -68,8 +80,8 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
|||||||
float val);
|
float val);
|
||||||
|
|
||||||
// Performs gradient update for each weight using given learning rate.
|
// Performs gradient update for each weight using given learning rate.
|
||||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
Status UpdateWeights(AbstractContext* ctx, std::vector<AbstractTensorHandle*>& grads,
|
||||||
vector<AbstractTensorHandle*>& weights,
|
std::vector<AbstractTensorHandle*>& weights,
|
||||||
AbstractTensorHandle* learning_rate);
|
AbstractTensorHandle* learning_rate);
|
||||||
|
|
||||||
// Helper function for RunModel to build the function for graph mode.
|
// Helper function for RunModel to build the function for graph mode.
|
||||||
@ -78,7 +90,7 @@ AbstractContext* BuildFunction(const char* fn_name);
|
|||||||
// Helper function for RunModel to add params for graph mode.
|
// Helper function for RunModel to add params for graph mode.
|
||||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||||
absl::Span<AbstractTensorHandle* const> inputs,
|
absl::Span<AbstractTensorHandle* const> inputs,
|
||||||
vector<AbstractTensorHandle*>* params);
|
std::vector<AbstractTensorHandle*>* params);
|
||||||
|
|
||||||
using Model = std::function<Status(
|
using Model = std::function<Status(
|
||||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||||
@ -91,4 +103,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
|||||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||||
const GradientRegistry& registry);
|
const GradientRegistry& registry);
|
||||||
|
|
||||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||||
|
|
||||||
|
} // namespace gradients
|
||||||
|
} // namespace tensorflow
|
Loading…
Reference in New Issue
Block a user