Fixing bug in cast_op test: A case study in the danger of default arguments.
PiperOrigin-RevId: 306345213 Change-Id: I2dfdaf0565774f22766b5b0f1633ce3c99d416e4
This commit is contained in:
parent
20d0fd0d81
commit
15aeef74da
@ -41,7 +41,7 @@ static Graph* Cast(int num) {
|
|||||||
|
|
||||||
class CastOpTest : public OpsTestBase {
|
class CastOpTest : public OpsTestBase {
|
||||||
protected:
|
protected:
|
||||||
void MakeOp(DataType src, DataType dst, bool trunc = false) {
|
void MakeOp(DataType src, DataType dst, bool trunc) {
|
||||||
if (trunc) {
|
if (trunc) {
|
||||||
TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
|
TF_EXPECT_OK(NodeDefBuilder("cast_op", "Cast")
|
||||||
.Input(FakeInput(src))
|
.Input(FakeInput(src))
|
||||||
@ -61,10 +61,10 @@ class CastOpTest : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename INPUT, typename OUTPUT>
|
template <typename INPUT, typename OUTPUT>
|
||||||
void CheckCast(bool trunc = false) {
|
void CheckCast(bool trunc) {
|
||||||
DataType in_type = DataTypeToEnum<INPUT>::v();
|
DataType in_type = DataTypeToEnum<INPUT>::v();
|
||||||
DataType out_type = DataTypeToEnum<OUTPUT>::v();
|
DataType out_type = DataTypeToEnum<OUTPUT>::v();
|
||||||
MakeOp(in_type, out_type);
|
MakeOp(in_type, out_type, trunc);
|
||||||
AddInputFromArray<INPUT>(TensorShape({1, 2, 2, 1}),
|
AddInputFromArray<INPUT>(TensorShape({1, 2, 2, 1}),
|
||||||
{INPUT(1), INPUT(2), INPUT(3), INPUT(4)});
|
{INPUT(1), INPUT(2), INPUT(3), INPUT(4)});
|
||||||
TF_ASSERT_OK(RunOpKernel());
|
TF_ASSERT_OK(RunOpKernel());
|
||||||
@ -75,9 +75,11 @@ class CastOpTest : public OpsTestBase {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
#define TEST_CAST(in, out) \
|
#define TEST_CAST(in, out) \
|
||||||
TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(); } \
|
TEST_F(CastOpTest, TestCast##_##in##_##out) { CheckCast<in, out>(false); } \
|
||||||
TEST_F(CastOpTest, TestCast2##_##in##_##out) { CheckCast<in, out>(true); }
|
TEST_F(CastOpTest, TestCastTruncate_##_##in##_##out) { \
|
||||||
|
CheckCast<in, out>(true); \
|
||||||
|
}
|
||||||
|
|
||||||
#define TEST_ALL_CASTS_FROM(in) \
|
#define TEST_ALL_CASTS_FROM(in) \
|
||||||
TEST_CAST(in, uint8); \
|
TEST_CAST(in, uint8); \
|
||||||
|
Loading…
x
Reference in New Issue
Block a user