2015-05-15 13 views
21

ho alcuni dati nel seguente formato (sia RDD o una scintilla dataframe):Rimodellare/dati di snodo Spark RDD e/o scintille DataFrames

from pyspark.sql import SQLContext 
sqlContext = SQLContext(sc) 

rdd = sc.parallelize([('X01',41,'US',3), 
         ('X01',41,'UK',1), 
         ('X01',41,'CA',2), 
         ('X02',72,'US',4), 
         ('X02',72,'UK',6), 
         ('X02',72,'CA',7), 
         ('X02',72,'XX',8)]) 

# convert to a Spark DataFrame      
schema = StructType([StructField('ID', StringType(), True), 
        StructField('Age', IntegerType(), True), 
        StructField('Country', StringType(), True), 
        StructField('Score', IntegerType(), True)]) 

df = sqlContext.createDataFrame(rdd, schema) 

Quello che vorrei fare è quello di 'rimodellare' il dati, convertire alcune righe in Paese (in particolare Stati Uniti, Regno Unito e CA) nelle colonne:

ID Age US UK CA 
'X01' 41 3 1 2 
'X02' 72 4 6 7 

in sostanza, ho bisogno di qualcosa lungo le linee di pivot flusso di lavoro di Python:

categories = ['US', 'UK', 'CA'] 
new_df = df[df['Country'].isin(categories)].pivot(index = 'ID', 
                columns = 'Country', 
                values = 'Score') 

Il mio set di dati è piuttosto grande, quindi non posso davvero fare il collect() e ingerire i dati in memoria per eseguire il rimodellamento in Python stesso. C'è un modo per convertire Python .pivot() in una funzione invokable mentre si mappano o un RDD o Spark DataFrame? Qualsiasi aiuto sarebbe apprezzato!

+0

I dati mi sembrano funky. I dati per 'X02' non corrispondono all'output. –

+0

@David Griffin, non sono sicuro di cosa intendi quando dici "I dati X02 non corrispondono all'output". I dati nella mia domanda sono il mio costrutto ma rappresentano i registri delle chiamate Intelligent Voice Recognition (IVR). Spesso, vedrai migliaia di queste coppie KV (Paese, Punteggio dal mio esempio). Parte della pulizia di questi dati comporta l'estrazione solo delle chiavi e dei valori necessari. Se vedi il mio codice Python, fa esattamente questo (mentre filtra la chiave Country = 'XX'. – Jason

+1

Ad esempio, nella definizione RDD hai ('X02', 72, 'US', 7) che significherebbe un 7 nella colonna degli Stati Uniti, ma nei risultati si ha 'X02' con un 4 nella colonna degli Stati Uniti –

risposta

13

Da Spark 1.6 è possibile utilizzare pivot funzione GroupedData e fornire espressione di aggregazione.

pivoted = (df 
    .groupBy("ID", "Age") 
    .pivot(
     "Country", 
     ['US', 'UK', 'CA']) # Optional list of levels 
    .sum("Score")) # alternatively you can use .agg(expr)) 
pivoted.show() 

## +---+---+---+---+---+ 
## | ID|Age| US| UK| CA| 
## +---+---+---+---+---+ 
## |X01| 41| 3| 1| 2| 
## |X02| 72| 4| 6| 7| 
## +---+---+---+---+---+ 

I livelli possono essere omessi ma, se forniti, possono entrambi aumentare le prestazioni e fungere da filtro interno.

Questo metodo è ancora relativamente lento ma certamente batte manualmente i dati di passaggio manuale tra JVM e Python.

1

Quindi, prima di tutto, ho dovuto fare la correzione al RDD (che corrisponde al tuo uscita effettiva):

rdd = sc.parallelize([('X01',41,'US',3), 
         ('X01',41,'UK',1), 
         ('X01',41,'CA',2), 
         ('X02',72,'US',4), 
         ('X02',72,'UK',6), 
         ('X02',72,'CA',7), 
         ('X02',72,'XX',8)]) 

Una volta ho fatto che la correzione, questo ha fatto il trucco:

df.select($"ID", $"Age").groupBy($"ID").agg($"ID", first($"Age") as "Age") 
.join(
    df.select($"ID" as "usID", $"Country" as "C1",$"Score" as "US"), 
    $"ID" === $"usID" and $"C1" === "US" 
) 
.join(
    df.select($"ID" as "ukID", $"Country" as "C2",$"Score" as "UK"), 
    $"ID" === $"ukID" and $"C2" === "UK" 
) 
.join(
    df.select($"ID" as "caID", $"Country" as "C3",$"Score" as "CA"), 
    $"ID" === $"caID" and $"C3" === "CA" 
) 
.select($"ID",$"Age",$"US",$"UK",$"CA") 

Non quasi elegante come il tuo perno, di sicuro.

+0

David, non sono riuscito a far funzionare tutto questo: in primo luogo, Spark non accettava '$' come metodo per fare riferimento alle colonne.Dopo aver rimosso tutti i simboli '$', ricevo comunque un errore di sintassi che punta all'espressione .select nell'ultima riga del codice sopra – Jason

