Acelerar el emparejamiento de cadenas en objetos en Python


Estoy tratando de encontrar una manera eficiente de emparejar filas de datos que contienen puntos enteros, y almacenarlos como objetos Python. Los datos se componen de X y Y puntos de coordenadas, representados como cadenas separadas por comas. Los puntos tienen que ser emparejados, como en (x_1, y_1), (x_2, y_2), ... etc. y luego se almacena como una lista de objetos, donde cada punto es un objeto. La siguiente función get_data genera este ejemplo de datos:

def get_data(N=100000, M=10):
    import random
    data = []
    for n in range(N):
        pair = [[str(random.randint(1, 10)) for x in range(M)],
                [str(random.randint(1, 10)) for x in range(M)]]
        row = [",".join(pair[0]),
               ",".join(pair[1])]
        data.append(row)
    return data

El código de análisis que tengo ahora es:

class Point:
    def __init__(self, a, b):
        self.a = a
        self.b = b

def test():
    import time
    data = get_data()
    all_point_sets = []
    time_start = time.time()
    for row in data:
        point_set = []
        first_points, second_points = row
        # Convert points from strings to integers
        first_points = map(int, first_points.split(","))
        second_points = map(int, second_points.split(","))
        paired_points = zip(first_points, second_points)
        curr_points = [Point(p[0], p[1]) \
                       for p in paired_points]
        all_point_sets.append(curr_points)
    time_end = time.time()
    print "total time: ", (time_end - time_start)

Actualmente, este toma casi 7 segundos para 100,000 puntos, lo que parece muy ineficiente. Parte de la ineficiencia parece provenir del cálculo de first_points, second_points y paired_points - y la conversión de estos en los objetos.

Otra parte de la ineficiencia parece ser la construcción de all_point_sets. ¡Sacar la línea all_point_sets.append(...) parece hacer que el código vaya de ~7 segundos a 2 segundos!

¿Cómo se puede acelerar esto? gracias.

FOLLOWUP Gracias por las grandes sugerencias de todos - ellos todos nos ayudaron. pero incluso con todas las mejoras, todavía son unos 3 segundos para procesar 100,000 entradas. No estoy seguro de por qué en este caso no es solo instantáneo, y si hay una representación alternativa que lo haría instantáneo. ¿La codificación de esto en Cython cambiaría las cosas? ¿Podría alguien ofrecer un ejemplo de eso? gracias de nuevo.

Author: John La Rooy, 2012-10-17

13 answers

Simplemente correr con pypy hace una gran diferencia

$ python pairing_strings.py 
total time:  2.09194397926
$ pypy pairing_strings.py 
total time:  0.764246940613

Desactivar gc no ayudó para pypy

$ pypy pairing_strings.py 
total time:  0.763386964798

Namedtuple for Point lo hace peor

$ pypy pairing_strings.py 
total time:  0.888827085495

Usando itertools.imap, y itertools.izip

$ pypy pairing_strings.py 
total time:  0.615751981735

Usando una versión memoizada de int y un iterador para evitar el zip

$ pypy pairing_strings.py 
total time:  0.423738002777 

Aquí está el código con el que terminé.

def test():
    import time
    def m_int(s, memo={}):
        if s in memo:
            return memo[s]
        else:
            retval = memo[s] = int(s)
            return retval
    data = get_data()
    all_point_sets = []
    time_start = time.time()
    for xs, ys in data:
        point_set = []
        # Convert points from strings to integers
        y_iter = iter(ys.split(","))
        curr_points = [Point(m_int(i), m_int(next(y_iter))) for i in xs.split(",")]
        all_point_sets.append(curr_points)
    time_end = time.time()
    print "total time: ", (time_end - time_start)
 15
Author: John La Rooy,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-22 10:08:55

Cuando se trata de la creación de grandes números de objetos, a menudo la mayor mejora de rendimiento que puede usar es apagar el recolector de basura. Cada "generación" de objetos, el recolector de basura atraviesa todos los objetos en vivo en la memoria, buscando objetos que son parte de ciclos pero no son apuntados por objetos en vivo, por lo tanto son elegibles para la recuperación de memoria. Véase El artículo de Doug Helmann sobre PyMOTW GC para obtener más información (quizás se pueda encontrar más con Google y algo de determinación). El recolector de basura se ejecuta de forma predeterminada cada 700 o más objetos creados, pero no reclamados, y las generaciones posteriores se ejecutan con un poco menos de frecuencia (me olvido de los detalles exactos).

