2016-03-23 14 views
5

Ho bisogno di creare una UDF da utilizzare in Python python che utilizza un oggetto java per i suoi calcoli interni.Implementare una UDF java e chiamarla da pyspark

Se si trattasse di un semplice pitone vorrei fare qualcosa di simile:

def f(x): 
    return 7 
fudf = pyspark.sql.functions.udf(f,pyspark.sql.types.IntegerType()) 

e chiamare utilizzando:

df = sqlContext.range(0,5) 
df2 = df.withColumn("a",fudf(df.id)).show() 

Tuttavia, l'attuazione della funzione di cui ho bisogno è in java e non in pitone. Ho bisogno di avvolgerlo in qualche modo così posso chiamarlo in modo simile da Python.

Il mio primo tentativo è stato quello di implementare l'oggetto java, quindi inserirlo in python in pyspark e convertirlo in UDF. Questo non è riuscito con l'errore di serializzazione. Codice

Java: codice

package com.test1.test2; 

public class TestClass1 { 
    Integer internalVal; 
    public TestClass1(Integer val1) { 
     internalVal = val1; 
    } 
    public Integer do_something(Integer val) { 
     return internalVal; 
    }  
} 

pyspark:

from py4j.java_gateway import java_import 
from pyspark.sql.functions import udf 
from pyspark.sql.types import IntegerType 
java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
audf = udf(a,IntegerType()) 

errore:

--------------------------------------------------------------------------- 
Py4JError         Traceback (most recent call last) 
<ipython-input-2-9756772ab14f> in <module>() 
     4 java_import(sc._gateway.jvm, "com.test1.test2.TestClass1") 
     5 a = sc._gateway.jvm.com.test1.test2.TestClass1(7) 
----> 6 audf = udf(a,IntegerType()) 

