Tensorflow: How does tf.get_variable work?


Question


I have read about tf.get_variable from this question and also a bit from the documentation available at the tensorflow website. However, I am still not clear and was unable to find an answer online.

How does tf.get_variable work? For example:

var1 = tf.Variable(3.,dtype=float64)
var2 = tf.get_variable("var1",[],dtype=tf.float64)

Does it mean that var2 is another variable with initialization similar to var1? Or is var2 an alias for var1 (I tried and it doesn't seem to)?

How are var1 and var2 related?

How is a variable constructed when the variable we are getting doesn't really exists?


Answer 1:


tf.get_variable(name) creates a new variable called name (or add _ if name already exists in the current scope) in the tensorflow graph.

In your example, you're creating a python variable called var1.

The name of that variable in the **Tensorflow graph is not ** var1, but is Variable:0.

Every node you define has its own name that you can specify or let tensorflow give a default (and always different) one. You can see the name value accessing the name property of the python variable. (ie print(var1.name)).

On your second line, you're defining a Python variable var2 whose name in the tensorflow graph is var1.

The script

import tensorflow as tf

var1 = tf.Variable(3.,dtype=tf.float64)
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

In fact prints:

Variable:0
var1:0

If you, instead, want to define a variable (node) called var1 in the tensorflow graph and then getting a reference to that node, you cannot simply use tf.get_variable("var1"), because it will create a new different variable valled var1_1.

This script

var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
print(var1.name)
var2 = tf.get_variable("var1",[],dtype=tf.float64)
print(var2.name)

prints:

var1:0
var1_1:0

If you want to create a reference to the node var1, you first:

  1. Have to replace tf.Variable with tf.get_variable. The variables created with tf.Variable can't be shared, while the latter can.

  2. Know what the scope of the var1 is and allow the reuse of that scope when declaring the reference.

Looking at the code is the better way for understanding

import tensorflow as tf

#var1 = tf.Variable(3.,dtype=tf.float64, name="var1")
var1 = tf.get_variable(initializer=tf.constant_initializer(3.), dtype=tf.float64, name="var1", shape=())
current_scope = tf.contrib.framework.get_name_scope()
print(var1.name)
with tf.variable_scope(current_scope, reuse=True):
    var2 = tf.get_variable("var1",[],dtype=tf.float64)
    print(var2.name)

outputs:

var1:0
var1:0



Answer 2:


If you define a variable with a name that has been defined before, then TensorFlow throws an exception. Hence, it is convenient to use the tf.get_variable() function instead of tf.Variable(). The function tf.get_variable() returns the existing variable with the same name if it exists, and creates the variable with the specified shape and initializer if it does not exist.



来源:https://stackoverflow.com/questions/45074049/tensorflow-how-does-tf-get-variable-work


码神部落- 版权声明 1、本主题所有言论和图片纯属会员个人意见,与码神部落立场无关。
2、本站所有主题由该帖子作者发表,该帖子作者这爷绝蝂码神部落享有帖子相关版权。
3、码神部落管理员和版主有权不事先通知发贴者而删除本文。
4、其他单位或个人使用、转载或引用本文时必须同时征得该帖子作者这爷绝蝂码神部落的同意。
5、帖子作者须承担一切因本文发表而直接或间接导致的民事或刑事法律责任。
6、本帖部分内容转载自其它媒体,但并不代表本站赞同其观点和对其真实性负责。
7、如本帖侵犯到任何版权问题,请立即告知本站,本站将及时予与删除并致以最深的歉意。

最新回复 (0)
    • 码神部落
      2
        立即登录 立即注册 GitHub登录
返回
发新帖
作者最近主题: