rename ctx_
to immediate_execution_ctx_
This commit is contained in:
parent
1697abe35d
commit
0b41c93753
@ -93,7 +93,7 @@ class CppGradients
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx_.reset(ctx_raw);
|
||||
immediate_execution_ctx_.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
@ -102,7 +102,7 @@ class CppGradients
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
|
||||
AbstractContextPtr ctx_;
|
||||
AbstractContextPtr immediate_execution_ctx_;
|
||||
GradientRegistry registry_;
|
||||
Status status_;
|
||||
|
||||
@ -115,7 +115,8 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -123,7 +124,8 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -132,15 +134,16 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
AddModel, BuildGradModel(AddModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
AddModel, BuildGradModel(AddModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestExpGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -149,8 +152,8 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
ExpModel, BuildGradModel(ExpModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
ExpModel, BuildGradModel(ExpModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
@ -159,8 +162,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
status_ =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
|
||||
A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
@ -170,8 +173,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr B;
|
||||
{
|
||||
AbstractTensorHandle* B_raw;
|
||||
status_ =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
|
||||
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
|
||||
B_vals, B_dims, 2, &B_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
B.reset(B_raw);
|
||||
}
|
||||
@ -193,8 +196,9 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
// well with `MatMul` and remove `TestMatMul*` in
|
||||
// `mnist_gradients_test` when done.
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
MatMulModel, BuildGradModel(MatMulModel, registry_), ctx_.get(),
|
||||
{A.get(), B.get()}, UseFunction(), /*abs_error*/ 0.4));
|
||||
MatMulModel, BuildGradModel(MatMulModel, registry_),
|
||||
immediate_execution_ctx_.get(), {A.get(), B.get()}, UseFunction(),
|
||||
/*abs_error*/ 0.4f));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -203,7 +207,8 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -212,15 +217,16 @@ TEST_P(CppGradients, TestSqrtGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SqrtModel, BuildGradModel(SqrtModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
SqrtModel, BuildGradModel(SqrtModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestNegGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -229,15 +235,16 @@ TEST_P(CppGradients, TestNegGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
NegModel, BuildGradModel(NegModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
NegModel, BuildGradModel(NegModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSubGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -245,7 +252,8 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -254,15 +262,16 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SubModel, BuildGradModel(SubModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
SubModel, BuildGradModel(SubModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMulGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -270,7 +279,8 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -279,15 +289,16 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
MulModel, BuildGradModel(MulModel, registry_), ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
MulModel, BuildGradModel(MulModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get(), y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestLog1pGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -296,8 +307,8 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
Log1pModel, BuildGradModel(Log1pModel, registry_), ctx_.get(), {x.get()},
|
||||
UseFunction()));
|
||||
Log1pModel, BuildGradModel(Log1pModel, registry_),
|
||||
immediate_execution_ctx_.get(), {x.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
@ -309,7 +320,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -317,26 +329,29 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 2.0f, &y_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
DivNoNanModel, DivNoNanGradModel, ctx_.get(), {x.get(), y.get()},
|
||||
UseFunction()));
|
||||
DivNoNanModel, DivNoNanGradModel, immediate_execution_ctx_.get(),
|
||||
{x.get(), y.get()}, UseFunction()));
|
||||
|
||||
// `DivNoNanGradModel` should return {`0`, `0`} when the denominator is `0`.
|
||||
AbstractTensorHandlePtr z;
|
||||
{
|
||||
AbstractTensorHandle* z_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle(ctx_.get(), 0.0f, &z_raw);
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 0.0f, &z_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
z.reset(z_raw);
|
||||
}
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
status_ = RunModel(DivNoNanGradModel, ctx_.get(), {x.get(), z.get()},
|
||||
absl::MakeSpan(outputs), UseFunction());
|
||||
status_ =
|
||||
RunModel(DivNoNanGradModel, immediate_execution_ctx_.get(),
|
||||
{x.get(), z.get()}, absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {},
|
||||
/*abs_error*/ 0));
|
||||
|
Loading…
Reference in New Issue
Block a user