Usar una tupla estándar en lugar de una clase de Punto puede ahorrarte algo de tiempo (usar un namedtuple está en algún punto intermedio), y desempaquetar de forma inteligente puede ahorrarte algo de tiempo, pero la mayor ganancia se puede obtener apagando el gc antes de crear muchos objetos que sepa que no necesita ser gc'd, y luego volver a encenderlo después.

Algún código:

def orig_test_gc_off():
    import time
    data = get_data()
    all_point_sets = []
    import gc
    gc.disable()
    time_start = time.time()
    for row in data:
        point_set = []
        first_points, second_points = row
        # Convert points from strings to integers
        first_points = map(int, first_points.split(","))
        second_points = map(int, second_points.split(","))
        paired_points = zip(first_points, second_points)
        curr_points = [Point(p[0], p[1]) \
                       for p in paired_points]
        all_point_sets.append(curr_points)
    time_end = time.time()
    gc.enable()
    print "gc off total time: ", (time_end - time_start)

def test1():
    import time
    import gc
    data = get_data()
    all_point_sets = []
    time_start = time.time()
    gc.disable()
    for index, row in enumerate(data):
        first_points, second_points = row
        curr_points = map(
            Point,
            [int(i) for i in first_points.split(",")],
            [int(i) for i in second_points.split(",")])
        all_point_sets.append(curr_points)
    time_end = time.time()
    gc.enable()
    print "variant 1 total time: ", (time_end - time_start)

def test2():
    import time
    import gc
    data = get_data()
    all_point_sets = []
    gc.disable()
    time_start = time.time()
    for index, row in enumerate(data):
        first_points, second_points = row
        first_points = [int(i) for i in first_points.split(",")]
        second_points = [int(i) for i in second_points.split(",")]
        curr_points = [(x, y) for x, y in zip(first_points, second_points)]
        all_point_sets.append(curr_points)
    time_end = time.time()
    gc.enable()
    print "variant 2 total time: ", (time_end - time_start)

orig_test()
orig_test_gc_off()
test1()
test2()

Algunos resultados:

>>> %run /tmp/flup.py
total time:  6.90738511086
gc off total time:  4.94075202942
variant 1 total time:  4.41632509232
variant 2 total time:  3.23905301094
 20
Author: Matt Anderson,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-17 04:26:16

Me gustaría

  • use numpy arrays para este problema (Cython sería una opción, si esto todavía no es lo suficientemente rápido).
  • almacene los puntos como un vector no como instancias simples Point.
  • confíe en analizadores existentes
  • (si es posible) analizar los datos una vez y luego almacenarlos en un formato binario como hdf5 para cálculos adicionales, que será la opción más rápida (ver más abajo)

Numpy ha incorporado funciones para leer archivos de texto, por ejemplo loadtxt. Si tiene los datos almacenados en una matriz estructurada, no necesariamente necesita convertirlos a otro tipo de datos. Usaré Pandas que es una compilación de biblioteca sobre numpy. Es un poco más conveniente para el manejo y procesamiento de datos estructurados. Pandas tiene su propio analizador de archivos read_csv.

Para cronometrarlo, escribí los datos en un archivo, como en su problema original (se basa en su get_data):

import numpy as np
import pandas as pd

def create_example_file(n=100000, m=20):
    ex1 = pd.DataFrame(np.random.randint(1, 10, size=(10e4, m)),
                       columns=(['x_%d' % x for x in range(10)] +
                                ['y_%d' % y for y in range(10)]))
    ex1.to_csv('example.csv', index=False, header=False)
    return

Este es el código que utilicé para leer el los datos en un pandas.DataFrame:

def with_read_csv(csv_file):
    df = pd.read_csv(csv_file, header=None,
                     names=(['x_%d' % x for x in range(10)] +
                            ['y_%d' % y for y in range(10)]))
    return df

(Tenga en cuenta que asumí, que no hay encabezado en su archivo y por lo que tuve que crear los nombres de columna.)

La lectura de los datos es rápida, debería ser más eficiente en memoria (ver esta pregunta) y los datos se almacenan en una estructura de datos con la que puede trabajar más de una manera rápida y vectorizada:

In [18]: %timeit string_to_object.with_read_csv('example.csv')
1 loops, best of 3: 553 ms per loop

