From d0201408290c48bf3e6b677a396ded2d594d22c2 Mon Sep 17 00:00:00 2001 From: "Xiaoming (Jason) Cui" Date: Mon, 20 Apr 2020 10:35:14 -0700 Subject: [PATCH] [Intel MKL] Fixes an issue with tensorflow bfloat16 scope that creates node names with '//', which causes unknown node name problem when run inference from a trained checkpoint --- tensorflow/python/tpu/bfloat16.py | 2 +- tensorflow/python/tpu/bfloat16_test.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/tpu/bfloat16.py b/tensorflow/python/tpu/bfloat16.py index 70f71815e51..b4ab1f8731f 100644 --- a/tensorflow/python/tpu/bfloat16.py +++ b/tensorflow/python/tpu/bfloat16.py @@ -76,5 +76,5 @@ def bfloat16_scope(): This enables variables to be read as bfloat16 type when using get_variable. """ with variable_scope.variable_scope( - '', custom_getter=_get_custom_getter()) as varscope: + 'bfloat16', custom_getter=_get_custom_getter()) as varscope: yield varscope diff --git a/tensorflow/python/tpu/bfloat16_test.py b/tensorflow/python/tpu/bfloat16_test.py index 78157ea86c2..be7b014a3af 100644 --- a/tensorflow/python/tpu/bfloat16_test.py +++ b/tensorflow/python/tpu/bfloat16_test.py @@ -31,7 +31,7 @@ class BFloat16ScopeTest(test.TestCase): def testScopeName(self): """Test if name for the variable scope is propagated correctly.""" with bfloat16.bfloat16_scope() as bf: - self.assertEqual(bf.name, "") + self.assertEqual(bf.name, "bfloat16") @test_util.run_deprecated_v1 def testRequestedDType(self):