Comparación de dos matrices numpy para la igualdad, elemento-sabio
¿Cuál es la forma más sencilla de comparar dos matrices numpy para la igualdad (donde la igualdad se define como: A = B iff para todos los índices i: A[i] == B[i]
)?
Simplemente usando ==
me da una matriz booleana:
>>> numpy.array([1,1,1]) == numpy.array([1,1,1])
array([ True, True, True], dtype=bool)
¿Tengo que and
los elementos de esta matriz para determinar si las matrices son iguales, o hay una forma más simple de comparar?
4 answers
(A==B).all()
Pruebe si todos los valores de la matriz (A==B) son Verdaderos.
Edit (de la respuesta de dbaupp y el comentario de yoavram)
Cabe señalar que:
- esta solución puede tener un comportamiento extraño en un caso particular: si
A
oB
está vacía y la otra contiene un solo elemento, a continuación, volverTrue
. Por alguna razón, la comparaciónA==B
devuelve un array vacío, para el cual el operadorall
devuelveTrue
. - Otro riesgo es si
A
yB
no tienen la misma forma y no son broadcastable, entonces este enfoque generará un error.
En conclusión, la solución que propuse es la estándar, creo, pero si tiene alguna duda sobre A
y B
la forma o simplemente quiere estar seguro: use una de las funciones especializadas:
np.array_equal(A,B) # test if same shape, same elements values
np.array_equiv(A,B) # test if broadcastable shape, same elements values
np.allclose(A,B,...) # test if same shape, elements have close enough values
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2014-10-06 07:56:59
La solución (A==B).all()
es muy ordenada, pero hay algunas funciones integradas para esta tarea. Es decir,array_equal
, allclose
y array_equiv
.
(Aunque algunas pruebas rápidas con timeit
parecen indicar que el método (A==B).all()
es el más rápido, lo cual es un poco peculiar, dado que tiene que asignar una matriz completamente nueva.)
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2012-05-14 11:00:58
Vamos a medir el rendimiento utilizando el siguiente fragmento de código.
import numpy as np
import time
exec_time0 = []
exec_time1 = []
exec_time2 = []
sizeOfArray = 5000
numOfIterations = 200
for i in xrange(numOfIterations):
A = np.random.randint(0,255,(sizeOfArray,sizeOfArray))
B = np.random.randint(0,255,(sizeOfArray,sizeOfArray))
a = time.clock()
res = (A==B).all()
b = time.clock()
exec_time0.append( b - a )
a = time.clock()
res = np.array_equal(A,B)
b = time.clock()
exec_time1.append( b - a )
a = time.clock()
res = np.array_equiv(A,B)
b = time.clock()
exec_time2.append( b - a )
print 'Method: (A==B).all(), ', np.mean(exec_time0)
print 'Method: np.array_equal(A,B),', np.mean(exec_time1)
print 'Method: np.array_equiv(A,B),', np.mean(exec_time2)
Salida
Method: (A==B).all(), 0.03031857
Method: np.array_equal(A,B), 0.030025185
Method: np.array_equiv(A,B), 0.030141515
De acuerdo con los resultados anteriores, los métodos numpy parecen ser más rápidos que la combinación de == operador y el método all () y comparando los métodos numpy el más rápido parece ser el numpy.método array_equal.
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-02-23 13:42:04
Si desea comprobar si dos matrices tienen el mismo shape
Y elements
debe usar np.array_equal
ya que es el método recomendado en la documentación.
En cuanto al rendimiento, no espere que ninguna comprobación de igualdad supere a otra, ya que no hay mucho espacio para optimizar
comparing two elements
. Sólo por el bien, todavía hice algunas pruebas.
import numpy as np
import timeit
A = np.zeros((300, 300, 3))
B = np.zeros((300, 300, 3))
C = np.ones((300, 300, 3))
timeit.timeit(stmt='(A==B).all()', setup='from __main__ import A, B', number=10**5)
timeit.timeit(stmt='np.array_equal(A, B)', setup='from __main__ import A, B, np', number=10**5)
timeit.timeit(stmt='np.array_equiv(A, B)', setup='from __main__ import A, B, np', number=10**5)
> 51.5094
> 52.555
> 52.761
Así que prácticamente igual, no hay necesidad de hablar de la velocidad.
El (A==B).all()
se comporta más o menos como el siguiente fragmento de código:
x = [1,2,3]
y = [1,2,3]
print all([x[i]==y[i] for i in range(len(x))])
> True
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2018-04-30 00:44:52