Remodelación / pivote de datos en marcos de datos Spark RDD y / o Spark


Tengo algunos datos en el siguiente formato (ya sea RDD o Spark DataFrame):

from pyspark.sql import SQLContext
sqlContext = SQLContext(sc)

 rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

# convert to a Spark DataFrame                    
schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlContext.createDataFrame(rdd, schema)

Lo que me gustaría hacer es 'remodelar' los datos, convertir ciertas filas en el País (específicamente EE. UU., Reino Unido y CA) en columnas:

ID    Age  US  UK  CA  
'X01'  41  3   1   2  
'X02'  72  4   6   7   

Esencialmente, necesito algo en la línea del flujo de trabajo de Python pivot:

categories = ['US', 'UK', 'CA']
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                                                  columns = 'Country',
                                                  values = 'Score')

Mi conjunto de datos es bastante grande, así que realmente no puedo collect() e ingerir los datos en la memoria para hacer la remodelación en Python. Hay una manera de convertir Python .pivot() ¿en una función invocable mientras se mapea un RDD o un DataFrame Spark? Cualquier ayuda sería apreciada!

Author: Joshua Taylor, 2015-05-15

6 answers

Desde Spark 1.6 puede usar pivot función en GroupedData y proporcionar expresión agregada.

pivoted = (df
    .groupBy("ID", "Age")
    .pivot(
        "Country",
        ['US', 'UK', 'CA'])  # Optional list of levels
    .sum("Score"))  # alternatively you can use .agg(expr))
pivoted.show()

## +---+---+---+---+---+
## | ID|Age| US| UK| CA|
## +---+---+---+---+---+
## |X01| 41|  3|  1|  2|
## |X02| 72|  4|  6|  7|
## +---+---+---+---+---+

Los niveles se pueden omitir, pero si se proporcionan pueden aumentar el rendimiento y servir como un filtro interno.

Este método todavía es relativamente lento, pero ciertamente supera el paso manual de datos manualmente entre JVM y Python.

 14
Author: zero323,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-06-08 21:53:30

En primer lugar, esto probablemente no es una buena idea, porque no está recibiendo ninguna información adicional, pero se está vinculando a sí mismo con un esquema fijo (es decir, debe necesitar saber cuántos países está esperando, y por supuesto, país adicional significa cambio en el código)

Dicho esto, este es un problema SQL, que se muestra a continuación. Pero en caso de que suponga que no es demasiado "software como" (en serio, he oído esto!!), entonces usted puede referir la primera solución.

Solución 1:

def reshape(t):
    out = []
    out.append(t[0])
    out.append(t[1])
    for v in brc.value:
        if t[2] == v:
            out.append(t[3])
        else:
            out.append(0)
    return (out[0],out[1]),(out[2],out[3],out[4],out[5])
def cntryFilter(t):
    if t[2] in brc.value:
        return t
    else:
        pass

def addtup(t1,t2):
    j=()
    for k,v in enumerate(t1):
        j=j+(t1[k]+t2[k],)
    return j

def seq(tIntrm,tNext):
    return addtup(tIntrm,tNext)

def comb(tP,tF):
    return addtup(tP,tF)


countries = ['CA', 'UK', 'US', 'XX']
brc = sc.broadcast(countries)
reshaped = calls.filter(cntryFilter).map(reshape)
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1)
for i in pivot.collect():
    print i

Ahora, Solución 2: Por supuesto mejor como SQL es la herramienta correcta para esto

callRow = calls.map(lambda t:   

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3]))
callsDF = ssc.createDataFrame(callRow)
callsDF.printSchema()
callsDF.registerTempTable("calls")
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\
                    from (select userid,age,\
                                  case when country='CA' then nbrCalls else 0 end ca,\
                                  case when country='UK' then nbrCalls else 0 end uk,\
                                  case when country='US' then nbrCalls else 0 end us,\
                                  case when country='XX' then nbrCalls else 0 end xx \
                             from calls) x \
                     group by userid,age")
res.show()

Configuración de datos:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)]
 calls = sc.parallelize(data,1)
countries = ['CA', 'UK', 'US', 'XX']

Resultado:

A partir de la 1a solución

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0))

De la 2a solución:

root  |-- age: long (nullable = true)  
      |-- country: string (nullable = true)  
      |-- nbrCalls: long (nullable = true)  
      |-- userid: string (nullable = true)

userid age ca uk us xx 
 X02    72  7  6  4  8  
 X01    41  2  1  3  0

Amablemente hazme saber si esto funciona, o no:)

Mejor Ayan

 7
Author: ayan guha,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-05-16 17:19:42

