Change global step to int64 again post update.

PiperOrigin-RevId: 167913112
This commit is contained in:
Jacques Pienaar 2017-09-07 14:35:41 -07:00 committed by TensorFlower Gardener
parent 90dad32968
commit eb75ded6d7
2 changed files with 15 additions and 3 deletions

View File

@ -151,6 +151,19 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
ComputeAndCompareR0<int32>(&builder, -3, {});
}
XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
ComputationBuilder builder(client_, TestName());
auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
builder.ConvertElementType(a, F32);
int64 value = 3LL << 32;
std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
std::unique_ptr<GlobalData> a_data =
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
{a_data.get()});
}
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
ComputationBuilder builder(client_, TestName());
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),

View File

@ -68,12 +68,11 @@ def _create_global_step(graph):
return variable_scope.get_variable(
ops.GraphKeys.GLOBAL_STEP,
shape=[],
dtype=dtypes.int32,
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
use_resource=True,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
def _sync_variables_ops():