Change global step to int64 again post update.
PiperOrigin-RevId: 167913112
This commit is contained in:
parent
90dad32968
commit
eb75ded6d7
@ -151,6 +151,19 @@ XLA_TEST_F(ScalarComputationsTest, SubtractTwoScalarsS32) {
|
||||
ComputeAndCompareR0<int32>(&builder, -3, {});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScalarComputationsTest, CastS64ToF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
auto a = builder.Parameter(0, ShapeUtil::MakeShape(S64, {}), "a");
|
||||
builder.ConvertElementType(a, F32);
|
||||
|
||||
int64 value = 3LL << 32;
|
||||
std::unique_ptr<Literal> a_literal = Literal::CreateR0<int64>(value);
|
||||
std::unique_ptr<GlobalData> a_data =
|
||||
client_->TransferToServer(*a_literal).ConsumeValueOrDie();
|
||||
ComputeAndCompareR0<float>(&builder, static_cast<float>(value),
|
||||
{a_data.get()});
|
||||
}
|
||||
|
||||
XLA_TEST_F(ScalarComputationsTest, MulThreeScalarsF32) {
|
||||
ComputationBuilder builder(client_, TestName());
|
||||
builder.Mul(builder.Mul(builder.ConstantR0<float>(2.1f),
|
||||
|
@ -68,12 +68,11 @@ def _create_global_step(graph):
|
||||
return variable_scope.get_variable(
|
||||
ops.GraphKeys.GLOBAL_STEP,
|
||||
shape=[],
|
||||
dtype=dtypes.int32,
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
use_resource=True,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
|
||||
ops.GraphKeys.GLOBAL_STEP])
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP])
|
||||
|
||||
|
||||
def _sync_variables_ops():
|
||||
|
Loading…
Reference in New Issue
Block a user