+0

Spiacente, sto usando Scala. È stato tagliato e incollato direttamente dalla conchiglia. Se togli l'ultima select(), dovresti ottenere i risultati corretti solo con troppe colonne. Puoi farlo e pubblicare i risultati? –

+0

Grazie David, questo funziona ... – Jason

7

Per prima cosa, probabilmente non è una buona idea, perché non si ottengono ulteriori informazioni, ma si sta vincolando con uno schema fisso (ovvero è necessario conoscere quanti paesi ci si aspetta e, naturalmente, , Paese aggiuntivo significa modifica del codice)

Detto questo, questo è un problema SQL, che viene mostrato di seguito. Ma nel caso in cui si supponga che non è troppo "software come" (sul serio, ho sentito questo !!), quindi è possibile fare riferimento alla prima soluzione.

Soluzione 1:

def reshape(t): 
    out = [] 
    out.append(t[0]) 
    out.append(t[1]) 
    for v in brc.value: 
     if t[2] == v: 
      out.append(t[3]) 
     else: 
      out.append(0) 
    return (out[0],out[1]),(out[2],out[3],out[4],out[5]) 
def cntryFilter(t): 
    if t[2] in brc.value: 
     return t 
    else: 
     pass 

def addtup(t1,t2): 
    j=() 
    for k,v in enumerate(t1): 
     j=j+(t1[k]+t2[k],) 
    return j 

def seq(tIntrm,tNext): 
    return addtup(tIntrm,tNext) 

def comb(tP,tF): 
    return addtup(tP,tF) 


countries = ['CA', 'UK', 'US', 'XX'] 
brc = sc.broadcast(countries) 
reshaped = calls.filter(cntryFilter).map(reshape) 
pivot = reshaped.aggregateByKey((0,0,0,0),seq,comb,1) 
for i in pivot.collect(): 
    print i 

Ora, Soluzione 2: Certo meglio come SQL è strumento giusto per questo

callRow = calls.map(lambda t: 

Row(userid=t[0],age=int(t[1]),country=t[2],nbrCalls=t[3])) 
callsDF = ssc.createDataFrame(callRow) 
callsDF.printSchema() 
callsDF.registerTempTable("calls") 
res = ssc.sql("select userid,age,max(ca),max(uk),max(us),max(xx)\ 
        from (select userid,age,\ 
            case when country='CA' then nbrCalls else 0 end ca,\ 
            case when country='UK' then nbrCalls else 0 end uk,\ 
            case when country='US' then nbrCalls else 0 end us,\ 
            case when country='XX' then nbrCalls else 0 end xx \ 
          from calls) x \ 
        group by userid,age") 
res.show() 

dati impostati:

data=[('X01',41,'US',3),('X01',41,'UK',1),('X01',41,'CA',2),('X02',72,'US',4),('X02',72,'UK',6),('X02',72,'CA',7),('X02',72,'XX',8)] 
calls = sc.parallelize(data,1) 
countries = ['CA', 'UK', 'US', 'XX'] 

Risultato:

Dalla prima soluzione

(('X02', 72), (7, 6, 4, 8)) 
(('X01', 41), (2, 1, 3, 0)) 

Dal 2 soluzione:

root |-- age: long (nullable = true) 
     |-- country: string (nullable = true) 
     |-- nbrCalls: long (nullable = true) 
     |-- userid: string (nullable = true) 

userid age ca uk us xx 
X02 72 7 6 4 8 
X01 41 2 1 3 0 

lasciato gentilmente sapere se questo funziona, o no :)

Miglior Ayan

+0

grazie ... le tue soluzioni funzionano e, cosa più importante, sono scalabili! – Jason

+1

Sei in grado di estenderlo a un caso più generico? Ad esempio, una volta nei miei dati potrei avere 3 paesi. Un'altra volta potrei avere 5. Quello che hai sopra sembra essere codificato in 4 paesi specifici. Ho capito che ho bisogno di sapere quali paesi ho in anticipo, ma potrebbe cambiare col passare del tempo. Come posso passare un elenco di paesi come parametro e continuare a farlo funzionare? Questa è una cosa abbastanza comune da fare nel lavorare con i dati, quindi spero che questo possa essere integrato nelle funzionalità molto presto. –

+0

Come ho notato, questo è un problema con la progettazione dello schema. "Non puoi" solo passare un elenco di paesi, perché lo schema cambierà a valle. Tuttavia, potresti ** solo ** ottenere con la restituzione di una tupla generalizzata da rimodellare e impostare valori zero per aggregateByKey. Nel metodo SQL, è necessario "generare" in modo programmatico uno sql seguendo lo schema qui descritto. –

4

