took out test-only functions from test_util and put them in gradients_util

This commit is contained in:
amturati 2020-09-02 20:34:49 +00:00
parent 0fdb700766
commit a17c8e3f99
6 changed files with 108 additions and 57 deletions

View File

@ -249,13 +249,12 @@ tf_cuda_cc_test(
) )
cc_library( cc_library(
name = "gradients_testutil", name = "gradients_util",
testonly = True,
srcs = [ srcs = [
"gradients_testutil.cc", "gradients_util.cc",
], ],
hdrs = [ hdrs = [
"gradients_testutil.h", "gradients_util.h",
], ],
visibility = [ visibility = [
"//tensorflow:internal", "//tensorflow:internal",
@ -266,22 +265,20 @@ cc_library(
":tape", ":tape",
":abstract_tensor_handle", ":abstract_tensor_handle",
":gradients", ":gradients",
":c_api",
":c_api_experimental", ":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal", ":c_api_unified_internal",
":gradients_internal", ":gradients_internal",
"//tensorflow/c/experimental/ops:array_ops", "//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops", "//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops", "//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/cc/profiler", "//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:framework",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -319,7 +316,6 @@ cc_library(
cc_library( cc_library(
name = "gradient_checker", name = "gradient_checker",
testonly = True,
srcs = [ srcs = [
"gradient_checker.cc", "gradient_checker.cc",
], ],
@ -332,13 +328,10 @@ cc_library(
deps = [ deps = [
":abstract_tensor_handle", ":abstract_tensor_handle",
":c_api_experimental", ":c_api_experimental",
":c_api_test_util",
":c_api_unified_internal", ":c_api_unified_internal",
":gradients_internal", ":gradients_internal",
":mnist_gradients_testutil", ":gradients_util",
":gradients_testutil",
"//tensorflow/c:c_api", "//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad", "//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad", "//tensorflow/c/experimental/gradients:nn_grad",
@ -349,8 +342,6 @@ cc_library(
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration", "//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc", "//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti", "//tensorflow/core/lib/llvm_rtti",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",

View File

@ -19,12 +19,10 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
@ -32,7 +30,9 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using namespace std; using namespace std;
@ -197,4 +197,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
} }
return Status::OK(); return Status::OK();
} }
} // namespace gradients
} // namespace tensorflow

View File

@ -17,13 +17,11 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h" #include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_testutil.h" #include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h" #include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h" #include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h" #include "tensorflow/c/experimental/ops/array_ops.h"
@ -31,7 +29,9 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using Model = std::function<Status( using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>, AbstractContext*, absl::Span<AbstractTensorHandle* const>,
@ -53,3 +53,6 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
std::vector<AbstractTensorHandle*> inputs, std::vector<AbstractTensorHandle*> inputs,
float* dtheta_approx, int input_index, float* dtheta_approx, int input_index,
bool use_function, bool is_scalar_out = false); bool use_function, bool is_scalar_out = false);
} // namespace gradients
} // namespace tensorflow

View File

@ -10,14 +10,13 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/eager/gradient_checker.h" #include "tensorflow/c/eager/gradient_checker.h"
// #include "tensorflow/c/eager/gradients_testutil.h" #include "tensorflow/c/eager/gradients_util.h"
#include <memory> #include <memory>
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
@ -109,7 +108,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr x; AbstractTensorHandlePtr x;
{ {
AbstractTensorHandle* x_raw = nullptr; AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw); Status s = ScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw); x.reset(x_raw);
} }
@ -117,7 +116,7 @@ TEST_P(GradientCheckerTest, TestGradCheckMul) {
AbstractTensorHandlePtr y; AbstractTensorHandlePtr y;
{ {
AbstractTensorHandle* y_raw = nullptr; AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 7.0f, &y_raw); Status s = ScalarTensorHandle(ctx.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message(); ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw); y.reset(y_raw);
} }

View File

@ -12,14 +12,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/eager/gradients_testutil.h" #include "tensorflow/c/eager/gradients_util.h"
#include <memory> #include <memory>
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
@ -31,28 +30,66 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
using namespace std; using namespace std;
// ================== TensorHandle generating functions ================= TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value) {
float data[] = {value};
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status);
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
// Get a scalar TensorHandle with given value // Get a scalar TensorHandle with given value
Status TestScalarTensorHandle(AbstractContext* ctx, float value, Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) { AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus); TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx = TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value); TFE_TensorHandle* input_eager = ScalarTensorHandleHelper(eager_ctx, value);
*tensor = *tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get()); return StatusFromTF_Status(status.get());
} }
// Get a TensorHandle with given float values and dimensions // Get a TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims, int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) { AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -61,14 +98,14 @@ Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims); TensorHandleWithDimsFloatHelper(eager_ctx, data, dims, num_dims);
*tensor = *tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get()); return StatusFromTF_Status(status.get());
} }
// Get a TensorHandle with given int values and dimensions // Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims, int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) { AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status( std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
@ -77,7 +114,7 @@ Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[],
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get()); TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims); TensorHandleWithDimsIntHelper(eager_ctx, data, dims, num_dims);
*tensor = *tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get())); unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get()); return StatusFromTF_Status(status.get());
@ -98,7 +135,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
int num_dims) { int num_dims) {
AbstractTensorHandlePtr A; AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr; AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw); Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) { if (s.ok()) {
A.reset(a_raw); A.reset(a_raw);
} }
@ -109,7 +146,7 @@ AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) { int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A; AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr; AbstractTensorHandle* a_raw = nullptr;
Status s = TestTensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw); Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) { if (s.ok()) {
A.reset(a_raw); A.reset(a_raw);
} }
@ -120,7 +157,7 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) { float val) {
AbstractTensorHandlePtr y; AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr; AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx, val, &y_raw); Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) { if (s.ok()) {
y.reset(y_raw); y.reset(y_raw);
} }
@ -268,4 +305,7 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get())); TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts); TFE_DeleteContextOptions(opts);
return Status::OK(); return Status::OK();
} }
} // namespace gradients
} // namespace tensorflow

