pasos vs num epochs en tensorflow introducción tutorial


Estoy pasando por TensorFlow tutorial para empezar. En el ejemplo tf.contrib.learn, estas son dos líneas de código:

input_fn = tf.contrib.learn.io.numpy_input_fn({"x":x}, y, batch_size=4, num_epochs=1000)
estimator.fit(input_fn=input_fn, steps=1000)

Me pregunto cuál es la diferencia entre el argumento steps en la llamada a la función fit y num_epochs en la llamada numpy_input_fn. ¿No debería haber una sola discusión? Cómo están conectados?

He encontrado que el código de alguna manera está tomando el min de estos dos como el número de pasos en el ejemplo de juguete de la tutorial.


Editar

Gracias por todas las respuestas. En mi humilde opinión, al menos uno de los dos parámetros num_epochs o steps tiene que ser redundante. Podemos calcular uno a partir del otro. ¿Hay alguna manera de saber cuántos pasos (número de veces que se actualizan los parámetros) tomó realmente mi algoritmo?

Tengo curiosidad de cuál tiene prioridad. ¿Y depende de otros parámetros?

 34
Author: nbro, 2017-03-15

5 answers

El contrib.learn.io module no está muy bien documentado, pero parece que la función numpy_input_fn() toma algunas matrices numpy y las agrupa como entrada para un clasificador. Por lo tanto, el número de épocas probablemente significa "cuántas veces para ir a través de los datos de entrada que tengo antes de parar". En este caso, alimentan dos matrices de longitud 4 en 4 lotes de elementos, por lo que solo significará que la función de entrada hará esto como máximo 1000 veces antes de levantar una excepción "fuera de datos". El argumento steps en la función estimator fit() es cuántas veces debe estimator hacer el ciclo de entrenamiento. Este ejemplo en particular es algo perverso, así que permítanme inventar otro para hacer las cosas un poco más claras (espero).

Digamos que tiene dos matrices numpy (muestras y etiquetas) en las que desea entrenar. Son 100 elementos cada uno. Usted quiere que su entrenamiento tome lotes con 10 muestras por lote. Así que después de 10 lotes, revisará todos sus datos de entrenamiento. Esa es una época. Si tu generador de entrada a 10 épocas, pasará por su conjunto de entrenamiento 10 veces antes de detenerse, es decir, generará como máximo 100 lotes.

De nuevo, el módulo io no está documentado, pero teniendo en cuenta cómo funcionan otras API relacionadas con la entrada en tensorflow, debería ser posible hacer que genere datos para un número ilimitado de épocas, por lo que lo único que controla la duración del entrenamiento serán los pasos. Esto le da cierta flexibilidad adicional sobre cómo desea que progrese su entrenamiento. Usted puede ir un número de épocas a la vez o un número de pasos a la vez o ambos o lo que sea.

Edit: TL;DR Epoch es cuando tu modelo revisa todos tus datos de entrenamiento una vez. El paso es cuando su modelo entrena en un solo lote (o una sola muestra si envía muestras una por una). Formación para 5 épocas en un 1000 muestras 10 muestras por lote tomará 500 pasos.

 31
Author: Mad Wombat,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-06-14 14:45:09

Época: Se pasa a través de todos los datos.

Tamaño del lote: El número de ejemplos vistos en un lote.

Si hay 1000 ejemplos y el tamaño del lote es 100, entonces habrá 10 pasos por época.

Las épocas y el tamaño del lote definen completamente el número de pasos.

Steps_cal = (no de ex / batch_size) * no_of_epochs

estimator.fit(input_fn=input_fn)

Si solo escribe el código anterior, entonces el valor de 'steps' es el dado por 'steps_cal' en la fórmula anterior.

estimator.fit(input_fn=input_fn, steps  = steps_less)

Si usted da un valor (digamos 'steps_less') menor que 'steps_cal', entonces solo' steps_less ' no de pasos será executed.In en este caso, la capacitación no cubrirá todo el número de épocas que se mencionaron.

estimator.fit(input_fn=input_fn, steps  = steps_more)

Si da un valor(digamos steps_more) más que steps_cal, entonces también 'steps_cal' no se ejecutará ningún paso.

 17
Author: Himanshu Sanghi,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-07-12 12:28:34

Empecemos por el orden opuesto:

1) Pasos - número de veces que el bucle de entrenamiento en su algoritmo de aprendizaje se ejecutará para actualizar los parámetros en el modelo. En cada iteración de bucle, procesará un trozo de datos, que es básicamente un lote. Por lo general, este bucle se basa en el algoritmo de Descenso de Gradiente .

2) Tamaño del lote - el tamaño del fragmento de datos que alimenta en cada bucle del algoritmo de aprendizaje. Puedes alimentar a la conjunto de datos completo, en cuyo caso el tamaño del lote es igual al tamaño del conjunto de datos.También puede alimentar un ejemplo a la vez. O puedes alimentar algún número N de ejemplos.

3) Época - el número de veces que se ejecuta sobre el conjunto de datos extrayendo lotes para alimentar el algoritmo de aprendizaje.

