Cómo obtener un elemento por índice en Spark RDD (Java)


Conozco el método rdd.first () que me da el primer elemento en un RDD.

También existe el método rdd.toma (num) Que me da los primeros elementos "num".

¿Pero no existe la posibilidad de obtener un elemento por índice?

Gracias.

Author: Community, 2014-11-09

3 answers

Esto debería ser posible indexando primero el RDD. La transformación zipWithIndex proporciona una indexación estable, numerando cada elemento en su orden original.

Dado: rdd = (a,b,c)

val withIndex = rdd.zipWithIndex // ((a,0),(b,1),(c,2))

Para buscar un elemento por índice, este formulario no es útil. Primero necesitamos usar el índice como clave:

val indexKey = withIndex.map{case (k,v) => (v,k)}  //((0,a),(1,b),(2,c))

Ahora, es posible usar la acción lookup en PairRDD para encontrar un elemento por clave:

val b = indexKey.lookup(1) // Array(b)

Si está esperando usar lookup a menudo en el mismo RDD, le recomendaría almacenar en caché el indexKey RDD para mejorar el rendimiento.

Cómo hacer esto usando la API de Java es un ejercicio que le queda al lector.

 53
Author: maasg,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-07-11 17:40:15

He intentado esta clase para obtener un elemento por índice. Primero, cuando construyes new IndexedFetcher(rdd, itemClass), cuenta el número de elementos en cada partición del RDD. Luego, cuando llama a indexedFetcher.get(n), ejecuta un trabajo solo en la partición que contiene ese índice.

Tenga en cuenta que necesitaba compilar esto usando Java 1.7 en lugar de 1.8; a partir de Spark 1.1.0, la org incluida.objectweb.asm dentro de com.software esotérico.reflectasm aún no puede leer las clases Java 1.8 (lanza IllegalStateException cuando intenta ejecutar un Java 1.8 función).

import java.io.Serializable;

import org.apache.spark.SparkContext;
import org.apache.spark.TaskContext;
import org.apache.spark.rdd.RDD;

import scala.reflect.ClassTag;

public static class IndexedFetcher<E> implements Serializable {
    private static final long serialVersionUID = 1L;
    public final RDD<E> rdd;
    public Integer[] elementsPerPartitions;
    private Class<?> clazz;
    public IndexedFetcher(RDD<E> rdd, Class<?> clazz){
        this.rdd = rdd;
        this.clazz = clazz;
        SparkContext context = this.rdd.context();
        ClassTag<Integer> intClassTag = scala.reflect.ClassTag$.MODULE$.<Integer>apply(Integer.class);
        elementsPerPartitions = (Integer[]) context.<E, Integer>runJob(rdd, IndexedFetcher.<E>countFunction(), intClassTag);
    }
    public static class IteratorCountFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, Integer> implements Serializable {
        private static final long serialVersionUID = 1L;
        @Override public Integer apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                count++;
                iterator.next();
            }
            return count;
        }
    }
    static <E> scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> countFunction() {
        scala.Function2<TaskContext, scala.collection.Iterator<E>, Integer> function = new IteratorCountFunction<E>();
        return function;
    }
    public E get(long index) {
        long remaining = index;
        long totalCount = 0;
        for (int partition = 0; partition < elementsPerPartitions.length; partition++) {
            if (remaining < elementsPerPartitions[partition]) {
                return getWithinPartition(partition, remaining);
            }
            remaining -= elementsPerPartitions[partition];
            totalCount += elementsPerPartitions[partition];
        }
        throw new IllegalArgumentException(String.format("Get %d within RDD that has only %d elements", index, totalCount));
    }
    public static class FetchWithinPartitionFunction<E> extends scala.runtime.AbstractFunction2<TaskContext, scala.collection.Iterator<E>, E> implements Serializable {
        private static final long serialVersionUID = 1L;
        private final long indexWithinPartition;
        public FetchWithinPartitionFunction(long indexWithinPartition) {
            this.indexWithinPartition = indexWithinPartition;
        }
        @Override public E apply(TaskContext taskContext, scala.collection.Iterator<E> iterator) {
            int count = 0;
            while (iterator.hasNext()) {
                E element = iterator.next();
                if (count == indexWithinPartition)
                    return element;
                count++;
            }
            throw new IllegalArgumentException(String.format("Fetch %d within partition that has only %d elements", indexWithinPartition, count));
        }
    }
    public E getWithinPartition(int partition, long indexWithinPartition) {
        System.out.format("getWithinPartition(%d, %d)%n", partition, indexWithinPartition);
        SparkContext context = rdd.context();
        scala.Function2<TaskContext, scala.collection.Iterator<E>, E> function = new FetchWithinPartitionFunction<E>(indexWithinPartition);
        scala.collection.Seq<Object> partitions = new scala.collection.mutable.WrappedArray.ofInt(new int[] {partition});
        ClassTag<E> classTag = scala.reflect.ClassTag$.MODULE$.<E>apply(this.clazz);
        E[] result = (E[]) context.<E, E>runJob(rdd, function, partitions, true, classTag);
        return result[0];
    }
}
 2
Author: yonran,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2014-11-17 20:58:01

También me quedé atascado en esto por un tiempo, así que para ampliar la respuesta de Maasg, pero respondiendo para buscar un rango de valores por índice para Java (necesitará definir las 4 variables en la parte superior):

DataFrame df;
SQLContext sqlContext;
Long start;
Long end;

JavaPairRDD<Row, Long> indexedRDD = df.toJavaRDD().zipWithIndex();
JavaRDD filteredRDD = indexedRDD.filter((Tuple2<Row,Long> v1) -> v1._2 >= start && v1._2 < end);
DataFrame filteredDataFrame = sqlContext.createDataFrame(filteredRDD, df.schema());

Recuerde que cuando ejecute este código, su clúster necesitará tener Java 8 (ya que una expresión lambda está en uso).

Además, zipWithIndex es probablemente caro!

 2
Author: Luke W,
Warning: date(): Invalid date.timezone value 'Europe/Kyiv', we selected the timezone 'UTC' for now. in /var/www/agent_stack/data/www/ajaxhispano.com/template/agent.layouts/content.php on line 61
2016-08-02 11:01:17