rebase update
This commit is contained in:
parent
347607e350
commit
fc3cd94b3c
@ -278,121 +278,8 @@ cc_library(
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients",
|
||||
srcs = [
|
||||
"mnist_gradients.cc",
|
||||
"mnist_gradients.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_util",
|
||||
":mnist_gradients",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_util",
|
||||
srcs = [
|
||||
"mnist_gradients_util.cc",
|
||||
"mnist_gradients_util.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients",
|
||||
srcs = [
|
||||
"mnist_gradients.cc",
|
||||
"mnist_gradients.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
@ -434,92 +321,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_util",
|
||||
srcs = [
|
||||
"mnist_gradients_util.cc",
|
||||
"mnist_gradients_util.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients",
|
||||
srcs = [
|
||||
"mnist_gradients.cc",
|
||||
"mnist_gradients.h",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
":mnist_gradients_util",
|
||||
"//tensorflow/core/common_runtime/eager:attr_builder",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
extra_copts = tfe_xla_copts(),
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["nomac"],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_test_util",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":mnist_gradients_util",
|
||||
":mnist_gradients",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
hdrs = ["abstract_tensor_handle.h"],
|
||||
|
@ -49,49 +49,6 @@ class CppGradients
|
||||
}
|
||||
};
|
||||
|
||||
// // Creates an Identity op.
|
||||
// Status Identity(AbstractContext* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
// AbstractOperationPtr identity_op(ctx->CreateOperation());
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// identity_op->Reset("Identity", /*raw_device_name=*/nullptr));
|
||||
// if (isa<tracing::TracingOperation>(identity_op.get())) {
|
||||
// TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingOperation>(identity_op.get())
|
||||
// ->SetOpName(name));
|
||||
// }
|
||||
// TF_RETURN_IF_ERROR(identity_op->AddInput(inputs[0]));
|
||||
// int num_retvals = 1;
|
||||
// TF_RETURN_IF_ERROR(identity_op->Execute(outputs, &num_retvals));
|
||||
// return Status::OK();
|
||||
// }
|
||||
|
||||
// // =================== Register gradients for Add ============================
|
||||
// class AddGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// explicit AddGradientFunction(AbstractContext* ctx) : ctx_(ctx) {}
|
||||
// Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
// absl::MakeSpan(identity_outputs), "Id0"));
|
||||
// (*grad_outputs)[0] = identity_outputs[0];
|
||||
// TF_RETURN_IF_ERROR(Identity(ctx_, {grad_inputs[0]},
|
||||
// absl::MakeSpan(identity_outputs), "Id1"));
|
||||
// (*grad_outputs)[1] = identity_outputs[0];
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~AddGradientFunction() override {}
|
||||
|
||||
// private:
|
||||
// AbstractContext* ctx_;
|
||||
// };
|
||||
|
||||
// GradientFunction* AddRegisterer(const ForwardOperation& op) {
|
||||
// return new AddGradientFunction(op.ctx);
|
||||
// }
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
@ -99,9 +56,6 @@ Status RegisterGradients(GradientRegistry* registry) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
// // =================== End gradient registrations ============================
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
|
@ -1,21 +1,14 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_util.h"
|
||||
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
@ -48,6 +41,15 @@ class CppGradients
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("SparseSoftmaxCrossEntropyWithLogits", SparseSoftmaxCrossEntropyLossRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ========================= Test Util Functions ==============================
|
||||
void printArr(float data[], int n) {
|
||||
std::cout << std::endl << "[";
|
||||
@ -185,7 +187,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientAdd(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
@ -207,7 +209,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[0]->Release();
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
|
||||
@ -215,7 +217,7 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Release();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
}
|
||||
|
||||
@ -243,7 +245,7 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
getMatrixTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientMatMul(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
@ -286,201 +288,12 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
TF_DeleteTensor(dB_tensor);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false, /*transpose_b=*/false, registry)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Release();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
// TODO: fix graph mode test by using RunModel to verify
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s = BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals [] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims [] = {2, 2};
|
||||
float B_vals [] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims [] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A = getMatrixTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B = getMatrixTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientMatMul(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(A)
|
||||
* tape.watch(B)
|
||||
* Y = AB
|
||||
* outputs = tape.gradient(Y, [A, B])
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = getValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor), TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float expected_dA [4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-3;
|
||||
for(int j = 0; j < 4; j++){
|
||||
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(inputs[0])); // Watch x.
|
||||
tape->Watch(ToId(inputs[1])); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, inputs, absl::MakeSpan(mm_outputs),
|
||||
"matmul0", /*transpose_a=*/false, /*transpose_b=*/false, registry)); // Compute x*y.
|
||||
|
||||
std::unordered_map<tensorflow::int64, TapeTensor>
|
||||
source_tensors_that_are_targets;
|
||||
|
||||
std::vector<AbstractTensorHandle*> out_grads;
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(
|
||||
vspace, /*target_tensor_ids=*/{ToId(mm_outputs[0])},
|
||||
/*source_tensor_ids=*/{ToId(inputs[0]), ToId(inputs[1])},
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Release();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
// TODO: fix graph mode test by using RunModel to verify
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals [] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims [] = {2, 2};
|
||||
float B_vals [] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims [] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A = getMatrixTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B = getMatrixTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientMatMul(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(A)
|
||||
// tape.watch(B)
|
||||
// Y = AB
|
||||
// outputs = tape.gradient(Y, [A, B])
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
// s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
|
||||
// absl::MakeSpan(outputs),
|
||||
// /*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
// ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
s = MatMulGradModel(ctx.get(), {A.get(), B.get()}, absl::MakeSpan(outputs), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// s = MatMulGradModel(ctx.get(), {A.get(), B.get()}, absl::MakeSpan(outputs), registry);
|
||||
// ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = getValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor), TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float expected_dA [4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-3;
|
||||
for(int j = 0; j < 4; j++){
|
||||
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTForward) {
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
@ -551,8 +364,8 @@ TEST_P(CppGradients, TestMNISTForward) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
@ -628,37 +441,12 @@ TEST_P(CppGradients, TestMNISTForward2) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
|
||||
// Test Model to see if transpose attributes are working
|
||||
Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TapeVSpace vspace(ctx);
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(ToId(X));
|
||||
tape->Watch(ToId(W1)); // Watch W1.
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx, tape, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0",/*transpose_a=*/true,/*transpose_b=*/false, registry)); // Compute X*W1
|
||||
|
||||
outputs[0] = temp_outputs[0];
|
||||
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// TODO: fix graph mode test by using RunModel to verify
|
||||
TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -707,8 +495,7 @@ TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
|
||||
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
|
||||
float tolerance = 1e-3;
|
||||
|
||||
for(int j = 0; j < 6; j++){
|
||||
for (int j = 0; j < 6; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
}
|
||||
@ -734,7 +521,7 @@ TEST_P(CppGradients, TestReluGrad) {
|
||||
getMatrixTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientRelu(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
@ -762,7 +549,7 @@ TEST_P(CppGradients, TestReluGrad) {
|
||||
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
@ -794,7 +581,7 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
getMatrixTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientSparseSoftmaxCrossEntropyLoss(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
@ -829,8 +616,8 @@ TEST_P(CppGradients, TestSoftmaxLossGrad) {
|
||||
ASSERT_NEAR(result_data[j], expected_dX[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(dX_tensor);
|
||||
}
|
||||
|
||||
@ -873,9 +660,7 @@ TEST_P(CppGradients, TestMNISTGrad) {
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientMatMul(®istry);
|
||||
s = RegisterGradientRelu(®istry);
|
||||
s = RegisterGradientSparseSoftmaxCrossEntropyLoss(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
@ -924,9 +709,9 @@ TEST_P(CppGradients, TestMNISTGrad) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[1]->Release();
|
||||
outputs[2]->Release();
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
outputs[2]->Unref();
|
||||
TF_DeleteTensor(dW1_tensor);
|
||||
TF_DeleteTensor(dW2_tensor);
|
||||
}
|
||||
@ -980,7 +765,7 @@ TEST_P(CppGradients, TestScalarMul) {
|
||||
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Release();
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
}
|
||||
|
||||
@ -1024,9 +809,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradientMatMul(®istry);
|
||||
s = RegisterGradientRelu(®istry);
|
||||
s = RegisterGradientSparseSoftmaxCrossEntropyLoss(®istry);
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Prepare for training
|
||||
@ -1040,7 +823,7 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Train
|
||||
int num_iters = 100;
|
||||
int num_iters = 10;
|
||||
std::vector<AbstractTensorHandle*> mnist_outputs(3);
|
||||
std::vector<AbstractTensorHandle*> grads(2);
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
@ -1075,9 +858,9 @@ TEST_P(CppGradients, TestMNIST_Training) {
|
||||
TF_DeleteTensor(loss_tensor);
|
||||
}
|
||||
|
||||
grads[0]->Release();
|
||||
grads[1]->Release();
|
||||
mnist_outputs[2]->Release();
|
||||
grads[0]->Unref();
|
||||
grads[1]->Unref();
|
||||
mnist_outputs[2]->Unref();
|
||||
}
|
||||
|
||||
// TODO(b/160888630): Enable this test with mlir after AddInputList is
|
||||
@ -1098,4 +881,4 @@ INSTANTIATE_TEST_SUITE_P(
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
} // namespace tensorflow
|
@ -36,96 +36,119 @@ Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
// AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
// ForwardOperation forward_op;
|
||||
// forward_op.ctx = ctx;
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
// if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
// }
|
||||
// TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
// TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
// int num_retvals = 1;
|
||||
// return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
// registry);
|
||||
// }
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(add_op.get(), "Add", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName("my_add"));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(add_op.get(), inputs[1], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(add_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// // Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
// Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
// bool transpose_a, bool transpose_b,
|
||||
// const GradientRegistry& registry) {
|
||||
// Computes `inputs[0] * inputs[1]` for matrices and records it on the tape.
|
||||
Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
// AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
// ForwardOperation forward_op;
|
||||
// forward_op.ctx = ctx;
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// Reset(matmul_op.get(), "MatMul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
// if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
// }
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(matmul_op.get(), "MatMul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
// TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
// TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
// matmul_op->SetAttrBool("transpose_a",transpose_a);
|
||||
// matmul_op->SetAttrBool("transpose_b",transpose_b);
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(matmul_op.get(), inputs[1], &forward_op));
|
||||
matmul_op->SetAttrBool("transpose_a",transpose_a);
|
||||
matmul_op->SetAttrBool("transpose_b",transpose_b);
|
||||
|
||||
// int num_retvals = 1;
|
||||
// return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
// registry);
|
||||
// }
|
||||
int num_retvals = 1;
|
||||
return Execute(matmul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// // Computes `Relu(inputs[0])` and records it on the tape.
|
||||
// Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
// const GradientRegistry& registry) {
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractOperationPtr mul_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(mul_op.get(), "Mul", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(mul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(mul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[0], &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(mul_op.get(), inputs[1], &forward_op));
|
||||
|
||||
int num_retvals = 1;
|
||||
return Execute(mul_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
// AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
// ForwardOperation forward_op;
|
||||
// forward_op.ctx = ctx;
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
// if (isa<tracing::TracingOperation>(relu_op.get())) {
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// dyn_cast<tracing::TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
// }
|
||||
// TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
// int num_retvals = 1;
|
||||
// return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
// registry);
|
||||
// }
|
||||
AbstractOperationPtr relu_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(relu_op.get(), "Relu", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(relu_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(relu_op.get())->SetOpName(name));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(AddInput(relu_op.get(), inputs[0], &forward_op));
|
||||
int num_retvals = 1;
|
||||
return Execute(relu_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
// // Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the tape.
|
||||
// Status SparseSoftmaxCrossEntropyLoss(AbstractContext* ctx, Tape* tape,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
// const GradientRegistry& registry) {
|
||||
// Computes `SoftmaxLoss(scores, labels)` for matrices and records it on the tape.
|
||||
Status SparseSoftmaxCrossEntropyLoss(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry) {
|
||||
|
||||
// AbstractTensorHandle* scores = inputs[0];
|
||||
// AbstractTensorHandle* labels = inputs[1];
|
||||
AbstractTensorHandle* scores = inputs[0];
|
||||
AbstractTensorHandle* labels = inputs[1];
|
||||
|
||||
// AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
// ForwardOperation forward_op;
|
||||
// forward_op.ctx = ctx;
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", /*raw_device_name=*/nullptr, &forward_op));
|
||||
// if (isa<tracing::TracingOperation>(sm_op.get())) {
|
||||
// TF_RETURN_IF_ERROR(
|
||||
// dyn_cast<tracing::TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
// }
|
||||
AbstractOperationPtr sm_op(ctx->CreateOperation());
|
||||
ForwardOperation forward_op;
|
||||
forward_op.ctx = ctx;
|
||||
TF_RETURN_IF_ERROR(
|
||||
Reset(sm_op.get(), "SparseSoftmaxCrossEntropyWithLogits", /*raw_device_name=*/nullptr, &forward_op));
|
||||
if (isa<tracing::TracingOperation>(sm_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(sm_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
// TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
// TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), scores, &forward_op));
|
||||
TF_RETURN_IF_ERROR(AddInput(sm_op.get(), labels, &forward_op));
|
||||
|
||||
// int num_retvals = 2; // returns loss values and backprop
|
||||
// return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
// registry);
|
||||
// }
|
||||
int num_retvals = 2; // returns loss values and backprop
|
||||
return Execute(sm_op.get(), ctx, outputs, &num_retvals, &forward_op, tape,
|
||||
registry);
|
||||
}
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
@ -153,7 +176,7 @@ Status AddGradModel(AbstractContext* ctx,
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Release();
|
||||
add_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
@ -187,7 +210,7 @@ Status MatMulGradModel(AbstractContext* ctx,
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Release();
|
||||
mm_output->Unref();
|
||||
}
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
@ -297,7 +320,7 @@ Status ReluGradModel(AbstractContext* ctx,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
|
||||
for (auto relu_output : relu_outputs) {
|
||||
relu_output->Release();
|
||||
relu_output->Unref();
|
||||
}
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
@ -328,9 +351,9 @@ Status SoftmaxLossGradModel(AbstractContext* ctx,
|
||||
source_tensors_that_are_targets,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
|
||||
for (auto sm_output : sm_outputs) {
|
||||
sm_output->Release();
|
||||
}
|
||||
// for (auto sm_output : sm_outputs) {
|
||||
// sm_output->Unref();
|
||||
// }
|
||||
|
||||
outputs[0] = out_grads[0];
|
||||
outputs[1] = out_grads[1];
|
||||
@ -373,7 +396,6 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
// std::vector<AbstractTensorHandle*> loss_outputs(2);
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx, tape, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss", registry)); // W2*Relu(X*W1)
|
||||
@ -391,7 +413,7 @@ Status MNISTGradModel(AbstractContext* ctx,
|
||||
/*output_gradients=*/{}, &out_grads));
|
||||
|
||||
// Only release 2nd temp output as first holds loss values.
|
||||
temp_outputs[1]->Release();
|
||||
// temp_outputs[1]->Unref();
|
||||
|
||||
outputs[0] = out_grads[0]; // dW1
|
||||
outputs[1] = out_grads[1]; // dW2
|
||||
@ -499,7 +521,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::MakeSpan(output_list.outputs), registry));
|
||||
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Release();
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
@ -507,7 +529,7 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
scoped_func.reset(func);
|
||||
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Release();
|
||||
output->Unref();
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
|
@ -32,32 +32,8 @@ using namespace tensorflow;
|
||||
using namespace tensorflow::gradients;
|
||||
using namespace tensorflow::gradients::internal;
|
||||
|
||||
// Creates an Identity op.
|
||||
// Status Identity(AbstractContext* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
// // Creates a MatMul op used for the MatMulGradient
|
||||
// Status MatMul(AbstractContext* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
// bool transpose_a, bool transpose_b);
|
||||
|
||||
// // Creates a ReluGrad op used for the ReluGradient
|
||||
// Status ReluGrad(AbstractContext* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs,
|
||||
// const char* name);
|
||||
|
||||
// // Creates a SmCrossEntropyLoss op used for the SoftmaxLossGradient
|
||||
// Status SparseSoftmaxCrossEntropyLoss(AbstractContext* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> inputs,
|
||||
// absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
|
||||
|
||||
// ========================== tape ==============================
|
||||
|
||||
|
||||
// Computes `inputs[0] + inputs[1]` and records it on the tape.
|
||||
Status Add(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
@ -71,6 +47,12 @@ Status MatMul(AbstractContext* ctx, Tape* tape,
|
||||
bool transpose_a, bool transpose_b,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `inputs[0] * inputs[1]` and records it on the tape.
|
||||
Status Mul(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes `Relu(inputs[0])` and records it on the tape.
|
||||
Status Relu(AbstractContext* ctx, Tape* tape,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
|
@ -37,6 +37,7 @@ cc_library(
|
||||
"//tensorflow/c/eager:gradients",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
],
|
||||
)
|
||||
|
@ -12,17 +12,197 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
// #include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
// #include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
// #include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
// #include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
|
||||
// using std::vector;
|
||||
// using tensorflow::ops::Conj;
|
||||
// using tensorflow::ops::Identity;
|
||||
// using tensorflow::ops::Mul;
|
||||
// using tensorflow::ops::MatMul;
|
||||
// using tensorflow::ops::ReluGrad;
|
||||
// using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
|
||||
|
||||
// namespace tensorflow {
|
||||
// namespace gradients {
|
||||
// namespace {
|
||||
|
||||
// class AddGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// grad_outputs->resize(2);
|
||||
// vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// // TODO(b/145674566): Handle name unification in tracing code.
|
||||
// // TODO(b/161805092): Support broadcasting.
|
||||
|
||||
// std::string name = "Identity_A_" + std::to_string(counter);
|
||||
// TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
// absl::MakeSpan(identity_outputs),
|
||||
// name.c_str()));
|
||||
// (*grad_outputs)[0] = identity_outputs[0];
|
||||
|
||||
// name = "Identity_B_" + std::to_string(counter);
|
||||
// TF_RETURN_IF_ERROR(ops::Identity(ctx->ctx, {grad_inputs[0]},
|
||||
// absl::MakeSpan(identity_outputs),
|
||||
// name.c_str()));
|
||||
// (*grad_outputs)[1] = identity_outputs[0];
|
||||
|
||||
// counter += 1;
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~AddGradientFunction() override {}
|
||||
|
||||
// private:
|
||||
// long counter;
|
||||
// };
|
||||
|
||||
|
||||
|
||||
// class MatMulGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// explicit MatMulGradientFunction(std::vector<AbstractTensorHandle*> f_inputs)
|
||||
// : forward_inputs(f_inputs) {}
|
||||
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// /* Given upstream grad U and a matmul op A*B, the gradients are:
|
||||
// *
|
||||
// * dA = U * B.T
|
||||
// * dB = A.T * U
|
||||
// *
|
||||
// * where A.T means `transpose(A)`
|
||||
// */
|
||||
|
||||
// AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
// grad_outputs->resize(2);
|
||||
// std::vector<AbstractTensorHandle*> matmul_outputs(1);
|
||||
|
||||
// // Gradient for A
|
||||
// std::string name = "mm_A_" + std::to_string(counter);
|
||||
// TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, forward_inputs[1]},
|
||||
// absl::MakeSpan(matmul_outputs), name.c_str(),
|
||||
// /*transpose_a = */ false,
|
||||
// /*transpose_b = */ true));
|
||||
|
||||
// (*grad_outputs)[0] = matmul_outputs[0];
|
||||
|
||||
// // Gradient for B
|
||||
// name = "mm_B_" + std::to_string(counter);
|
||||
// TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {forward_inputs[0], upstream_grad},
|
||||
// absl::MakeSpan(matmul_outputs), name.c_str(),
|
||||
// /*transpose_a = */ true,
|
||||
// /*transpose_b = */ false));
|
||||
|
||||
// (*grad_outputs)[1] = matmul_outputs[0];
|
||||
|
||||
// counter += 1; // update counter for names
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~MatMulGradientFunction() override {}
|
||||
|
||||
// private:
|
||||
// long counter;
|
||||
// std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
// };
|
||||
|
||||
// class ReluGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// explicit ReluGradientFunction(std::vector<AbstractTensorHandle*> f_inputs)
|
||||
// : forward_inputs(f_inputs) {}
|
||||
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
// AbstractTensorHandle* input_features = forward_inputs[0];
|
||||
// grad_outputs->resize(1);
|
||||
// std::vector<AbstractTensorHandle*> relugrad_outputs(1);
|
||||
|
||||
// // Calculate Grad
|
||||
// std::string name = "relu_grad" + std::to_string(counter);
|
||||
|
||||
// TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, input_features},
|
||||
// absl::MakeSpan(relugrad_outputs),
|
||||
// name.c_str()));
|
||||
|
||||
// (*grad_outputs)[0] = relugrad_outputs[0];
|
||||
|
||||
// counter += 1;
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~ReluGradientFunction() override {}
|
||||
|
||||
// private:
|
||||
// long counter;
|
||||
// std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
// };
|
||||
|
||||
// class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
// public:
|
||||
// explicit SparseSoftmaxCrossEntropyLossGradientFunction(
|
||||
// std::vector<AbstractTensorHandle*> f_inputs,
|
||||
// std::vector<AbstractTensorHandle*> f_outputs)
|
||||
// : forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
|
||||
// Status Compute(Context* ctx,
|
||||
// absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
// std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// // Forward Inputs : [scores, labels]
|
||||
|
||||
// grad_outputs->resize(2);
|
||||
// std::vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
|
||||
// // Calculate Grad
|
||||
// std::string name = "sm_loss" + std::to_string(counter);
|
||||
|
||||
// TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
// ctx->ctx, {forward_inputs[0], forward_inputs[1]},
|
||||
// absl::MakeSpan(sm_outputs), name.c_str()));
|
||||
|
||||
// // TODO(amturati): fix error where we have to return the softmax loss as the
|
||||
// // 2nd grad for the labels to avoid mangled stack trace. Also avoid running
|
||||
// // forward operation again, check to see if forward_outputs are being
|
||||
// // passed.
|
||||
|
||||
// // SparseSoftmaxCrossEntropyLoss returns [loss_vals, grads], so return 2nd
|
||||
// // output.
|
||||
// (*grad_outputs)[0] = sm_outputs[1]; // return backprop for scores
|
||||
// (*grad_outputs)[1] = sm_outputs[0]; // nullptr causes Mangled Stack Trace
|
||||
|
||||
// counter += 1;
|
||||
// return Status::OK();
|
||||
// }
|
||||
// ~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
|
||||
|
||||
// private:
|
||||
// long counter;
|
||||
// std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
// std::vector<AbstractTensorHandle*> forward_outputs;
|
||||
// };
|
||||
|
||||
// } // namespace
|
||||
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
|
||||
using std::vector;
|
||||
using tensorflow::ops::Conj;
|
||||
using tensorflow::ops::Identity;
|
||||
using tensorflow::ops::Mul;
|
||||
using tensorflow::ops::MatMul;
|
||||
using tensorflow::ops::ReluGrad;
|
||||
using tensorflow::ops::SparseSoftmaxCrossEntropyLoss;
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
@ -33,7 +213,7 @@ class AddGradientFunction : public GradientFunction {
|
||||
Status Compute(Context* ctx, const IncomingGradients& grad_inputs,
|
||||
vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
grad_outputs->resize(2);
|
||||
vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
std::vector<AbstractTensorHandle*> identity_outputs(1);
|
||||
// TODO(b/145674566): Handle name unification in tracing code.
|
||||
// TODO(b/161805092): Support broadcasting.
|
||||
|
||||
@ -82,10 +262,11 @@ class ExpGradientFunction : public GradientFunction {
|
||||
|
||||
class MatMulGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit MatMulGradientFunction(AbstractContext* ctx, std::vector<AbstractTensorHandle*> f_inputs) :
|
||||
ctx_(ctx), forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
explicit MatMulGradientFunction(std::vector<AbstractTensorHandle*> f_inputs)
|
||||
: forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
/* Given upstream grad U and a matmul op A*B, the gradients are:
|
||||
*
|
||||
@ -100,16 +281,20 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
std::vector<AbstractTensorHandle*> matmul_outputs(1);
|
||||
|
||||
// Gradient for A
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx_, {upstream_grad, forward_inputs[1]},
|
||||
absl::MakeSpan(matmul_outputs), "mm0",
|
||||
/*transpose_a = */false, /*transpose_b = */true));
|
||||
std::string name = "matm_A_" + std::to_string(counter);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {upstream_grad, forward_inputs[1]},
|
||||
absl::MakeSpan(matmul_outputs), name.c_str(),
|
||||
/*transpose_a = */ false,
|
||||
/*transpose_b = */ true));
|
||||
|
||||
(*grad_outputs)[0] = matmul_outputs[0];
|
||||
|
||||
// Gradient for B
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx_, {forward_inputs[0], upstream_grad},
|
||||
absl::MakeSpan(matmul_outputs), "mm1",
|
||||
/*transpose_a = */true, /*transpose_b = */false));
|
||||
name = "mm_B_" + std::to_string(counter);
|
||||
TF_RETURN_IF_ERROR(MatMul(ctx->ctx, {forward_inputs[0], upstream_grad},
|
||||
absl::MakeSpan(matmul_outputs), name.c_str(),
|
||||
/*transpose_a = */ true,
|
||||
/*transpose_b = */ false));
|
||||
|
||||
(*grad_outputs)[1] = matmul_outputs[0];
|
||||
|
||||
@ -119,27 +304,29 @@ class MatMulGradientFunction : public GradientFunction {
|
||||
~MatMulGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
long counter;
|
||||
std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
};
|
||||
|
||||
class ReluGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit ReluGradientFunction(AbstractContext* ctx, std::vector<AbstractTensorHandle*> f_inputs) :
|
||||
ctx_(ctx), forward_inputs(f_inputs) {}
|
||||
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
explicit ReluGradientFunction(std::vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
AbstractTensorHandle* upstream_grad = grad_inputs[0];
|
||||
AbstractTensorHandle* input_features = forward_inputs[0];
|
||||
AbstractTensorHandle* activations = forward_outputs[0];
|
||||
grad_outputs->resize(1);
|
||||
std::vector<AbstractTensorHandle*> relugrad_outputs(1);
|
||||
|
||||
// Calculate Grad
|
||||
TF_RETURN_IF_ERROR(ReluGrad(ctx_, {upstream_grad, input_features},
|
||||
absl::MakeSpan(relugrad_outputs), "relu_grad"));
|
||||
std::string name = "relu_grad" + std::to_string(counter);
|
||||
|
||||
TF_RETURN_IF_ERROR(ReluGrad(ctx->ctx, {upstream_grad, activations},
|
||||
absl::MakeSpan(relugrad_outputs),
|
||||
name.c_str()));
|
||||
|
||||
(*grad_outputs)[0] = relugrad_outputs[0];
|
||||
|
||||
@ -149,33 +336,31 @@ class ReluGradientFunction : public GradientFunction {
|
||||
~ReluGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
long counter;
|
||||
std::vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
|
||||
class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
public:
|
||||
explicit SparseSoftmaxCrossEntropyLossGradientFunction(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*> f_inputs, std::vector<AbstractTensorHandle*> f_outputs) :
|
||||
ctx_(ctx), forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
explicit SparseSoftmaxCrossEntropyLossGradientFunction(
|
||||
std::vector<AbstractTensorHandle*> f_inputs,
|
||||
std::vector<AbstractTensorHandle*> f_outputs)
|
||||
: forward_inputs(f_inputs), forward_outputs(f_outputs) {}
|
||||
|
||||
Status Compute(Context* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> grad_inputs,
|
||||
std::vector<AbstractTensorHandle*>* grad_outputs) override {
|
||||
// Forward Inputs : [scores, labels]
|
||||
|
||||
grad_outputs->resize(2);
|
||||
std::vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
|
||||
// Calculate Grad
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(ctx_, {forward_inputs[0], forward_inputs[1]},
|
||||
absl::MakeSpan(sm_outputs), "softmax_loss"));
|
||||
// std::vector<AbstractTensorHandle*> sm_outputs(2);
|
||||
|
||||
// Calculate Grad
|
||||
std::string name = "sm_loss" + std::to_string(counter);
|
||||
// // Calculate Grad
|
||||
// std::string name = "sm_loss" + std::to_string(counter);
|
||||
|
||||
TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
ctx->ctx, {forward_inputs[0], forward_inputs[1]},
|
||||
absl::MakeSpan(sm_outputs), name.c_str()));
|
||||
// TF_RETURN_IF_ERROR(SparseSoftmaxCrossEntropyLoss(
|
||||
// ctx->ctx, {forward_inputs[0], forward_inputs[1]},
|
||||
// absl::MakeSpan(sm_outputs), name.c_str()));
|
||||
|
||||
// TODO(amturati): fix error where we have to return the softmax loss as the
|
||||
// 2nd grad for the labels to avoid mangled stack trace. Also avoid running
|
||||
@ -184,8 +369,8 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
|
||||
// SparseSoftmaxCrossEntropyLoss returns [loss_vals, grads], so return 2nd
|
||||
// output.
|
||||
(*grad_outputs)[0] = sm_outputs[1]; // return backprop for scores
|
||||
(*grad_outputs)[1] = sm_outputs[0]; // nullptr causes Mangled Stack Trace
|
||||
(*grad_outputs)[0] = forward_outputs[1]; // sm_outputs[1]; // return backprop for scores
|
||||
(*grad_outputs)[1] = forward_outputs[0]; // nullptr causes Mangled Stack Trace
|
||||
|
||||
counter += 1;
|
||||
return Status::OK();
|
||||
@ -193,7 +378,7 @@ class SparseSoftmaxCrossEntropyLossGradientFunction : public GradientFunction {
|
||||
~SparseSoftmaxCrossEntropyLossGradientFunction() override {}
|
||||
|
||||
private:
|
||||
AbstractContext* ctx_;
|
||||
long counter;
|
||||
std::vector<AbstractTensorHandle*> forward_inputs;
|
||||
std::vector<AbstractTensorHandle*> forward_outputs;
|
||||
};
|
||||
@ -218,5 +403,20 @@ BackwardFunction* ExpRegisterer(const ForwardOperation& op) {
|
||||
return new BackwardFunction(gradient_function, default_gradients);
|
||||
}
|
||||
|
||||
GradientFunction* MatMulRegisterer(const ForwardOperation& op) {
|
||||
return new MatMulGradientFunction(op.inputs);
|
||||
}
|
||||
|
||||
GradientFunction* ReluRegisterer(const ForwardOperation& op) {
|
||||
return new ReluGradientFunction(op.outputs);
|
||||
}
|
||||
|
||||
GradientFunction* SparseSoftmaxCrossEntropyLossRegisterer(
|
||||
const ForwardOperation& op) {
|
||||
return new SparseSoftmaxCrossEntropyLossGradientFunction(op.inputs,
|
||||
op.outputs);
|
||||
}
|
||||
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
@ -15,7 +15,6 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
@ -36,12 +35,30 @@ cc_library(
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":array_ops",
|
||||
"//tensorflow/c/eager:abstract_context",
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core:framework_headers_lib",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "nn_ops",
|
||||
srcs = [
|
||||
"nn_ops.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"nn_ops.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:abstract_operation",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:c_api_unified_internal",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:errors",
|
||||
],
|
||||
|
@ -51,5 +51,47 @@ Status Conj(AbstractContext* ctx,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name) {
|
||||
AbstractOperationPtr add_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(add_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(add_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b) {
|
||||
AbstractOperationPtr matmul_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
|
||||
|
||||
if (isa<tracing::TracingOperation>(matmul_op.get())) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
dyn_cast<tracing::TracingOperation>(matmul_op.get())->SetOpName(name));
|
||||
}
|
||||
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
|
||||
TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
|
||||
|
||||
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
|
||||
TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
|
||||
|
||||
int num_retvals = 1;
|
||||
TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
@ -25,6 +25,13 @@ Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
Status Conj(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name);
|
||||
Status MatMul(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, const char* name,
|
||||
bool transpose_a, bool transpose_b);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tensorflow
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user