/usr/local/spark/python/pyspark/sql/functions.py in udf(f, returnType) 
    1595  [Row(slen=5), Row(slen=3)] 
    1596  """ 
-> 1597  return UserDefinedFunction(f, returnType) 
    1598 
    1599 blacklist = ['map', 'since', 'ignore_unicode_prefix'] 

/usr/local/spark/python/pyspark/sql/functions.py in __init__(self, func, returnType, name) 
    1556   self.returnType = returnType 
    1557   self._broadcast = None 
-> 1558   self._judf = self._create_judf(name) 
    1559 
    1560  def _create_judf(self, name): 

/usr/local/spark/python/pyspark/sql/functions.py in _create_judf(self, name) 
    1565   command = (func, None, ser, ser) 
    1566   sc = SparkContext.getOrCreate() 
-> 1567   pickled_command, broadcast_vars, env, includes = _prepare_for_python_RDD(sc, command, self) 
    1568   ctx = SQLContext.getOrCreate(sc) 
    1569   jdt = ctx._ssql_ctx.parseDataType(self.returnType.json()) 

/usr/local/spark/python/pyspark/rdd.py in _prepare_for_python_RDD(sc, command, obj) 
    2297  # the serialized command will be compressed by broadcast 
    2298  ser = CloudPickleSerializer() 
-> 2299  pickled_command = ser.dumps(command) 
    2300  if len(pickled_command) > (1 << 20): # 1M 
    2301   # The broadcast will have same life cycle as created PythonRDD 

/usr/local/spark/python/pyspark/serializers.py in dumps(self, obj) 
    426 
    427  def dumps(self, obj): 
--> 428   return cloudpickle.dumps(obj, 2) 
    429 
    430 

/usr/local/spark/python/pyspark/cloudpickle.py in dumps(obj, protocol) 
    644 
    645  cp = CloudPickler(file,protocol) 
--> 646  cp.dump(obj) 
    647 
    648  return file.getvalue() 

/usr/local/spark/python/pyspark/cloudpickle.py in dump(self, obj) 
    105   self.inject_addons() 
    106   try: 
--> 107    return Pickler.dump(self, obj) 
    108   except RuntimeError as e: 
    109    if 'recursion' in e.args[0]: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in dump(self, obj) 
    222   if self.proto >= 2: 
    223    self.write(PROTO + chr(self.proto)) 
--> 224   self.save(obj) 
    225   self.write(STOP) 
    226 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    566   write(MARK) 
    567   for element in obj: 
--> 568    save(element) 
    569 
    570   if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function(self, obj, name) 
    191   if islambda(obj) or obj.__code__.co_filename == '<stdin>' or themodule is None: 
    192    #print("save global", islambda(obj), obj.__code__.co_filename, modname, themodule) 
--> 193    self.save_function_tuple(obj) 
    194    return 
    195   else: 

/usr/local/spark/python/pyspark/cloudpickle.py in save_function_tuple(self, func) 
    234   # create a skeleton function object and memoize it 
    235   save(_make_skel_func) 
--> 236   save((code, closure, base_globals)) 
    237   write(pickle.REDUCE) 
    238   self.memoize(func) 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_tuple(self, obj) 
    552   if n <= 3 and proto >= 2: 
    553    for element in obj: 
--> 554     save(element) 
    555    # Subtle. Same as in the big comment below. 
    556    if id(obj) in memo: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    284   f = self.dispatch.get(t) 
    285   if f: 
--> 286    f(self, obj) # Call unbound method with explicit self 
    287    return 
    288 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save_list(self, obj) 
    604 
    605   self.memoize(obj) 
--> 606   self._batch_appends(iter(obj)) 
    607 
    608  dispatch[ListType] = save_list 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in _batch_appends(self, items) 
    637     write(MARK) 
    638     for x in tmp: 
--> 639      save(x) 
    640     write(APPENDS) 
    641    elif n: 

/home/mendea3/anaconda2/lib/python2.7/pickle.pyc in save(self, obj) 
    304    reduce = getattr(obj, "__reduce_ex__", None) 
    305    if reduce: 
--> 306     rv = reduce(self.proto) 
    307    else: 
    308     reduce = getattr(obj, "__reduce__", None) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    811   answer = self.gateway_client.send_command(command) 
    812   return_value = get_return_value(
--> 813    answer, self.gateway_client, self.target_id, self.name) 
    814 
    815   for temp_arg in temp_args: 

/usr/local/spark/python/pyspark/sql/utils.py in deco(*a, **kw) 
    43  def deco(*a, **kw): 
    44   try: 
---> 45    return f(*a, **kw) 
    46   except py4j.protocol.Py4JJavaError as e: 
    47    s = e.java_exception.toString() 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/protocol.py in get_return_value(answer, gateway_client, target_id, name) 
    310     raise Py4JError(
    311      "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". 
--> 312      format(target_id, ".", name, value)) 
    313   else: 
    314    raise Py4JError(

Py4JError: An error occurred while calling o18.__getnewargs__. Trace: 
py4j.Py4JException: Method __getnewargs__([]) does not exist 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:335) 
    at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:344) 
    at py4j.Gateway.invoke(Gateway.java:252) 
    at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:133) 
    at py4j.commands.CallCommand.execute(CallCommand.java:79) 
    at py4j.GatewayConnection.run(GatewayConnection.java:209) 
    at java.lang.Thread.run(Thread.java:745) 

EDIT: Ho anche provato a fare la classe java serializzabile ma inutilmente .

Il mio secondo tentativo è stato quello di definire l'UDF in java per cominciare, ma che non è riuscita come io non sono sicuro di come avvolgere correttamente:

codice

java: pacchetto com.test1.test2;

import org.apache.spark.sql.api.java.UDF1; 

public class TestClassUdf implements UDF1<Integer, Integer> { 

    Integer retval; 

    public TestClassUdf(Integer val) { 
     retval = val; 
    } 

    @Override 
    public Integer call(Integer arg0) throws Exception { 
     return retval; 
    } 
} 

ma come utilizzarlo? ho provato:

from py4j.java_gateway import java_import 
java_import(sc._gateway.jvm, "com.test1.test2.TestClassUdf") 
a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
dfint = sqlContext.range(0,15) 
df = dfint.withColumn("a",a(dfint.id)) 

ma ottengo:

--------------------------------------------------------------------------- 
TypeError         Traceback (most recent call last) 
<ipython-input-5-514811090b5f> in <module>() 
     3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf(7) 
     4 dfint = sqlContext.range(0,15) 
----> 5 df = dfint.withColumn("a",a(dfint.id)) 

TypeError: 'JavaObject' object is not callable 

e ho provato ad usare a.call invece di una:

df = dfint.withColumn("a",a.call(dfint.id)) 

ma ottenuto: ----- -------------------------------------------------- -------------------- TypeError Traceback (ultima chiamata ultima) in() 3 a = sc._gateway.jvm.com.test1.test2.TestClassUdf (7) 4 dfint = sqlContext.range (0,15) ----> 5 df = dfint.withColumn ("a", a .call (dfint.id))

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in __call__(self, *args) 
    796  def __call__(self, *args): 
    797   if self.converters is not None and len(self.converters) > 0: 
--> 798    (new_args, temp_args) = self._get_args(args) 
    799   else: 
    800    new_args = args 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_gateway.py in _get_args(self, args) 
    783     for converter in self.gateway_client.converters: 
    784      if converter.can_convert(arg): 
--> 785       temp_arg = converter.convert(arg, self.gateway_client) 
    786       temp_args.append(temp_arg) 
    787       new_args.append(temp_arg) 

/usr/local/spark/python/lib/py4j-0.9-src.zip/py4j/java_collections.py in convert(self, object, gateway_client) 
    510   HashMap = JavaClass("java.util.HashMap", gateway_client) 
    511   java_map = HashMap() 
--> 512   for key in object.keys(): 
    513    java_map[key] = object[key] 
    514   return java_map 

TypeError: 'Column' object is not callable 

Qualsiasi aiuto sarebbe appagato.

risposta

3

Ho ottenuto questo lavoro con l'aiuto di another question (and answer) of your own su UDAF.

Spark fornisce un metodo udf() per il wrapping di Scala FunctionN, in modo che possiamo avvolgere la funzione Java in Scala e utilizzarlo. Il tuo metodo Java deve essere statico o su una classe che sia implements Serializable.

package com.example 

import org.apache.spark.sql.UserDefinedFunction 
import org.apache.spark.sql.functions.udf 

class MyUdf extends Serializable { 
    def getUdf: UserDefinedFunction = udf(() => MyJavaClass.MyJavaMethod()) 
} 

Utilizzo in PySpark:

def my_udf(): 
    from pyspark.sql.column import Column, _to_java_column, _to_seq 
    pcls = "com.example.MyUdf" 
    jc = sc._jvm.java.lang.Thread.currentThread() \ 
     .getContextClassLoader().loadClass(pcls).newInstance().getUdf().apply 
    return Column(jc(_to_seq(sc, [], _to_java_column))) 

rdd1 = sc.parallelize([{'c1': 'a'}, {'c1': 'b'}, {'c1': 'c'}]) 
df1 = rdd1.toDF() 
df2 = df1.withColumn('mycol', my_udf()) 

Come con l'UDAF nella vostra altra domanda e risposta, si può passare colonne in esso con return Column(jc(_to_seq(sc, ["col1", "col2"], _to_java_column)))

Problemi correlati