- PySpark 的多进程架构;
- Python 端调用 Java、Scala 接口;
- Python Driver 端 RDD、SQL 接口;
- Executor 端进程间通信和序列化;
- Pandas UDF;
- 总结。
PySpark项目地址:https://github.com/apache/spark/tree/master/python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
代码实现在 python/pyspark/context.py: def _ensure_initialized(cls, instance=None, gateway=None, conf=None): """ Checks whether a SparkContext is initialized or not. Throws error if a SparkContext is already running. """ with SparkContext._lock: if not SparkContext._gateway: SparkContext._gateway = gateway or launch_gateway(conf) SparkContext._jvm = SparkContext._gateway.jvm 在 launch_gateway (python/pyspark/java_gateway.py) 中,首先启动 JVM 进程: SPARK_HOME = _find_spark_home() # Launch the Py4j gateway using Spark's run command so that we pick up the # proper classpath and settings from spark-env.sh on_windows = platform.system() == "Windows" script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit" command = [os.path.join(SPARK_HOME, script)] 然后创建 JavaGateway 并 import 一些关键的 class: gateway = JavaGateway( gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, auto_convert=True)) # Import the classes used by PySpark java_import(gateway.jvm, "org.apache.spark.SparkConf") java_import(gateway.jvm, "org.apache.spark.api.java.*") java_import(gateway.jvm, "org.apache.spark.api.python.*") java_import(gateway.jvm, "org.apache.spark.ml.python.*") java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*") # TODO(davies): move into sql java_import(gateway.jvm, "org.apache.spark.sql.*") java_import(gateway.jvm, "org.apache.spark.sql.api.python.*") java_import(gateway.jvm, "org.apache.spark.sql.hive.*") java_import(gateway.jvm, "scala.Tuple2") 拿到 JavaGateway 对象,即可以通过它的 jvm 属性,去调用 Java 的类了,例如: gateway = JavaGateway() gateway = JavaGateway() jvm = gateway.jvm l = jvm.java.util.ArrayList() 然后会继续创建 JVM 中的 SparkContext 对象: def _initialize_context(self, jconf): """ Initialize SparkContext in function to allow subclass specific initialization """ return self._jvm.JavaSparkContext(jconf) # Create the Java SparkContext through Py4J self._jsc = jsc or self._initialize_context(self._conf._jconf) |
1 2 3 4 5 6 7 |
def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None, valueConverter=None, conf=None, batchSize=0): jconf = self._dictToJavaMap(conf) jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass, valueClass, keyConverter, valueConverter, jconf, batchSize) return RDD(jrdd, self) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
object PythonEvals extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ArrowEvalPython(udfs, output, child, evalType) => ArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil case BatchEvalPython(udfs, output, child) => BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil case _ => Nil } } 创建了 ArrowEvalPythonExec 或者 BatchEvalPythonExec,而这二者内部会创建 ArrowPythonRunner、PythonUDFRunner 等类的对象实例,并调用了它们的 compute 方法。由于它们都继承了 BasePythonRunner,基类的 compute 方法中会去启动 Python 子进程: def compute( inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): Iterator[OUT] = { // ...... val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Start a thread to feed the process input from our parent's iterator val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) writerThread.start() val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = newReaderIterator( stream, writerThread, startTime, env, worker, releasedOrClosed, context) new InterruptibleIterator(context, stdoutIterator) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
val arrowWriter = ArrowWriter.create(root) val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() while (inputIterator.hasNext) { val nextBatch = inputIterator.next() while (nextBatch.hasNext) { arrowWriter.write(nextBatch.next()) } arrowWriter.finish() writer.writeBatch() arrowWriter.reset() 可以看到,每次取出一个 batch,填充给 ArrowWriter,实际数据会保存在 root 对象中,然后由 ArrowStreamWriter 将 root 对象中的整个 batch 的数据写入到 socket 的 DataOutputStream 中去。ArrowStreamWriter 会调用 writeBatch 方法去序列化消息并写数据,代码参考 ArrowWriter.java#L131。 protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException { ArrowBlock block = MessageSerializer.serialize(out, batch, option); LOGGER.debug("RecordBatch at {}, metadata: {}, body: {}", block.getOffset(), block.getMetadataLength(), block.getBodyLength()); return block; } 在 MessageSerializer 中,使用了 flatbuffer 来序列化数据。flatbuffer 是一种比较高效的序列化协议,它的主要优点是反序列化的时候,不需要解码,可以直接通过裸 buffer 来读取字段,可以认为反序列化的开销为零。我们来看看 Python 进程收到消息后是如何反序列化的。 Python 子进程实际上是执行了 worker.py 的 main 函数 (python/pyspark/worker.py): if __name__ == '__main__': # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file) 这里会去向 JVM 建立连接,并从 socket 中读取指令和数据。对于如何进行序列化、反序列化,是通过 UDF 的类型来区分: eval_type = read_int(infile) if eval_type == PythonEvalType.NON_UDF: func, profiler, deserializer, serializer = read_command(pickleSer, infile) else: func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) 在 read_udfs 中,如果是 PANDAS 类的 UDF,会创建 ArrowStreamPandasUDFSerializer,其余的 UDF 类型创建 BatchedSerializer。我们来看看 ArrowStreamPandasUDFSerializer(python/pyspark/serializers.py): def dump_stream(self, iterator, stream): import pyarrow as pa writer = None try: for batch in iterator: if writer is None: writer = pa.RecordBatchStreamWriter(stream, batch.schema) writer.write_batch(batch) finally: if writer is not None: writer.close() def load_stream(self, stream): import pyarrow as pa reader = pa.ipc.open_stream(stream) for batch in reader: yield batch 可以看到,这里双向的序列化、反序列化,都是调用了 PyArrow 的 ipc 的方法,和前面看到的 Scala 端是正好对应的,也是按 batch 来读写数据。对于 Pandas 的 UDF,读到一个 batch 后,会将 Arrow 的 batch 转换成 Pandas Series。 def arrow_to_pandas(self, arrow_column): from pyspark.sql.types import _check_series_localize_timestamps # If the given column is a date type column, creates a series of datetime.date directly # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by # datetime64[ns] type handling. s = arrow_column.to_pandas(date_as_object=True) s = _check_series_localize_timestamps(s, self._timezone) return s def load_stream(self, stream): """ Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. """ batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) import pyarrow as pa for batch in batches: yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] |
1 2 3 4 5 6 |
def multiply_func(a, b): return a * b multiply = pandas_udf(multiply_func, returnType=LongType()) df.select(multiply(col("x"), col("x"))).show() |
- 进程间通信消耗额外的 CPU 资源;
- 编程接口仍然需要理解 Spark 的分布式计算原理;
- Pandas UDF 对返回值有一定的限制,返回多列数据不太方便。
陈绪,汇量科技(Mobvista)高级算法科学家,负责汇量科技大规模数据智能计算引擎和平台的研发工作。在此之前陈绪是阿里巴巴高级技术专家,负责阿里集团大规模机器学习平台的研发。
免责声明:本文内容来源于网络,文章版权归原作者所有,意在传播相关技术知识&行业趋势,供大家学习交流,若涉及作品版权问题,请联系删除或授权事宜。
本站上原创文章未经作者许可,不得用于商业用途,仅做学习交流使用,本站免责声明。转载请注明出处,否则保留追究法律责任的权利。《署名-非商业性使用-相同方式共享 4.0 国际 (CC BY-NC-SA 4.0)》许可协议授权
数据科学与编程 » 【PySpark源码解析】教你用Python调用高效Scala接口
数据科学与编程 » 【PySpark源码解析】教你用Python调用高效Scala接口