如何为TensorFlow变量分配值?


78

我正在尝试为python中的tensorflow变量分配一个新值。

import tensorflow as tf
import numpy as np

x = tf.Variable(0)
init = tf.initialize_all_variables()
sess = tf.InteractiveSession()
sess.run(init)

print(x.eval())

x.assign(1)
print(x.eval())

但是我得到的输出是

0
0

因此该值没有改变。我想念什么?

Answers:


123

在TF1中,该语句x.assign(1)实际上并未将值分配1x,而是创建了tf.Operation必须显式运行以更新变量的a。*调用Operation.run()Session.run()可用于运行该操作:

assign_op = x.assign(1)
sess.run(assign_op)  # or `assign_op.op.run()`
print(x.eval())
# ==> 1

(*实际上,它返回tf.Tensor与变量的更新值相对应的a ,以便更轻松地链接分配。)

但是,x.assign(1)现在在TF2中会急切地分配值:

x.assign(1)
print(x.numpy())
# ==> 1

谢谢!Assign_op.run()给出错误:AttributeError:“ Tensor”对象没有属性“ run”。但是sess.run(assign_op)运行得很好。
abora 2015年

在此示例中,操作/可变张量运行Variable x之前存储在内存中的数据是否assign被覆盖,还是创建了一个新张量来存储更新后的值?
dannygoldstein'6

3
当前的实现将assign()覆盖现有值。
mrry

1
有没有一种方法可以将新值分配给,Variable而无需在图形中创建任何其他操作?似乎每个变量都已经为其创建了一个Assign操作,但是调用my_var.assign()tf.assign()创建一个新操作而不是使用现有操作。
内森

我不认为这与这里是否相关,但是您可以给定assign一个张量参数,例如数学运算。并以此方式创建一个计数器,该计数器在每次评估赋值操作时都会更新:op = t.assign(tf.add(t, 1))
Eliel Van Hojman

40

您也可以为分配新值,tf.Variable而无需在图形上添加操作:tf.Variable.load(value, session)。从图形外部分配值时,此功能还可以节省添加占位符的功能,这在图形完成时很有用。

import tensorflow as tf
x = tf.Variable(0)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(x))  # Prints 0.
x.load(1, sess)
print(sess.run(x))  # Prints 1.

更新:在TF2中对此进行了描述,因为默认情况下急切执行,并且在面向用户的API中不再显示图形。


2
注意:您不能使用形状与变量初始值不同的数组来加载它!
Rajarshee Mitra

1
不推荐使用Variable.load(来自tensorflow.python.ops.variables),并将在以后的版本中将其删除。更新说明:建议Variable.assign在2.X中具有相同的行为。不知道如何改变Tensorflow 2.0的变量的值,而不会增加运算到图形
若昂阿布兰特什

15

首先,您可以通过使用占位符的相同方式将值输入变量/常量来为它们分配值。因此,这样做完全合法:

import tensorflow as tf
x = tf.Variable(0)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x, feed_dict={x: 3})

关于您与tf.assign()运算符的混淆。在TF中,在会话内部运行它之前不会执行任何操作。因此,您始终必须执行以下操作:op_name = tf.some_function_that_create_op(params)然后在会话中运行sess.run(op_name)。以assign为例,您将执行以下操作:

import tensorflow as tf
x = tf.Variable(0)
y = tf.assign(x, 1)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(x)
    print sess.run(y)
    print sess.run(x)

@RobinDinse,确实如此。在上面的示例中,您将获得0,1,1作为标准输出。
Rajarshee Mitra '18

4
请注意,通过馈送值feed_dict不会将值永久分配给变量,而只是针对特定的运行调用。
罗宾丁斯

@RobinDinse如何永久分配该值?如果可以,请在此处查看我的问题stackoverflow.com/questions/53141762/…–
volperossa

3

另外,必须注意,如果您使用your_tensor.assign(),则tf.global_variables_initializer不需要显式调用该需求,因为assign操作会在后台为您执行此操作。

例:

In [212]: w = tf.Variable(12)
In [213]: w_new = w.assign(34)

In [214]: with tf.Session() as sess:
     ...:     sess.run(w_new)
     ...:     print(w_new.eval())

# output
34 

但是,这不会初始化所有变量,而只会初始化对其assign执行的变量。


1

在这里回答了类似的问题。我到很多地方都发现了同样的问题。基本上,我不想为权重分配一个值,而只是更改权重。上述答案的简短版本是:

tf.keras.backend.set_value(tf_var, numpy_weights)


0

这是完整的工作示例:

import numpy as np
import tensorflow as tf

w= tf.Variable(0, dtype=tf.float32) #good practice to set the type of the variable
cost = 10 + 5*w + w*w
train = tf.train.GradientDescentOptimizer(0.01).minimize(cost)

init = tf.global_variables_initializer()
session = tf.Session()
session.run(init)

print(session.run(w))

session.run(train)
print(session.run(w)) # runs one step of gradient descent

for i in range(10000):
  session.run(train)

print(session.run(w))

注意输出将是:

0.0
-0.049999997
-2.499994

这意味着,从一开始,变量就是所定义的0,然后在经过仅1个步骤的渐变体之后,变量就为-0.049999997,再经过10.000步,我们就达到了-2.499994(基于成本函数)。

注意:您最初使用交互式会话。当需要在同一脚本中运行多个不同的会话时,交互式会话很有用。但是,为了简单起见,我使用了非交互式会话。


0

使用最新的Tensorflow渴望执行模式。

import tensorflow as tf
tf.enable_eager_execution()
my_int_variable = tf.get_variable("my_int_variable", [1, 2, 3])
print(my_int_variable)

-1

因此,在不同情况下,我需要在运行会话之前分配值,因此这是最简单的方法:

other_variable = tf.get_variable("other_variable", dtype=tf.int32,
  initializer=tf.constant([23, 42]))

在这里,我正在创建一个变量并同时为其赋值


-10

有一个更简单的方法:

x = tf.Variable(0)
x = x + 1
print x.eval()

3
操作人员正在检查的用法tf.assign,而不是加法。
vega
By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.