Aquí hay un enfoque nativo de Spark que no conecta los nombres de las columnas. Se basa en aggregateByKey, y utiliza un diccionario para recopilar las columnas que aparecen para cada clave. Luego reunimos todos los nombres de columna para crear el dataframe final. [La versión anterior usaba jsonRDD después de emitir un diccionario para cada registro, pero esto es más eficiente.] Restringir a una lista específica de columnas, o excluir algunas como XX sería una modificación fácil.

El rendimiento parece bueno incluso en bastante mesas grandes. Estoy usando una variación que cuenta el número de veces que cada uno de un número variable de eventos ocurre para cada ID, generando una columna por tipo de evento. El código es básicamente el mismo excepto que utiliza una colección.Contador en lugar de un dict en el seqFn para contar las ocurrencias.

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    return u

def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    df
    .rdd
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)
columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)
result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]),
    schema=StructType(
        [StructField('ID', StringType())] + 
        [StructField(c, IntegerType()) for c in columns]
    )
)
result.show()

Produce:

ID  CA UK US XX  
X02 7  6  4  8   
X01 2  1  3  null
 4
Author: patricksurry,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-06-24 13:29:48

Así que en primer lugar, tuve que hacer esta corrección a su RDD (que coincide con su salida real):

rdd = sc.parallelize([('X01',41,'US',3),
                      ('X01',41,'UK',1),
                      ('X01',41,'CA',2),
                      ('X02',72,'US',4),
                      ('X02',72,'UK',6),
                      ('X02',72,'CA',7),
                      ('X02',72,'XX',8)])

Una vez que hice esa corrección, esto hizo el truco:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age")
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"),
    $"ID" === $"usID" and $"C1" === "US"
)
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"),
    $"ID" === $"ukID" and $"C2" === "UK"
)
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA"
)
.select($"ID",$"Age",$"US",$"UK",$"CA")

No es tan elegante como tu pivote, seguro.

 1
Author: David Griffin,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-05-15 16:48:31

Solo algunos comentarios sobre la muy útil respuesta de patricksurry :

  • falta la edad de la columna, así que simplemente agregue u ["Age"] = v. Age a la función seqPivot
  • resultó que ambos bucles sobre los elementos de las columnas daban los elementos en un orden diferente. Los valores de las columnas eran correctos, pero no los nombres de las mismas. Para evitar este comportamiento, simplemente ordene la lista de columnas.

Aquí está el código ligeramente modificado:

from pyspark.sql.types import *

rdd = sc.parallelize([('X01',41,'US',3),
                       ('X01',41,'UK',1),
                       ('X01',41,'CA',2),
                       ('X02',72,'US',4),
                       ('X02',72,'UK',6),
                       ('X02',72,'CA',7),
                       ('X02',72,'XX',8)])

schema = StructType([StructField('ID', StringType(), True),
                     StructField('Age', IntegerType(), True),
                     StructField('Country', StringType(), True),
                     StructField('Score', IntegerType(), True)])

df = sqlCtx.createDataFrame(rdd, schema)

# u is a dictionarie
# v is a Row
def seqPivot(u, v):
    if not u:
        u = {}
    u[v.Country] = v.Score
    # In the original posting the Age column was not specified
    u["Age"] = v.Age
    return u

# u1
# u2
def cmbPivot(u1, u2):
    u1.update(u2)
    return u1

pivot = (
    rdd
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2],  Score=row[3]))
    .keyBy(lambda row: row.ID)
    .aggregateByKey(None, seqPivot, cmbPivot)
)

columns = (
    pivot
    .values()
    .map(lambda u: set(u.keys()))
    .reduce(lambda s,t: s.union(t))
)

columns_ord = sorted(columns)

result = sqlCtx.createDataFrame(
    pivot
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]),
        schema=StructType(
            [StructField('ID', StringType())] + 
            [StructField(c, IntegerType()) for c in columns_ord]
        )
    )

print result.show()

Finalmente, la salida debe ser

+---+---+---+---+---+----+
| ID|Age| CA| UK| US|  XX|
+---+---+---+---+---+----+
|X02| 72|  7|  6|  4|   8|
|X01| 41|  2|  1|  3|null|
+---+---+---+---+---+----+
 1
Author: rolpat,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-09-24 09:31:12

Hay una JIRA en Hive for PIVOT para hacer esto de forma nativa, sin una gran sentencia CASE para cada valor:

Https://issues.apache.org/jira/browse/HIVE-3776

Por favor voten para que JIRA se implemente antes. Una vez que esté en Hive SQL, Spark generalmente no carece demasiado de respaldo y, finalmente, también se implementará en Spark.

 0
Author: Tagar,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2015-09-01 19:12:39