2014-11-09 16 views

risposta

48

Ciò dovrebbe essere possibile innanzitutto indicizzando l'RDD. La trasformazione zipWithIndex fornisce un'indicizzazione stabile, numerando ogni elemento nel suo ordine originale.

Dato: rdd = (a,b,c)

val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2)) 

di ricercare un elemento per indice, questo modulo non è utile. In primo luogo abbiamo bisogno di utilizzare l'indice come chiave:

val indexKey = withIndex.map{case (k,v) => (v,k)} //((0,a),(1,b),(2,c)) 

Ora, è possibile utilizzare l'azione lookup in PairRDD per trovare un elemento con chiave:

val b = indexKey.lookup(1) // Array(b) 

Se ci si aspetta di utilizzare lookup spesso sullo stesso RDD, consiglierei di memorizzare nella cache il RDD indexKey per migliorare le prestazioni.

Come fare questo utilizzando Java API è un esercizio rimasto per il lettore.

+10

Ricordare inoltre che l'ordine del RDD non è definito a meno che non sia ordinato. Può cambiare ogni volta che viene valutato diversamente. Puoi anche 'mappare (_. Swap)' per cambiare chiavi e valori. –

+2

Quello che hai lasciato "come esercizio per il lettore" è la principale complessità in questa domanda. – dirtv

+0

@ user1843507 Cosa considerereste mancanti? 'zipWithIndex' e' lookup' sono gli stessi in Java, quindi solo la funzione 'map' richiede un piccolo sforzo per scambiare i valori di (chiave, valore) in (valore, chiave) che non è il nucleo di questa domanda. – maasg

2

Ho provato questa classe a recuperare un oggetto per indice. Innanzitutto, quando si costruisce new IndexedFetcher(rdd, itemClass), viene conteggiato il numero di elementi in ciascuna partizione dell'RDD. Quindi, quando si chiama indexedFetcher.get(n), viene eseguito un processo solo sulla partizione che contiene tale indice.

Nota che dovevo compilarlo usando Java 1.7 anziché 1.8; a partire da Spark 1.1.0, il pacchetto org.objectweb.asm all'interno di com.esotericsoftware.reflectasm non può ancora leggere le classi di Java 1.8 (genera IllegalStateException quando si tenta di eseguire una funzione Java 1.8 di Java).

import java.io.Serializable; 

import org.apache.spark.SparkContext; 
import org.apache.spark.TaskContext; 
import org.apache.spark.rdd.RDD; 

import scala.reflect.ClassTag; 

public static class IndexedFetcher<E> implements Serializable { 
    private static final long serialVersionUID = 1L; 
    public final RDD<E> rdd; 
    public Integer[] elementsPerPartitions; 
    private Class<?> clazz; 
    public IndexedFetcher(RDD<E> rdd, Class<?> clazz){ 
     this.rdd = rdd; 
     this.clazz = clazz; 
     SparkContext context = this.rdd.context(); 
     ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class); 
     elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag); 
    } 
    public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable { 
     private static final long serialVersionUID = 1L; 
     @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) { 
      int count = 0; 
      while (iterator.hasNext()) { 
       count++; 
       iterator.next(); 
      } 
      return count; 
     } 
    } 
    static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() { 
     scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>(); 
     return function; 
    } 
    public E get(long index) { 
     long remaining = index; 
     long totalCount = 0; 
     for (int partition = 0; partition < elementsPerPartitions.length; partition++) { 
      if (remaining < elementsPerPartitions[partition]) { 
       return getWithinPartition(partition, remaining); 
      } 
      remaining -= elementsPerPartitions[partition]; 
      totalCount += elementsPerPartitions[partition]; 
     } 
     throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount)); 
    } 
    public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable { 
     private static final long serialVersionUID = 1L; 
     private final long indexWithinPartition; 
     public FetchWithinPartitionFunction(long indexWithinPartition) { 
      this.indexWithinPartition = indexWithinPartition; 
     } 
     @Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) { 
      int count = 0; 
      while (iterator.hasNext()) { 
       E element = iterator.next(); 
       if (count == indexWithinPartition) 
        return element; 
       count++; 
      } 
      throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count)); 
     } 
    } 
    public E getWithinPartition(int partition, long indexWithinPartition) { 
     System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition); 
     SparkContext context = rdd.context(); 
     scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition); 
     scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition}); 
     ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz); 
     E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag); 
     return result[0]; 
    } 
} 
2

sono rimasto bloccato su questo per un po 'e, in modo di espandere sulla risposta Maasg ma rispondendo a cercare un intervallo di valori per indice per Java (è necessario definire le 4 variabili in alto):

DataFrame df; 
SQLContext sqlContext; 
Long start; 
Long end; 

JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex(); 
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end); 
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema()); 

Ricordate che quando si esegue questo codice il cluster avrà bisogno di avere Java 8 (come un'espressione lambda è in uso).

Inoltre, zipWithIndex è probabilmente costoso!

+0

ciao, puoi guidare una soluzione simile per Java 7? – tortuga

+0

'Codice delle rdd.filter (nuova funzione , booleano>() { chiamata booleano pubblico (Tuple2 v1) { ritorno v1._2> = start && v1._2

Problemi correlati