2016-03-08 16 views
5

Sto provando a creare una funzione aggregata definita dall'utente che posso chiamare da python. Ho provato a seguire la risposta alla domanda this. Io fondamentalmente implementate le seguenti (tratto da here):Avvolgimento di una funzione java in pyspark

package com.blu.bla; 
import java.util.ArrayList; 
import java.util.List; 
import org.apache.spark.sql.expressions.MutableAggregationBuffer; 
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction; 
import org.apache.spark.sql.types.StructField; 
import org.apache.spark.sql.types.StructType; 
import org.apache.spark.sql.types.DataType; 
import org.apache.spark.sql.types.DataTypes; 
import org.apache.spark.sql.Row; 

public class MySum extends UserDefinedAggregateFunction { 
    private StructType _inputDataType; 
    private StructType _bufferSchema; 
    private DataType _returnDataType; 

    public MySum() { 
     List<StructField> inputFields = new ArrayList<StructField>(); 
     inputFields.add(DataTypes.createStructField("inputDouble", DataTypes.DoubleType, true)); 
     _inputDataType = DataTypes.createStructType(inputFields); 

     List<StructField> bufferFields = new ArrayList<StructField>(); 
     bufferFields.add(DataTypes.createStructField("bufferDouble", DataTypes.DoubleType, true)); 
     _bufferSchema = DataTypes.createStructType(bufferFields); 

     _returnDataType = DataTypes.DoubleType; 
    } 

    @Override public StructType inputSchema() { 
     return _inputDataType; 
    } 

    @Override public StructType bufferSchema() { 
     return _bufferSchema; 
    } 

    @Override public DataType dataType() { 
     return _returnDataType; 
    } 

    @Override public boolean deterministic() { 
     return true; 
    } 

    @Override public void initialize(MutableAggregationBuffer buffer) { 
     buffer.update(0, null); 
    } 

    @Override public void update(MutableAggregationBuffer buffer, Row input) { 
     if (!input.isNullAt(0)) { 
      if (buffer.isNullAt(0)) { 
       buffer.update(0, input.getDouble(0)); 
      } else { 
       Double newValue = input.getDouble(0) + buffer.getDouble(0); 
       buffer.update(0, newValue); 
      } 
     } 
    } 

    @Override public void merge(MutableAggregationBuffer buffer1, Row buffer2) { 
     if (!buffer2.isNullAt(0)) { 
      if (buffer1.isNullAt(0)) { 
       buffer1.update(0, buffer2.getDouble(0)); 
      } else { 
       Double newValue = buffer2.getDouble(0) + buffer1.getDouble(0); 
       buffer1.update(0, newValue); 
      } 
     } 
    } 

    @Override public Object evaluate(Row buffer) { 
     if (buffer.isNullAt(0)) { 
      return null; 
     } else { 
      return buffer.getDouble(0); 
     } 
    } 
} 

ho poi compilato con tutte le dipendenze e pyspark corsa con --jars myjar.jar

In pyspark ho fatto:

df = sqlCtx.createDataFrame([(1.0, "a"), (2.0, "b"), (3.0, "C")], ["A", "B"]) 
from pyspark.sql.column import Column, _to_java_column, _to_seq 
from pyspark.sql import Row 

def myCol(col): 
    _f = sc._jvm.com.blu.bla.MySum.apply 
    return Column(_f(_to_seq(sc,[col], _to_java_column))) 
b = df.agg(myCol("A")) 

ho ottenuto il seguente errore:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-24-f45b2a367e67> in <module>() 
----> 1 b = df.agg(myCol("A")) 

<ipython-input-22-afcb8884e1db> in myCol(col) 
     4 def myCol(col): 
     5  _f = sc._jvm.com.blu.bla.MySum.apply 
----> 6  return Column(_f(_to_seq(sc,[col], _to_java_column))) 

TypeError: 'JavaPackage' object is not callable 

anche io ho provato ad aggiungere --driver-class-path alla chiamata pyspark ma ho ottenuto lo stesso risultato.

anche tentato di accedere alla classe Java tramite l'importazione java:

from py4j.java_gateway import java_import 
jvm = sc._gateway.jvm 
java_import(jvm, "com.bla.blu.MySum") 
def myCol2(col): 
    _f = jvm.bla.blu.MySum.apply 
    return Column(_f(_to_seq(sc,[col], _to_java_column))) 

ha anche cercato di creare semplicemente la classe (come suggerito here):

a = jvm.com.bla.blu.MySum() 

Tutti sono sempre lo stesso messaggio di errore.

Non riesco a capire quale sia il problema.

risposta

4

Quindi sembra che il problema principale era che tutte le opzioni per aggiungere il jar (--jars, percorso classe driver, SPARK_CLASSPATH) non funzionano correttamente se si fornisce un percorso relativo. Questo è probabilmente a causa di problemi con la directory di lavoro all'interno di ipython rispetto a dove ho eseguito pyspark.

Una volta modificato questo percorso in assoluto, funziona (non l'ho ancora testato su un cluster ma almeno funziona su un'installazione locale).

Inoltre, non sono sicuro se questo è un bug anche nella risposta here come quella risposta utilizza un'implementazione scala, tuttavia nel implementazione Java avevo bisogno di fare

def myCol(col): 
    _f = sc._jvm.com.blu.bla.MySum().apply 
    return Column(_f(_to_seq(sc,[col], _to_java_column))) 

Questo probabilmente non è realmente efficace poiché crea _f ogni volta, probabilmente dovrei definire _f al di fuori della funzione (di nuovo, ciò richiederebbe il test sul cluster) ma almeno ora fornisce la risposta funzionale corretta

+0

un'ultima cosa, per riferimento futuro questo è stato testato su installazione spark 1.6.0 local (single node) –

+2

Testato su un cluster a e funziona. Usato --jars AND --driver-class-path insieme (apparentemente --jars non imposta il classpath sul driver) –

+0

Ciao! Sto cercando di fare una cosa simile ma sto ottenendo lo stesso errore, non sono in grado di capirlo, potresti fornirci un po 'di informazioni su come farlo funzionare. – StarLord