Enable half for resource scatter
Call each type Enable sparse adam test in eager mode Use self.evaluate() Update tests for adam Enable half for scatter Use assertAllCloseAccordingToType to pass float16 tests imake linter happy Run pylint Relax tolerance for half
This commit is contained in:
parent
7e279d6b0f
commit
5e28f0c5c8
tensorflow
core/kernels
python/kernel_tests
@ -998,8 +998,8 @@ REGISTER_SCATTER_KERNEL(Variant, CPU, "ResourceScatterUpdate",
|
||||
|
||||
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_MINMAX_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("ResourceScatterUpdate")
|
||||
.Device(DEVICE_GPU)
|
||||
|
@ -55,7 +55,7 @@ namespace functor {
|
||||
DECLARE_GPU_SPECS_INDEX(T, int32); \
|
||||
DECLARE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(DECLARE_GPU_SPECS);
|
||||
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPECS);
|
||||
|
||||
#undef DECLARE_GPU_SPECS
|
||||
#undef DECLARE_GPU_SPECS_INDEX
|
||||
|
@ -40,13 +40,11 @@ typedef Eigen::GpuDevice GPUDevice;
|
||||
DEFINE_GPU_SPECS_INDEX(T, int32); \
|
||||
DEFINE_GPU_SPECS_INDEX(T, int64);
|
||||
|
||||
DEFINE_GPU_SPECS(Eigen::half);
|
||||
DEFINE_GPU_SPECS(float);
|
||||
DEFINE_GPU_SPECS(double);
|
||||
DEFINE_GPU_SPECS_OP(bool, int32, scatter_op::UpdateOp::ASSIGN);
|
||||
DEFINE_GPU_SPECS_OP(bool, int64, scatter_op::UpdateOp::ASSIGN);
|
||||
// TODO(b/27222123): The following fails to compile due to lack of support for
|
||||
// fp16.
|
||||
// TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
|
||||
#undef DEFINE_GPU_SPECS
|
||||
#undef DEFINE_GPU_SPECS_INDEX
|
||||
|
@ -286,9 +286,9 @@ TF_CALL_ALL_TYPES(REGISTER_SCATTER_UPDATE_CPU);
|
||||
|
||||
#define REGISTER_SCATTER_UPDATE_GPU(type) REGISTER_SCATTER_UPDATE(type, GPU);
|
||||
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ARITHMETIC_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_ARITHMETIC_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_MINMAX_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_SCATTER_UPDATE_GPU);
|
||||
|
||||
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
|
||||
|
@ -182,13 +182,14 @@ class ScatterTest(test.TestCase):
|
||||
ref = variables.Variable(old)
|
||||
self.evaluate(ref.initializer)
|
||||
self.evaluate(tf_scatter(ref, indices, updates))
|
||||
self.assertAllClose(self.evaluate(ref), new)
|
||||
self.assertAllCloseAccordingToType(
|
||||
self.evaluate(ref), new, half_rtol=5e-3, half_atol=5e-3)
|
||||
|
||||
def _VariableRankTests(self,
|
||||
tf_scatter,
|
||||
repeat_indices=False,
|
||||
updates_are_scalar=False):
|
||||
vtypes = [np.float32, np.float64]
|
||||
vtypes = [np.float16, np.float32, np.float64]
|
||||
if tf_scatter != state_ops.scatter_div:
|
||||
vtypes.append(np.int32)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user