在Tensorflow中,获取图中所有张量的名称


118

我正在使用Tensorflow和创建神经网络skflow。由于某种原因,我想获得某种内在的张量的值给定的输入,所以我使用的myClassifier.get_layer_value(input, "tensorName")myClassifier作为一个skflow.estimators.TensorFlowEstimator

但是,我发现很难找到张量名称的正确语法,即使知道它的名称也很困难(而且我对操作和张量感到困惑),因此我使用张量板来绘制图形并寻找名称。

有没有一种方法可以在不使用张量板的情况下枚举图中的所有张量?

Answers:


189

你可以做

[n.name for n in tf.get_default_graph().as_graph_def().node]

另外,如果要在IPython笔记本中进行原型制作,则可以直接在笔记本中显示图形,请参见show_graphAlexander's Deep Dream 笔记本中的功能


2
您可以通过if "Variable" in n.op在理解的末尾添加来过滤例如变量。
Radu

如果知道名称,是否可以获取特定节点?
火箭Pingu

要了解有关图节点的更多信息,请访问:tensorflow.org/extend/tool_developers/#nodes
Ivan Talalaev

3
上面的命令产生所有操作/节点的名称。要获取所有张量的名称,请执行以下操作:tensors_per_node = [graph.get_operations()中的node的node.values()] tensor_names = [tensors_per_node中的张量的tensor.name,张量中的张量]
gebbissimo

25

有一种方法可以通过使用get_operations来比Yaroslav的回答中稍快一些。这是一个简单的示例:

import tensorflow as tf

a = tf.constant(1.3, name='const_a')
b = tf.Variable(3.1, name='variable_b')
c = tf.add(a, b, name='addition')
d = tf.multiply(c, a, name='multiply')

for op in tf.get_default_graph().get_operations():
    print(str(op.name))

1
您无法使用使用Tensors tf.get_operations()。只有您可以获得的操作。
Soulduck

14

我将尝试总结答案:

要获取所有节点(类型tensorflow.core.framework.node_def_pb2.NodeDef):

all_nodes = [n for n in tf.get_default_graph().as_graph_def().node]

要获取所有操作(类型tensorflow.python.framework.ops.Operation):

all_ops = tf.get_default_graph().get_operations()

要获取所有变量(类型tensorflow.python.ops.resource_variable_ops.ResourceVariable):

all_vars = tf.global_variables()

获取所有张量(类型tensorflow.python.framework.ops.Tensor

all_tensors = [tensor for op in tf.get_default_graph().get_operations() for tensor in op.values()]

11

tf.all_variables() 可以为您获取所需的信息。

此外,今天在TensorFlow Learn中所做的提交get_variable_names在estimator中提供了一个函数,您可以使用该函数轻松检索所有变量名称。


不建议使用此功能
CAFEBABE

8
...及其后继者是tf.global_variables()
bluenote10 '17

11
这仅获取变量,而不获取张量。
Rajarshee Mitra

在Tensorflow 1.9.0中显示all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02
stackoverYC

5

我认为这样做也可以:

print(tf.contrib.graph_editor.get_tensors(tf.get_default_graph()))

但是,与萨尔瓦多和雅罗斯拉夫的答案相比,我不知道哪个更好。


这是一张从tensorflow对象检测API中使用的Frozen_inference_graph.pb文件导入的图。谢谢
simo23 '18

4

接受的答案仅会为您提供带有名称的字符串列表。我更喜欢另一种方法,它使您(几乎)直接访问张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values() for op in graph.get_operations()]

list_of_tuples现在包含每个张量,每个张量都在一个元组中。您还可以对其进行调整以直接获得张量:

graph = tf.get_default_graph()
list_of_tuples = [op.values()[0] for op in graph.get_operations()]

这是获取操作的实际输出张量的方法,而不仅仅是操作。
Szabolcs

4

由于OP要求张量的列表而不是操作/节点的列表,因此代码应略有不同:

graph = tf.get_default_graph()    
tensors_per_node = [node.values() for node in graph.get_operations()]
tensor_names = [tensor.name for tensors in tensors_per_node for tensor in tensors]

3

先前的答案很好,我只想分享我编写的从图中选择张量的实用函数:

def get_graph_op(graph, and_conds=None, op='and', or_conds=None):
    """Selects nodes' names in the graph if:
    - The name contains all items in and_conds
    - OR/AND depending on op
    - The name contains any item in or_conds

    Condition starting with a "!" are negated.
    Returns all ops if no optional arguments is given.

    Args:
        graph (tf.Graph): The graph containing sought tensors
        and_conds (list(str)), optional): Defaults to None.
            "and" conditions
        op (str, optional): Defaults to 'and'. 
            How to link the and_conds and or_conds:
            with an 'and' or an 'or'
        or_conds (list(str), optional): Defaults to None.
            "or conditions"

    Returns:
        list(str): list of relevant tensor names
    """
    assert op in {'and', 'or'}

    if and_conds is None:
        and_conds = ['']
    if or_conds is None:
        or_conds = ['']

    node_names = [n.name for n in graph.as_graph_def().node]

    ands = {
        n for n in node_names
        if all(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in and_conds
        )}

    ors = {
        n for n in node_names
        if any(
            cond in n if '!' not in cond
            else cond[1:] not in n
            for cond in or_conds
        )}

    if op == 'and':
        return [
            n for n in node_names
            if n in ands.intersection(ors)
        ]
    elif op == 'or':
        return [
            n for n in node_names
            if n in ands.union(ors)
        ]

因此,如果您有带有操作图的图形:

['model/classifier/dense/kernel',
'model/classifier/dense/kernel/Assign',
'model/classifier/dense/kernel/read',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd',
'model/classifier/ArgMax/dimension',
'model/classifier/ArgMax']

然后跑步

get_graph_op(tf.get_default_graph(), ['dense', '!kernel'], 'or', ['Assign'])

返回:

['model/classifier/dense/kernel/Assign',
'model/classifier/dense/bias',
'model/classifier/dense/bias/Assign',
'model/classifier/dense/bias/read',
'model/classifier/dense/MatMul',
'model/classifier/dense/BiasAdd']

By using our site, you acknowledge that you have read and understand our Cookie Policy and Privacy Policy.
Licensed under cc by-sa 3.0 with attribution required.