View File

@ -18,7 +18,6 @@ limitations under the License.
#include "absl/types/span.h" #include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h" #include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h" #include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h" #include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h" #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h" #include "tensorflow/c/eager/gradients.h"
@ -30,24 +29,37 @@ limitations under the License.
#include "tensorflow/c/tf_tensor.h" #include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h" #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/c/eager/c_api.h"
#include "tensorflow/core/platform/types.h"
using namespace std;
using namespace tensorflow;
using namespace tensorflow::gradients;
using namespace tensorflow::gradients::internal;
// Get a scalar TensorHandle with given value. // using namespace std;
Status TestScalarTensorHandle(AbstractContext* ctx, float value, // using namespace tensorflow;
// using namespace tensorflow::gradients;
// using namespace tensorflow::gradients::internal;
namespace tensorflow {
namespace gradients {
TFE_TensorHandle* ScalarTensorHandleHelper(TFE_Context* ctx, float value);
TFE_TensorHandle* TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims);
TFE_TensorHandle* TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor); AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions. // Get a TensorHandle with given float values and dimensions
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float data[], Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims, int64_t dims[], int num_dims,
AbstractTensorHandle** tensor); AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions. // Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int data[], Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[],
int64_t dims[], int num_dims, int64_t dims[], int num_dims,
AbstractTensorHandle** tensor); AbstractTensorHandle** tensor);
@ -68,8 +80,8 @@ AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val); float val);
// Performs gradient update for each weight using given learning rate. // Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads, Status UpdateWeights(AbstractContext* ctx, std::vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights, std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate); AbstractTensorHandle* learning_rate);
// Helper function for RunModel to build the function for graph mode. // Helper function for RunModel to build the function for graph mode.
@ -78,7 +90,7 @@ AbstractContext* BuildFunction(const char* fn_name);
// Helper function for RunModel to add params for graph mode. // Helper function for RunModel to add params for graph mode.
Status CreateParamsForInputs(AbstractContext* ctx, Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs, absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params); std::vector<AbstractTensorHandle*>* params);
using Model = std::function<Status( using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>, AbstractContext*, absl::Span<AbstractTensorHandle* const>,
@ -91,4 +103,7 @@ Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle*> outputs, bool use_function, absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry); const GradientRegistry& registry);
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx); Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow