¿Cómo contar el número total de parámetros entrenables en un modelo tensorflow?


¿Hay una llamada a una función u otra forma de contar el número total de parámetros en un gráfico de tensorflow?

Por parámetros quiero decir: un vector N dim de variables entrenables tiene N parámetros, una matriz NxM tiene N*M parámetros, etc. Así que esencialmente me gustaría sumar el producto de las dimensiones de forma de todas las variables entrenables en una sesión de tensorflow.

Author: j314erre, 2016-07-02

6 answers

Bucle sobre la forma de cada variable en tf.trainable_variables().

total_parameters = 0
for variable in tf.trainable_variables():
    # shape is an array of tf.Dimension
    shape = variable.get_shape()
    print(shape)
    print(len(shape))
    variable_parameters = 1
    for dim in shape:
        print(dim)
        variable_parameters *= dim.value
    print(variable_parameters)
    total_parameters += variable_parameters
print(total_parameters)
 59
Author: nessuno,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-08-24 06:00:35

Tengo una versión aún más corta, una solución de línea usando usando numpy:

np.sum([np.prod(v.get_shape().as_list()) for v in tf.trainable_variables()])
 27
Author: Michael Gygli,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-05-06 11:34:16

No estoy seguro de si la respuesta dada realmente se ejecuta (descubrí que necesita convertir el objeto dim a un int para que funcione). Aquí es uno que funciona y solo puede copiar y pegar las funciones y llamarlas (también se agregaron algunos comentarios):

def count_number_trainable_params():
    '''
    Counts the number of trainable variables.
    '''
    tot_nb_params = 0
    for trainable_variable in tf.trainable_variables():
        shape = trainable_variable.get_shape() # e.g [D,F] or [W,H,C]
        current_nb_params = get_nb_params_shape(shape)
        tot_nb_params = tot_nb_params + current_nb_params
    return tot_nb_params

def get_nb_params_shape(shape):
    '''
    Computes the total number of params for a given shap.
    Works for any number of shapes etc [D,F] or [W,H,C] computes D*F and W*H*C.
    '''
    nb_params = 1
    for dim in shape:
        nb_params = nb_params*int(dim)
    return nb_params 
 8
Author: Pinocchio,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-11-21 23:34:29

Las dos respuestas existentes son buenas si estás buscando calcular el número de parámetros tú mismo. Si su pregunta fue más en la línea de "¿hay una manera fácil de perfilar mis modelos TensorFlow?", Yo recomendaría encarecidamente buscar en tfprof . Perfila su modelo, incluyendo el cálculo del número de parámetros.

 6
Author: Gabriel Parent,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-12-21 06:45:03

Voy a lanzar en mi equivalente, pero la implementación más corta:

def count_params():
    "print number of trainable variables"
    size = lambda v: reduce(lambda x, y: x*y, v.get_shape().as_list())
    n = sum(size(v) for v in tf.trainable_variables())
    print "Model size: %dK" % (n/1000,)
 2
Author: Gregor Mitscha-Baude,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-12-29 15:00:15

Si uno prefiere evitar numpy (se puede dejar fuera para muchos proyectos), entonces:

all_trainable_vars = tf.reduce_sum([tf.reduce_prod(v.shape) for v in tf.trainable_variables()])

Esta es una traducción TF de la respuesta anterior de Julius Kunze.

Como cualquier operación TF, requiere una sesión ejecutada para evaluar:

print(sess.run(all_trainable_vars))
 0
Author: Ran Feldesh,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-08-07 02:44:45