PySpark源碼解析,教你用Python調(diào)用高效Scala接口,搞定大規(guī)模數(shù)據(jù)分析
相較于Scala語言而言,Python具有其獨有的優(yōu)勢及廣泛應用性,因此Spark也推出了PySpark,在框架上提供了利用Python語言的接口,為數(shù)據(jù)科學家使用該框架提供了便利。
眾所周知,Spark 框架主要是由 Scala 語言實現(xiàn),同時也包含少量 Java 代碼。Spark 面向用戶的編程接口,也是 Scala。然而,在數(shù)據(jù)科學領域,Python 一直占據(jù)比較重要的地位,仍然有大量的數(shù)據(jù)工程師在使用各類 Python 數(shù)據(jù)處理和科學計算的庫,例如 numpy、Pandas、scikit-learn 等。同時,Python 語言的入門門檻也顯著低于 Scala。
為此,Spark 推出了 PySpark,在 Spark 框架上提供一套 Python 的接口,方便廣大數(shù)據(jù)科學家使用。本文主要從源碼實現(xiàn)層面解析 PySpark 的實現(xiàn)原理,包括以下幾個方面:
- PySpark 的多進程架構(gòu);
- Python 端調(diào)用 Java、Scala 接口;
- Python Driver 端 RDD、SQL 接口;
- Executor 端進程間通信和序列化;
- Pandas UDF;
- 總結(jié)。
PySpark項目地址:https://github.com/apache/spark/tree/master/python
1、PySpark 的多進程架構(gòu)
PySpark 采用了 Python、JVM 進程分離的多進程架構(gòu),在 Driver、Executor 端均會同時有 Python、JVM 兩個進程。當通過 spark-submit 提交一個 PySpark 的 Python 腳本時,Driver 端會直接運行這個 Python 腳本,并從 Python 中啟動 JVM;而在 Python 中調(diào)用的 RDD 或者 DataFrame 的操作,會通過 Py4j 調(diào)用到 Java 的接口。
在 Executor 端恰好是反過來,首先由 Driver 啟動了 JVM 的 Executor 進程,然后在 JVM 中去啟動 Python 的子進程,用以執(zhí)行 Python 的 UDF,這其中是使用了 socket 來做進程間通信??傮w的架構(gòu)圖如下所示:
2、Python Driver 如何調(diào)用 Java 的接口
上面提到,通過 spark-submit 提交 PySpark 作業(yè)后,Driver 端首先是運行用戶提交的 Python 腳本,然而 Spark 提供的大多數(shù) API 都是 Scala 或者 Java 的,那么就需要能夠在 Python 中去調(diào)用 Java 接口。這里 PySpark 使用了 Py4j 這個開源庫。當創(chuàng)建 Python 端的 SparkContext 對象時,實際會啟動 JVM,并創(chuàng)建一個 Scala 端的 SparkContext 對象。代碼實現(xiàn)在 python/pyspark/context.py:
- def _ensure_initialized(cls, instance=None, gateway=None, conf=None):
- """
- Checks whether a SparkContext is initialized or not.
- Throws error if a SparkContext is already running.
- """
- with SparkContext._lock:
- if not SparkContext._gateway:
- SparkContext._gateway = gateway or launch_gateway(conf)
- SparkContext._jvm = SparkContext._gateway.jvm
在 launch_gateway (python/pyspark/java_gateway.py) 中,首先啟動 JVM 進程:
- SPARK_HOME = _find_spark_home()
- # Launch the Py4j gateway using Spark's run command so that we pick up the
- # proper classpath and settings from spark-env.sh
- on_windows = platform.system() == "Windows"
- script = "./bin/spark-submit.cmd" if on_windows else "./bin/spark-submit"
- command = [os.path.join(SPARK_HOME, script)]
然后創(chuàng)建 JavaGateway 并 import 一些關鍵的 class:
- gateway = JavaGateway(
- gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret,
- auto_convert=True))
- # Import the classes used by PySpark
- java_import(gateway.jvm, "org.apache.spark.SparkConf")
- java_import(gateway.jvm, "org.apache.spark.api.java.*")
- java_import(gateway.jvm, "org.apache.spark.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.ml.python.*")
- java_import(gateway.jvm, "org.apache.spark.mllib.api.python.*")
- # TODO(davies): move into sql
- java_import(gateway.jvm, "org.apache.spark.sql.*")
- java_import(gateway.jvm, "org.apache.spark.sql.api.python.*")
- java_import(gateway.jvm, "org.apache.spark.sql.hive.*")
- java_import(gateway.jvm, "scala.Tuple2")
- 拿到 JavaGateway 對象,即可以通過它的 jvm 屬性,去調(diào)用 Java 的類了,例如:
- gateway = JavaGateway()
- gateway = JavaGateway()
- jvm = gateway.jvm
- l = jvm.java.util.ArrayList()
然后會繼續(xù)創(chuàng)建 JVM 中的 SparkContext 對象:
- def _initialize_context(self, jconf):
- """
- Initialize SparkContext in function to allow subclass specific initialization
- """
- return self._jvm.JavaSparkContext(jconf)
- # Create the Java SparkContext through Py4J
- self._jsc = jsc or self._initialize_context(self._conf._jconf)
3、Python Driver 端的 RDD、SQL 接口
在 PySpark 中,繼續(xù)初始化一些 Python 和 JVM 的環(huán)境后,Python 端的 SparkContext 對象就創(chuàng)建好了,它實際是對 JVM 端接口的一層封裝。和 Scala API 類似,SparkContext 對象也提供了各類創(chuàng)建 RDD 的接口,和 Scala API 基本一一對應,我們來看一些例子。
- def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=0):
- jconf = self._dictToJavaMap(conf)
- jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
- valueClass, keyConverter, valueConverter,
- jconf, batchSize)
- return RDD(jrdd, self)
可以看到,這里 Python 端基本就是直接調(diào)用了 Java/Scala 接口。而 PythonRDD (core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala),則是一個 Scala 中封裝的伴生對象,提供了常用的 RDD IO 相關的接口。另外一些接口會通過 self._jsc 對象去創(chuàng)建 RDD。其中 self._jsc 就是 JVM 中的 SparkContext 對象。拿到 RDD 對象之后,可以像 Scala、Java API 一樣,對 RDD 進行各類操作,這些大部分都封裝在 python/pyspark/rdd.py 中。
這里的代碼中出現(xiàn)了 jrdd 這樣一個對象,這實際上是 Scala 為提供 Java 互操作的 RDD 的一個封裝,用來提供 Java 的 RDD 接口,具體實現(xiàn)在 core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala 中??梢钥吹矫總€ Python 的 RDD 對象需要用一個 JavaRDD 對象去創(chuàng)建。
對于 DataFrame 接口,Python 層也同樣提供了 SparkSession、DataFrame 對象,它們也都是對 Java 層接口的封裝,這里不一一贅述。
4、Executor 端進程間通信和序列化
對于 Spark 內(nèi)置的算子,在 Python 中調(diào)用 RDD、DataFrame 的接口后,從上文可以看出會通過 JVM 去調(diào)用到 Scala 的接口,最后執(zhí)行和直接使用 Scala 并無區(qū)別。而對于需要使用 UDF 的情形,在 Executor 端就需要啟動一個 Python worker 子進程,然后執(zhí)行 UDF 的邏輯。那么 Spark 是怎樣判斷需要啟動子進程的呢?
在 Spark 編譯用戶的 DAG 的時候,Catalyst Optimizer 會創(chuàng)建 BatchEvalPython 或者 ArrowEvalPython 這樣的 Logical Operator,隨后會被轉(zhuǎn)換成 PythonEvals 這個 Physical Operator。在 PythonEvals(sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala)中:
- object PythonEvals extends Strategy {
- override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case ArrowEvalPython(udfs, output, child, evalType) =>
- ArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil
- case BatchEvalPython(udfs, output, child) =>
- BatchEvalPythonExec(udfs, output, planLater(child)) :: Nil
- case _ =>
- Nil
- }
- }
創(chuàng)建了 ArrowEvalPythonExec 或者 BatchEvalPythonExec,而這二者內(nèi)部會創(chuàng)建 ArrowPythonRunner、PythonUDFRunner 等類的對象實例,并調(diào)用了它們的 compute 方法。由于它們都繼承了 BasePythonRunner,基類的 compute 方法中會去啟動 Python 子進程:
- def compute(
- inputIterator: Iterator[IN],
- partitionIndex: Int,
- context: TaskContext): Iterator[OUT] = {
- // ......
- val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
- // Start a thread to feed the process input from our parent's iterator
- val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context)
- writerThread.start()
- val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))
- val stdoutIterator = newReaderIterator(
- stream, writerThread, startTime, env, worker, releasedOrClosed, context)
- new InterruptibleIterator(context, stdoutIterator)
這里 env.createPythonWorker 會通過 PythonWorkerFactory(core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala)去啟動 Python 進程。Executor 端啟動 Python 子進程后,會創(chuàng)建一個 socket 與 Python 建立連接。所有 RDD 的數(shù)據(jù)都要序列化后,通過 socket 發(fā)送,而結(jié)果數(shù)據(jù)需要同樣的方式序列化傳回 JVM。
對于直接使用 RDD 的計算,或者沒有開啟 spark.sql.execution.arrow.enabled 的 DataFrame,是將輸入數(shù)據(jù)按行發(fā)送給 Python,可想而知,這樣效率極低。
在 Spark 2.2 后提供了基于 Arrow 的序列化、反序列化的機制(從 3.0 起是默認開啟),從 JVM 發(fā)送數(shù)據(jù)到 Python 進程的代碼在 sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala。這個類主要是重寫了 newWriterThread 這個方法,使用了 ArrowWriter 向 socket 發(fā)送數(shù)據(jù):
- val arrowWriter = ArrowWriter.create(root)
- val writer = new ArrowStreamWriter(root, null, dataOut)
- writer.start()
- while (inputIterator.hasNext) {
- val nextBatch = inputIterator.next()
- while (nextBatch.hasNext) {
- arrowWriter.write(nextBatch.next())
- }
- arrowWriter.finish()
- writer.writeBatch()
- arrowWriter.reset()
可以看到,每次取出一個 batch,填充給 ArrowWriter,實際數(shù)據(jù)會保存在 root 對象中,然后由 ArrowStreamWriter 將 root 對象中的整個 batch 的數(shù)據(jù)寫入到 socket 的 DataOutputStream 中去。ArrowStreamWriter 會調(diào)用 writeBatch 方法去序列化消息并寫數(shù)據(jù),代碼參考 ArrowWriter.java#L131。
- protected ArrowBlock writeRecordBatch(ArrowRecordBatch batch) throws IOException {
- ArrowBlock block = MessageSerializer.serialize(out, batch, option);
- LOGGER.debug("RecordBatch at {}, metadata: {}, body: {}",
- block.getOffset(), block.getMetadataLength(), block.getBodyLength());
- return block;
- }
在 MessageSerializer 中,使用了 flatbuffer 來序列化數(shù)據(jù)。flatbuffer 是一種比較高效的序列化協(xié)議,它的主要優(yōu)點是反序列化的時候,不需要解碼,可以直接通過裸 buffer 來讀取字段,可以認為反序列化的開銷為零。我們來看看 Python 進程收到消息后是如何反序列化的。
Python 子進程實際上是執(zhí)行了 worker.py 的 main 函數(shù) (python/pyspark/worker.py):
- if __name__ == '__main__':
- # Read information about how to connect back to the JVM from the environment.
- java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
- auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
- (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
- main(sock_file, sock_file)
這里會去向 JVM 建立連接,并從 socket 中讀取指令和數(shù)據(jù)。對于如何進行序列化、反序列化,是通過 UDF 的類型來區(qū)分:
- eval_type = read_int(infile)
- if eval_type == PythonEvalType.NON_UDF:
- func, profiler, deserializer, serializer = read_command(pickleSer, infile)
- else:
- func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type)
在 read_udfs 中,如果是 PANDAS 類的 UDF,會創(chuàng)建 ArrowStreamPandasUDFSerializer,其余的 UDF 類型創(chuàng)建 BatchedSerializer。我們來看看 ArrowStreamPandasUDFSerializer(python/pyspark/serializers.py):
- def dump_stream(self, iterator, stream):
- import pyarrow as pa
- writer = None
- try:
- for batch in iterator:
- if writer is None:
- writer = pa.RecordBatchStreamWriter(stream, batch.schema)
- writer.write_batch(batch)
- finally:
- if writer is not None:
- writer.close()
- def load_stream(self, stream):
- import pyarrow as pa
- reader = pa.ipc.open_stream(stream)
- for batch in reader:
- yield batch
可以看到,這里雙向的序列化、反序列化,都是調(diào)用了 PyArrow 的 ipc 的方法,和前面看到的 Scala 端是正好對應的,也是按 batch 來讀寫數(shù)據(jù)。對于 Pandas 的 UDF,讀到一個 batch 后,會將 Arrow 的 batch 轉(zhuǎn)換成 Pandas Series。
- def arrow_to_pandas(self, arrow_column):
- from pyspark.sql.types import _check_series_localize_timestamps
- # If the given column is a date type column, creates a series of datetime.date directly
- # instead of creating datetime64[ns] as intermediate data to avoid overflow caused by
- # datetime64[ns] type handling.
- s = arrow_column.to_pandas(date_as_object=True)
- s = _check_series_localize_timestamps(s, self._timezone)
- return s
- def load_stream(self, stream):
- """
- Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series.
- """
- batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
- import pyarrow as pa
- for batch in batches:
- yield [self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()]
5、Pandas UDF
前面我們已經(jīng)看到,PySpark 提供了基于 Arrow 的進程間通信來提高效率,那么對于用戶在 Python 層的 UDF,是不是也能直接使用到這種高效的內(nèi)存格式呢?答案是肯定的,這就是 PySpark 推出的 Pandas UDF。區(qū)別于以往以行為單位的 UDF,Pandas UDF 是以一個 Pandas Series 為單位,batch 的大小可以由 spark.sql.execution.arrow.maxRecordsPerBatch 這個參數(shù)來控制。這是一個來自官方文檔的示例:
- def multiply_func(a, b):
- return a * b
- multiply = pandas_udf(multiply_func, returnType=LongType())
- df.select(multiply(col("x"), col("x"))).show()
上文已經(jīng)解析過,PySpark 會將 DataFrame 以 Arrow 的方式傳遞給 Python 進程,Python 中會轉(zhuǎn)換為 Pandas Series,傳遞給用戶的 UDF。在 Pandas UDF 中,可以使用 Pandas 的 API 來完成計算,在易用性和性能上都得到了很大的提升。
6、總結(jié)
PySpark 為用戶提供了 Python 層對 RDD、DataFrame 的操作接口,同時也支持了 UDF,通過 Arrow、Pandas 向量化的執(zhí)行,對提升大規(guī)模數(shù)據(jù)處理的吞吐是非常重要的,一方面可以讓數(shù)據(jù)以向量的形式進行計算,提升 cache 命中率,降低函數(shù)調(diào)用的開銷,另一方面對于一些 IO 的操作,也可以降低網(wǎng)絡延遲對性能的影響。
然而 PySpark 仍然存在著一些不足,主要有:
- 進程間通信消耗額外的 CPU 資源;
- 編程接口仍然需要理解 Spark 的分布式計算原理;
- Pandas UDF 對返回值有一定的限制,返回多列數(shù)據(jù)不太方便。
Databricks 提出了新的 Koalas 接口來使得用戶可以以接近單機版 Pandas 的形式來編寫分布式的 Spark 計算作業(yè),對數(shù)據(jù)科學家會更加友好。而 Vectorized Execution 的推進,有望在 Spark 內(nèi)部一切數(shù)據(jù)都是用 Arrow 的格式來存放,對跨語言支持將會更加友好。同時也能看到,在這里仍然有很大的性能、易用性的優(yōu)化空間,這也是我們平臺近期的主要發(fā)力方向之一。
作者介紹
陳緒,匯量科技(Mobvista)高級算法科學家,負責匯量科技大規(guī)模數(shù)據(jù)智能計算引擎和平臺的研發(fā)工作。在此之前陳緒是阿里巴巴高級技術專家,負責阿里集團大規(guī)模機器學習平臺的研發(fā)。