Digamos que tienes 1000 ejemplos. Establecer batch size = 100, epoch = 1 y steps = 200 da un proceso con una pasada (una época) sobre todo el conjunto de datos. En cada pasarlo alimentará al algoritmo un lote con 100 ejemplos. El algoritmo ejecutará 200 pasos en cada lote. En total, se ven 10 lotes. Si cambias la época a 25, entonces hará esto 25 veces, y obtienes 25x10 lotes vistos en conjunto.

¿por Qué necesitamos esto? Hay muchas variaciones en el descenso del gradiente (lote, estocástico, mini-lote), así como otros algoritmos para optimizar los parámetros de aprendizaje (por ejemplo, L-BFGS). Algunos de ellos necesitan ver los datos en lotes, mientras que otros ven un dato a la vez. Además, algunos de ellos incluyen factores/pasos aleatorios, por lo tanto, es posible que necesite varias pasadas de los datos para obtener una buena convergencia.

 14
Author: Manuel,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-06-14 11:30:46

Esta respuesta se basa en la experimentación que he hecho en el código tutorial de introducción.

Mad Wombat ha dado una explicación detallada de los términos num_epochs, batch_size y pasos. Esta respuesta es una extensión de su respuesta.

num_epochs - El número máximo de veces que el programa puede iterar sobre todo el conjunto de datos en un train(). Usando este argumento, podemos restringir el número de lotes que se pueden procesar durante la ejecución de un método train().

batch_size - El número de ejemplos en un solo lote emitido por la input_fn

pasos - Número de lotes que el método LinearRegressor.train() puede procesar en una ejecución

max_steps es otro argumento para el método LinearRegressor.train(). Este argumento define el número máximo de pasos (lotes) que se pueden procesar en la vida útil de los objetos LinearRegressor().

Vamos a lo que esto significa. Los siguientes experimentos cambian dos líneas del código proporcionado por el tutorial. El resto del código permanece como está.

Nota: Para todos los ejemplos, supongamos que el número de entrenamiento, es decir, la longitud de x_train es igual a 4.

Ex 1:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=4, num_epochs=2, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

En este ejemplo, se define el batch_size = 4 y num_epochs = 2. Por lo tanto, el input_fn puede emitir solo 2 lotes de datos de entrada para una ejecución de train(). A pesar de que definimos steps = 10, el método train() se detiene después de 2 pasos.

Ahora, ejecute el estimator.train(input_fn=input_fn, steps=10) de nuevo. Podemos ver que se han ejecutado 2 pasos más. Podemos seguir ejecutando el método train() una y otra vez. Si ejecutamos train() 50 veces, se han ejecutado un total de 100 pasos.

Ex 2:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=2, num_epochs=2, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

En este ejemplo, el valor de batch_size se cambia a 2 (era igual a 4 en Ex 1). Ahora, en cada ejecución de train() método, se procesan 4 pasos. Después del 4to paso, no hay lotes para correr. Si el método train() se ejecuta de nuevo, se procesan otros 4 pasos, lo que hace que sea un total de 8 pasos.

Aquí, el valor de steps no importa porque el método train() puede obtener un máximo de 4 lotes. Si el valor de pasos es menor que (num_epochs x training_size) / batch_size, véase ex 3.

Ex 3:

input_fn = tf.estimator.inputs.numpy_input_fn( {"x": x_train}, y_train, batch_size=2, num_epochs=8, shuffle=True)

estimator.train(input_fn=input_fn, steps=10)

Ahora, vamos a batch_size = 2, num_epochs = 8 y pasos = 10. El input_fn puede emitir un total de 16 lotes en una ejecución del método train(). Sin embargo, steps se establece en 10. Esto significa que aunque input_fn puede proporcionar 16 lotes para la ejecución, train() debe detenerse después de 10 pasos. Por supuesto, el método train() se puede volver a ejecutar para más pasos acumulativamente.


De los ejemplos 1, 2 y 3, podemos claramente ver cómo los valores de pasos, num_epoch y batch_size afectan al número de pasos que se pueden ejecutar con el método train() en una sola ejecución.

El argumento max_steps del método train() restringe el número total de pasos que se pueden ejecutar acumulativamente por train()

Ex 4:

Si batch_size = 4, num_epochs = 2, el input_fn puede emitir 2 lotes para una train() ejecución. Pero, si max_steps se establece en 20, no importa cuántas veces se ejecute train() solo se ejecutarán 20 pasos en la optimización. Esto contrasta con el ejemplo 1, donde el optimizador puede ejecutar 200 pasos si el método train() se exuta 100 veces.

Espero que esto dé una comprensión detallada de lo que significan estos argumentos.

 8
Author: pbskumar,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-11-06 16:57:33

Num_epochs: el número máximo de épocas (ver cada punto de datos).

Pasos: el número de actualizaciones (de parámetros).

Puede actualizar varias veces, en una época cuando el tamaño del lote es menor que el número de datos de entrenamiento.

 3
Author: Change-the-world,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-04-21 11:50:51