Rolling back PR #44894 because it caused type check errors (got bfloat16 when expecting float).

PiperOrigin-RevId: 346214899
Change-Id: I291592f14a7e3b087e1c26d29d7ecdaef4bc2fed
This commit is contained in:
Penporn Koanantakool 2020-12-07 17:24:24 -08:00 committed by TensorFlower Gardener
parent 4764f167c2
commit bb59188213
3 changed files with 28 additions and 49 deletions

View File

@ -18,7 +18,6 @@ limitations under the License.
#define EIGEN_USE_THREADS
#include "tensorflow/core/kernels/sparse_xent_op.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor.h"
@ -123,8 +122,6 @@ REGISTER(CPU, float, int32)
REGISTER(CPU, float, int64)
REGISTER(CPU, double, int32)
REGISTER(CPU, double, int64)
REGISTER(CPU, bfloat16, int32)
REGISTER(CPU, bfloat16, int64)
REGISTER(CPU, Eigen::half, int32)
REGISTER(CPU, Eigen::half, int64)

View File

@ -23,9 +23,9 @@ limitations under the License.
namespace tensorflow {
static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
static Graph* SparseXent(int batch_size, int num_classes) {
Graph* g = new Graph(OpRegistry::Global());
Tensor logits(value_type, TensorShape({batch_size, num_classes}));
Tensor logits(DT_FLOAT, TensorShape({batch_size, num_classes}));
logits.flat<float>().setRandom();
Tensor labels(DT_INT64, TensorShape({batch_size}));
std::random_device rd;
@ -41,45 +41,44 @@ static Graph* SparseXent(int batch_size, int num_classes, DataType value_type) {
return g;
}
#define BM_SparseXentDev(BATCH, CLASS, DEVICE, DTYPE) \
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE( \
#define BM_SparseXentDev(BATCH, CLASS, DEVICE) \
static void BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE( \
::testing::benchmark::State& state) { \
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS, DTYPE), \
test::Benchmark(#DEVICE, SparseXent(BATCH, CLASS), \
/*old_benchmark_api*/ false) \
.Run(state); \
state.SetItemsProcessed(static_cast<int64>(state.iterations()) * BATCH * \
CLASS); \
} \
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE##_##DTYPE);
#define BM_SPARSE_XENT_DEV_CPU(DTYPE) \
BM_SparseXentDev(8, 1000000, cpu, DTYPE); \
BM_SparseXentDev(16, 10000, cpu, DTYPE); \
BM_SparseXentDev(16, 100000, cpu, DTYPE); \
BM_SparseXentDev(32, 10000, cpu, DTYPE); \
BM_SparseXentDev(32, 100000, cpu, DTYPE); \
BM_SparseXentDev(64, 10000, cpu, DTYPE); \
BM_SparseXentDev(64, 100000, cpu, DTYPE);
// CPU
BM_SPARSE_XENT_DEV_CPU(DT_FLOAT);
BM_SPARSE_XENT_DEV_CPU(DT_BFLOAT16);
BENCHMARK(BM_SparseXent##_##BATCH##_##CLASS##_##DEVICE);
/// The representative tests for ptb_word on GPU
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
BM_SparseXentDev(8, 1000000, gpu, DT_FLOAT);
BM_SparseXentDev(8, 1000000, gpu);
BM_SparseXentDev(16, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(16, 10000, gpu);
BM_SparseXentDev(16, 30000, gpu);
BM_SparseXentDev(16, 100000, gpu);
BM_SparseXentDev(32, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(32, 10000, gpu);
BM_SparseXentDev(32, 30000, gpu);
BM_SparseXentDev(32, 100000, gpu);
BM_SparseXentDev(64, 10000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 30000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 100000, gpu, DT_FLOAT);
BM_SparseXentDev(64, 10000, gpu);
BM_SparseXentDev(64, 30000, gpu);
BM_SparseXentDev(64, 100000, gpu);
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
// CPU
BM_SparseXentDev(8, 1000000, cpu);
BM_SparseXentDev(16, 10000, cpu);
BM_SparseXentDev(16, 100000, cpu);
BM_SparseXentDev(32, 10000, cpu);
BM_SparseXentDev(32, 100000, cpu);
BM_SparseXentDev(64, 10000, cpu);
BM_SparseXentDev(64, 100000, cpu);
} // end namespace tensorflow

View File

@ -182,23 +182,6 @@ class SparseXentTest(test.TestCase):
np.array([[1., 1., 1., 1.], [1., 2., 3., 4.]]).astype(np.float64),
np.array([0, 3]).astype(label_dtype))
def testBfloat16(self):
for label_dtype in np.int32, np.int64:
np_features = np.array([[1., 1., 1., 1.], [1., 2., 3.,
4.]]).astype(np.float32)
np_labels = np.array([0, 3]).astype(label_dtype)
np_loss, np_backprop = self._npXent(np_features, np_labels)
np_features_bf16 = math_ops.cast(np_features, dtypes.bfloat16)
np_loss_bf16 = math_ops.cast(np_loss, dtypes.bfloat16)
np_backprop_bf16 = math_ops.cast(np_backprop, dtypes.bfloat16)
with self.cached_session(use_gpu=False):
loss, backprop = gen_nn_ops.sparse_softmax_cross_entropy_with_logits(
np_features_bf16, np_labels)
tf_loss, tf_backprop = self.evaluate([loss, backprop])
self.assertAllCloseAccordingToType(np_loss_bf16, tf_loss)
self.assertAllCloseAccordingToType(np_backprop_bf16, tf_backprop)
def testHalf(self):
for label_dtype in np.int32, np.int64:
self._testXent(