Hay un nuevo analizador basado en C en una rama de desarrollo que toma 414 ms en mi sistema. Su prueba toma 2.29 s en mi sistema, pero no es realmente comparable, ya que los datos no se leen de un archivo y creó las instancias Point.

Si ha leído una vez los datos, puede almacenarlos en un archivo hdf5:

In [19]: store = pd.HDFStore('example.h5')

In [20]: store['data'] = df

In [21]: store.close()

La próxima vez que necesite los datos, puede leerlos desde este archivo, que es realmente rápido:

In [1]: store = pd.HDFStore('example.h5')

In [2]: %timeit df = store['data']
100 loops, best of 3: 16.5 ms per loop

Sin embargo, solo será aplicable si necesita los mismos datos más de una vez.

Usar matrices basadas en numpy con grandes conjuntos de datos tendrá ventajas cuando están haciendo más cálculos. Cython no sería necesariamente más rápido si puedes usar funciones vectorizadas numpy e indexación, será más rápido si realmente necesitas iteración (ver también esta respuesta).

 9
Author: bmu,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2017-05-23 11:53:25

Método más rápido, usando Numpy (aceleración de aproximadamente 7x):

import numpy as np
txt = ','.join(','.join(row) for row in data)
arr = np.fromstring(txt, dtype=int, sep=',')
return arr.reshape(100000, 2, 10).transpose((0,2,1))

Comparación de rendimiento:

def load_1(data):
    all_point_sets = []
    gc.disable()
    for xs, ys in data:
        all_point_sets.append(zip(map(int, xs.split(',')), map(int, ys.split(','))))
    gc.enable()
    return all_point_sets

def load_2(data):
    txt = ','.join(','.join(row) for row in data)
    arr = np.fromstring(txt, dtype=int, sep=',')
    return arr.reshape(100000, 2, 10).transpose((0,2,1))

load_1 se ejecuta en 1.52 segundos en mi máquina; load_2 se ejecuta en 0.20 segundos, una mejora de 7 veces. La gran advertencia aquí es que requiere que (1) sepa las longitudes de todo por adelantado, y (2) que cada fila contenga exactamente el mismo número de puntos. Esto es cierto para su salida get_data, pero puede no ser cierto para su conjunto de datos real.

 8
Author: nneonneo,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-26 21:01:05

Obtuve una mejora del 50% al usar arrays y un objeto holder que construye perezosamente objetos Puntuales cuando se accede. Yo también "ranurado" el objeto de Punto de mejor eficiencia del almacenamiento. Sin embargo, una tupla probablemente sería mejor.

Cambiar la estructura de datos también puede ayudar, si eso es posible. Pero esto nunca será instantáneo.

from array import array

class Point(object):
    __slots__ = ["a", "b"]
    def __init__(self, a, b):
        self.a = a
        self.b = b

    def __repr__(self):
        return "Point(%d, %d)" % (self.a, self.b)

class Points(object):
    def __init__(self, xs, ys):
        self.xs = xs
        self.ys = ys

    def __getitem__(self, i):
        return Point(self.xs[i], self.ys[i])

def test3():
    xs = array("i")
    ys = array("i")
    time_start = time.time()
    for row in data:
        xs.extend([int(val) for val in row[0].split(",")])
        ys.extend([int(val) for val in row[1].split(",")])
    print ("total time: ", (time.time() - time_start))
    return Points(xs, ys)

Pero cuando se trata de grandes cantidades de datos que normalmente utilizaría numpy N matrices dimensionales (ndarray). Si la estructura de datos original pudiera ser alterado entonces que probablemente sería más rápido de todos. Si se pudiera estructurar para leer x, y pares en linealmente y luego remodelar el ndarray.

 7
Author: Keith,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-21 01:41:26
  1. Make Point a namedtuple (~10% speedup):

    from collections import namedtuple
    Point = namedtuple('Point', 'a b')
    
  2. Desempaquetar durante la iteración (~2-4% speedup):

    for xs, ys in data:
    
  3. Use n - forma de argumento de map para evitar zip (~10% speedup):

    curr_points = map(Point,
        map(int, xs.split(',')),
        map(int, ys.split(',')),
    )
    

Dado que los conjuntos de puntos son cortos, los generadores son probablemente excesivos, ya que tienen una sobrecarga fija más alta.

 6
Author: nneonneo,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-17 04:01:37

