Alcance de la variable Tensorflow: reutilizar si la variable existe


Quiero un fragmento de código que cree una variable dentro de un ámbito si no existe, y acceder a la variable si ya existe. Necesito que sea el mismo código ya que se llamará varias veces.

Sin embargo, Tensorflow me necesita para especificar si quiero crear o reutilizar la variable, así:

with tf.variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with tf.variable_scope("foo", reuse=True): #reuse the second time
    v = tf.get_variable("v", [1])

¿Cómo puedo hacer que averigüe si crear o reutilizar automáticamente? Es decir, quiero que los dos bloques de código anteriores sean iguales y tengan el programa se ejecuta.

 25
Author: holdenlee, 2016-07-23

4 answers

A ValueError se genera en get_variable() cuando se crea una nueva variable y no se declara la forma, o cuando se viola la reutilización durante la creación de la variable. Por lo tanto, puedes probar esto:

def get_scope_variable(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as scope:
        try:
            v = tf.get_variable(var, shape)
        except ValueError:
            scope.reuse_variables()
            v = tf.get_variable(var)
    return v

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v')
assert v1 == v2

Tenga en cuenta que lo siguiente también funciona:

v1 = get_scope_variable('foo', 'v', [1])
v2 = get_scope_variable('foo', 'v', [1])
assert v1 == v2

ACTUALIZACIÓN. La nueva API admite la reutilización automática ahora:

def get_scope_variable(scope, var, shape=None):
    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        v = tf.get_variable(var, shape)
    return v
 24
Author: rvinas,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-03-11 23:19:17

Aunque usando " try...excepto..."la cláusula funciona, creo que una forma más elegante y mantenible sería separar el proceso de inicialización de variables con el proceso de "reutilización".

def initialize_variable(scope_name, var_name, shape):
    with tf.variable_scope(scope_name) as scope:
        v = tf.get_variable(var_name, shape)
        scope.reuse_variable()

def get_scope_variable(scope_name, var_name):
    with tf.variable_scope(scope_name, reuse=True):
        v = tf.get_variable(var_name)
    return v

Dado que a menudo solo necesitamos inicializar variables, pero reutilizarlas/compartirlas muchas veces, separar los dos procesos hace que el código sea más limpio. También de esta manera, no necesitaremos pasar por la cláusula "try" cada vez para comprobar si la variable ya ha sido creada o no.

 12
Author: Zhongyu Kuang,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-10-17 18:45:21

La nueva opción AUTO_REUSE hace el truco.

Del tf.variable_scope API docs : si reuse=tf.AUTO_REUSE, creamos variables si no existen, y las devolvemos de lo contrario.

Ejemplo básico de compartir una variable AUTO_REUSE:

def foo():
  with tf.variable_scope("foo", reuse=tf.AUTO_REUSE):
    v = tf.get_variable("v", [1])
  return v

v1 = foo()  # Creates v.
v2 = foo()  # Gets the same, existing v.
assert v1 == v2
 6
Author: Mikhail Mishin,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-12-04 21:08:18

Podemos escribir nuestra abstracción sobre tf.varaible_scope que usa reuse=None en la primera llamada y usa reuse=True en las llamadas consecuentes:

def variable_scope(name_or_scope, *args, **kwargs):
  if isinstance(name_or_scope, str):
    scope_name = tf.get_variable_scope().name + '/' + name_or_scope
  elif isinstance(name_or_scope, tf.Variable):
    scope_name = name_or_scope.name

  if scope_name in variable_scope.scopes:
    kwargs['reuse'] = True
  else:
    variable_scope.scopes.add(scope_name)

  return tf.variable_scope(name_or_scope, *args, **kwargs)
variable_scope.scopes = set()

Uso:

with variable_scope("foo"): #create the first time
    v = tf.get_variable("v", [1])

with variable_scope("foo"): #reuse the second time
    v = tf.get_variable("v", [1])
 1
Author: AlexP,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-06-28 13:51:44