Ecco un approccio Spark nativa che non lo fa cablare i nomi delle colonne. È basato su aggregateByKey e utilizza un dizionario per raccogliere le colonne visualizzate per ogni chiave. Quindi raccogliamo tutti i nomi delle colonne per creare il dataframe finale. [La versione precedente utilizzava jsonRDD dopo l'emissione di un dizionario per ogni record, ma questo è più efficiente.] Limitare a un elenco specifico di colonne o escludere quelli come XX sarebbe una modifica semplice.

Le prestazioni sembrano buone anche su tavoli abbastanza grandi. Sto usando una variazione che conta il numero di volte in cui ognuno di un numero variabile di eventi si verifica per ciascun ID, generando una colonna per tipo di evento. Il codice è fondamentalmente lo stesso eccetto che usa una collection.Counter invece di un dict nel seqFn per contare le occorrenze.

from pyspark.sql.types import * 

rdd = sc.parallelize([('X01',41,'US',3), 
         ('X01',41,'UK',1), 
         ('X01',41,'CA',2), 
         ('X02',72,'US',4), 
         ('X02',72,'UK',6), 
         ('X02',72,'CA',7), 
         ('X02',72,'XX',8)]) 

schema = StructType([StructField('ID', StringType(), True), 
        StructField('Age', IntegerType(), True), 
        StructField('Country', StringType(), True), 
        StructField('Score', IntegerType(), True)]) 

df = sqlCtx.createDataFrame(rdd, schema) 

def seqPivot(u, v): 
    if not u: 
     u = {} 
    u[v.Country] = v.Score 
    return u 

def cmbPivot(u1, u2): 
    u1.update(u2) 
    return u1 

pivot = (
    df 
    .rdd 
    .keyBy(lambda row: row.ID) 
    .aggregateByKey(None, seqPivot, cmbPivot) 
) 
columns = (
    pivot 
    .values() 
    .map(lambda u: set(u.keys())) 
    .reduce(lambda s,t: s.union(t)) 
) 
result = sqlCtx.createDataFrame(
    pivot 
    .map(lambda (k, u): [k] + [u.get(c) for c in columns]), 
    schema=StructType(
     [StructField('ID', StringType())] + 
     [StructField(c, IntegerType()) for c in columns] 
    ) 
) 
result.show() 

produce:

ID CA UK US XX 
X02 7 6 4 8 
X01 2 1 3 null 
+0

Nice writeup - btw spark 1.6 dataframes support pivots https: //github.com/apache/spark/pull/7841 – meyerson

+0

Cool: la scintilla sta migliorando rapidamente. – patricksurry

0

C'è un JIRA in Hive per PIVOT per fare questo in modo nativo, senza una dichiarazione CASE enorme per ogni valore:

https://issues.apache.org/jira/browse/HIVE-3776

Si prega di votare che JIRA up quindi verrà implementato prima. Una volta in Hive SQL, Spark di solito non gli manca troppo e alla fine verrà implementato anche in Spark.

1

Solo alcune osservazioni sulla risposta molto utile di patricksurry:

  • colonna Età manca, quindi basta aggiungere u [ "Age"] = v.Age alla funzione seqPivot
  • esso si è scoperto che entrambi i loop sugli elementi delle colonne davano gli elementi in un ordine diverso. I valori delle colonne erano corretti, ma non i loro nomi. Per evitare questo comportamento basta ordinare l'elenco delle colonne.

Ecco il codice leggermente modificata:

from pyspark.sql.types import * 

rdd = sc.parallelize([('X01',41,'US',3), 
         ('X01',41,'UK',1), 
         ('X01',41,'CA',2), 
         ('X02',72,'US',4), 
         ('X02',72,'UK',6), 
         ('X02',72,'CA',7), 
         ('X02',72,'XX',8)]) 

schema = StructType([StructField('ID', StringType(), True), 
        StructField('Age', IntegerType(), True), 
        StructField('Country', StringType(), True), 
        StructField('Score', IntegerType(), True)]) 

df = sqlCtx.createDataFrame(rdd, schema) 

# u is a dictionarie 
# v is a Row 
def seqPivot(u, v): 
    if not u: 
     u = {} 
    u[v.Country] = v.Score 
    # In the original posting the Age column was not specified 
    u["Age"] = v.Age 
    return u 

# u1 
# u2 
def cmbPivot(u1, u2): 
    u1.update(u2) 
    return u1 

pivot = (
    rdd 
    .map(lambda row: Row(ID=row[0], Age=row[1], Country=row[2], Score=row[3])) 
    .keyBy(lambda row: row.ID) 
    .aggregateByKey(None, seqPivot, cmbPivot) 
) 

columns = (
    pivot 
    .values() 
    .map(lambda u: set(u.keys())) 
    .reduce(lambda s,t: s.union(t)) 
) 

columns_ord = sorted(columns) 

result = sqlCtx.createDataFrame(
    pivot 
    .map(lambda (k, u): [k] + [u.get(c, None) for c in columns_ord]), 
     schema=StructType(
      [StructField('ID', StringType())] + 
      [StructField(c, IntegerType()) for c in columns_ord] 
     ) 
    ) 

print result.show() 

Infine, l'uscita dovrebbe essere

+---+---+---+---+---+----+ 
| ID|Age| CA| UK| US| XX| 
+---+---+---+---+---+----+ 
|X02| 72| 7| 6| 4| 8| 
|X01| 41| 2| 1| 3|null| 
+---+---+---+---+---+----+ 
Problemi correlati