Cython es capaz de acelerar las cosas por un factor de 5.5

$ python split.py
total time:  2.16252303123
total time:  0.393486022949

Aquí está el código que usé

Split.py

import time
import pyximport; pyximport.install()
from split_ import test_


def get_data(N=100000, M=10):
    import random
    data = []
    for n in range(N):
        pair = [[str(random.randint(1, 100)) for x in range(M)],
                [str(random.randint(1, 100)) for x in range(M)]]
        row = [",".join(pair[0]),
               ",".join(pair[1])]
        data.append(row)
    return data

class Point:
    def __init__(self, a, b):
        self.a = a
        self.b = b

def test(data):
    all_point_sets = []
    for row in data:
        point_set = []
        first_points, second_points = row
        # Convert points from strings to integers
        first_points = map(int, first_points.split(","))
        second_points = map(int, second_points.split(","))
        paired_points = zip(first_points, second_points)
        curr_points = [Point(p[0], p[1]) \
                       for p in paired_points]
        all_point_sets.append(curr_points)
    return all_point_sets

data = get_data()
for func in test, test_:
    time_start = time.time()
    res = func(data)
    time_end = time.time()
    print "total time: ", (time_end - time_start)

Split_.pyx

from libc.string cimport strsep
from libc.stdlib cimport atoi

cdef class Point:
    cdef public int a,b

    def __cinit__(self, a, b):
        self.a = a
        self.b = b

def test_(data):
    cdef char *xc, *yc, *xt, *yt
    cdef char **xcp, **ycp
    all_point_sets = []
    for xs, ys in data:
        xc = xs
        xcp = &xc
        yc = ys
        ycp = &yc
        point_set = []
        while True:
            xt = strsep(xcp, ',')
            if xt is NULL:
                break
            yt = strsep(ycp, ",")
            point_set.append(Point(atoi(xt), atoi(yt)))
        all_point_sets.append(point_set)
    return all_point_sets

Husmeando más lejos, aproximadamente puedo descomponer algunos de los recursos de la cpu

         5% strsep()
         9% atoi()
        23% creating Point instances
        35% all_point_sets.append(point_set)

Esperaría que hubiera una mejora si el cython fuera capaz de leer desde un archivo csv(o lo que sea) directamente en lugar de tener que rastrear a través de un objeto Python.

 6
Author: John La Rooy,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-24 23:46:09

Puedes afeitarte unos segundos:

class Point2(object):
    __slots__ = ['a','b']
    def __init__(self, a, b):
        self.a = a
        self.b = b

def test_new(data):
    all_point_sets = []
    for row in data:
        first_points, second_points = row
        r0 = map(int, first_points.split(","))
        r1 = map(int, second_points.split(","))
        cp = map(Point2, r0, r1)
        all_point_sets.append(cp)

Que me dio

In [24]: %timeit test(d)
1 loops, best of 3: 5.07 s per loop

In [25]: %timeit test_new(d)
1 loops, best of 3: 3.29 s per loop

De forma intermitente fui capaz de afeitar otros 0.3 s mediante la pre-asignación de espacio en all_point_sets pero eso podría ser solo ruido. Y, por supuesto, está la forma anticuada de hacer las cosas más rápido:

localhost-2:coding $ pypy pointexam.py
1.58351397514
 2
Author: DSM,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-17 04:02:34

¿Qué tan apegado estás a tener tus coordenadas accesibles como atributos .x y .y? Para mi sorpresa, mis pruebas muestran que el mayor sumidero de tiempo no fueron las llamadas a list.append(), sino la construcción de los objetos Point. Tardan cuatro veces más en construirse que una tupla, y hay muchos de ellos. Simplemente reemplazando Point(int(x), int(y)) con una tupla (int(x), int(y)) en tu código afeitado con más de un 50% de descuento en el tiempo de ejecución total (Python 2.6 en Win XP). Tal vez su código actual todavía tiene espacio para optimizar esto?

Si realmente está decidido a acceder a las coordenadas con .x y .y, puede intentar usar collections.namedtuple. No es tan rápido como las tuplas simples, pero parece ser mucho más rápido que la clase de Par en su código (estoy cubriendo porque un punto de referencia de tiempo separado me dio resultados extraños).

Pair = namedtuple("Pair", "x y")  # instead of the Point class
...
curr_points = [ Pair(x, y) for x, y in paired_points ]

Si necesita seguir esta ruta, también vale la pena derivar una clase de tupla (costo mínimo sobre la tupla simple). Puedo proporcionar detalles si solicitar.

PS Veo que @MattAnderson mencionó el problema de la tupla-objeto hace mucho tiempo. Pero es un efecto importante (al menos en mi caja), incluso antes de deshabilitar la recolección de basura.

               Original code: total time:  15.79
      tuple instead of Point: total time:  7.328
 namedtuple instead of Point: total time:  9.140
 2
Author: alexis,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-25 22:43:03

Los datos son un archivo separado por tabulación, que consiste en listas de comas enteros separados.

Usando la muestra get_data() hice un archivo .csv como este:

1,6,2,8,2,3,5,9,6,6     10,4,10,5,7,9,6,1,9,5
6,2,2,5,2,2,1,7,7,9     7,6,7,1,3,7,6,2,10,5
8,8,9,2,6,10,10,7,8,9   4,2,10,3,4,4,1,2,2,9
...

Luego abusé del análisis optimizado para C a través de JSON:

def test2():
    import json
    import time
    time_start = time.time()
    with open('data.csv', 'rb') as f:
        data = f.read()
    data = '[[[' + ']],[['.join(data.splitlines()).replace('\t', '],[') + ']]]'
    all_point_sets = [Point(*xy) for row in json.loads(data) for xy in zip(*row)]
    time_end = time.time()
    print "total time: ", (time_end - time_start)

Resultados en mi caja: su test() ~8s original, con gc desactivado ~6s, mientras que mi versión (E/S incluida) da ~6s y ~4s respectivamente. Es decir, aproximadamente ~50% de velocidad. Pero mirando los datos del perfilador, es obvio que el mayor cuello de botella está en la instanciación de objetos en sí mismo, por lo que Matt Anderson's respuesta le neto la mayor ganancia en CPython.

 2
Author: Zart,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-10-30 09:56:17

No se si hay mucho que puedas hacer.

Puede usar generator para evitar las asignaciones de memoria adicionales. Esto me da un aumento de velocidad del 5%.

first_points  = (int(p) for p in first_points .split(","))
second_points = (int(p) for p in second_points.split(","))
paired_points = itertools.izip(first_points, second_points)
curr_points   = [Point(x, y) for x,y in paired_points]

Incluso colapsar todo el bucle en una comprensión de lista masiva no hace mucho.

all_point_sets = [
    [Point(int(x), int(y)) for x, y in itertools.izip(xs.split(','), ys.split(','))]
    for xs, ys in data
]

Si continúa iterando sobre esta gran lista, entonces podría convertirla en un generador. Eso distribuiría el costo de analizar los datos CSV para que no obtenga un gran éxito por adelantado.

all_point_sets = (
    [Point(int(x), int(y)) for x, y in itertools.izip(xs.split(','), ys.split(','))]
    for xs, ys in data
)
 1
Author: John Kugelman,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-17 03:38:41

Hay muchas buenas respuestas aquí. Un lado de este problema no abordado hasta ahora, sin embargo, es las diferencias de costo de tiempo de lista a cadena entre las diversas implementaciones de iteradores en python.

Hay un ensayo que prueba la eficiencia de diferentes iteradores con respecto a la conversión de lista a cadena en Python.org ensayos: list2str . Tenga en cuenta que cuando me encontré con problemas de optimización similares, pero con diferentes estructuras y tamaños de datos, los resultados presentados en el ensayo no todos escalar a un ritmo igual, por lo que vale la pena probar las diferentes implementaciones de iteradores para su caso de uso particular.

 0
Author: Nisan.H,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-10-25 16:48:37

Como el tiempo necesario para funciones integradas como zip(a,b) o map(int, string.split(",")) para matrices de longitud 2000000 es insignificante, tengo que suponer que la operación que consume más tiempo es append.

Por lo tanto, la forma correcta de abordar el problema es concatenar recursivamente las cadenas:
10 cadenas de 10 elementos a una cadena más grande
10 cadenas de 100 elementos
10 cadenas de 1000 elementos

Y finalmente a zip(map(int,huge_string_a.split(",")),map(int,huge_string_b.split(",")));

Es entonces solo ajuste fino para encontrar el óptimo base N para el método anexar y conquistar.

 0
Author: Aki Suihkonen,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-10-30 10:31:59