Pyspark 机器学习实战

一、SparkSession

from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('test').getOrCreate()

1、回归(Regression)

df = spark.read.csv('cruise_ship_info.csv',inferSchema=True,header=True)
df.show(5)
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+
only showing top 5 rows

(1)将标签数据转化为整数索引

因为要运用回归模型,所以需要将标签字段(分类字段)转换为数值。

## StringIndexer是一个Estimator,用来将某个文本属性的值转化成数字编码index,以便后续其他适用于数字编码的算法使用。
## 编码规则是对该文本属性每个出现的属性值label给出从0~label数量-1的数字,出现频率越高的label,给出的编码数字就越小。
## 因此StringIndexer是需要根据已有训练集来进行fit的。
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCol="Cruise_line", outputCol="cruise_cat")
## 根据已有训练集进行 fit
indexed = indexer.fit(df).transform(df)
indexed.show(5)
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|  Ship_name|Cruise_line|Age|           Tonnage|passengers|length|cabins|passenger_density|crew|cruise_cat|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
|    Journey|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|      Quest|    Azamara|  6|30.276999999999997|      6.94|  5.94|  3.55|            42.64|3.55|      16.0|
|Celebration|   Carnival| 26|            47.262|     14.86|  7.22|  7.43|             31.8| 6.7|       1.0|
|   Conquest|   Carnival| 11|             110.0|     29.74|  9.53| 14.88|            36.99|19.1|       1.0|
|    Destiny|   Carnival| 17|           101.353|     26.42|  8.92| 13.21|            38.36|10.0|       1.0|
+-----------+-----------+---+------------------+----------+------+------+-----------------+----+----------+
only showing top 5 rows

(2)将多列特征组合成一个向量列

## VectorAssembler是一个转换器,它可以将给定的多列转换为一个向量列
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(
    inputCols=['Age',
               'Tonnage',
               'passengers',
               'length',
               'cabins',
               'passenger_density',
               'cruise_cat'],
    outputCol="features")
output = assembler.transform(indexed)
output.select("features", "crew").show(5)
+--------------------+----+
|            features|crew|
+--------------------+----+
|[6.0,30.276999999...|3.55|
|[6.0,30.276999999...|3.55|
|[26.0,47.262,14.8...| 6.7|
|[11.0,110.0,29.74...|19.1|
|[17.0,101.353,26....|10.0|
+--------------------+----+
only showing top 5 rows

(3)将数据集划分为训练集和测试集

full_data = output.select("features", "crew")
train_data,test_data = full_data.randomSplit([0.8,0.2])

(4)训练线性回归模型

from pyspark.ml.regression import LinearRegression
## 构建线性回归模型并训练
lr = LinearRegression(featuresCol = 'features',labelCol='crew',predictionCol='prediction')
lrModel = lr.fit(train_data)
## 模型系数和截距
print(lrModel.coefficients)
print(lrModel.intercept)
[-0.01415926727704148,0.006120844210220613,-0.15060788148792473,0.4560453232842637,0.8690266483207997,-0.0006548166180796964,0.04433218409250203]
-1.1598221050189703
## 模型训练均方根差(RMSE)和 R方
trainingSummary = lrModel.summary
print(trainingSummary.rootMeanSquaredError)
print(trainingSummary.r2)
1.0132983612553066
0.908675109913899
## 模型训练残差
trainingSummary.residuals.show(5)
+--------------------+
|           residuals|
+--------------------+
| -1.3832551030447302|
|  0.5516827126047827|
|0.007265278652305085|
| -0.8206717806779125|
| -0.8206717806779125|
+--------------------+
only showing top 5 rows

(5)评估模型

test_results = lrModel.evaluate(test_data)
## 模型测试均方根差(RMSE)、均方误差(MSE)和 R方
print(test_results.rootMeanSquaredError)
print(test_results.meanSquaredError)
print(test_results.r2)
0.6355876682111081
0.40397168398203365
0.974274797868934
## 输出测试数据的真实值和预测值
test_results.predictions.show(5)
+--------------------+----+------------------+
|            features|crew|        prediction|
+--------------------+----+------------------+
|[4.0,220.0,54.0,1...|21.0| 20.82479894863448|
|[5.0,115.0,35.74,...|12.2|11.886366778288501|
|[5.0,160.0,36.34,...|13.6|15.108232838175828|
|[6.0,113.0,37.82,...|12.0| 11.68772199562538|
|[9.0,90.09,25.01,...|8.69| 9.368189276414368|
+--------------------+----+------------------+
only showing top 5 rows

(6)模型预测

predictions = lrModel.transform(test_data.select('features'))
predictions.show(5)
+--------------------+------------------+
|            features|        prediction|
+--------------------+------------------+
|[4.0,220.0,54.0,1...| 20.82479894863448|
|[5.0,115.0,35.74,...|11.886366778288501|
|[5.0,160.0,36.34,...|15.108232838175828|
|[6.0,113.0,37.82,...| 11.68772199562538|
|[9.0,90.09,25.01,...| 9.368189276414368|
+--------------------+------------------+
only showing top 5 rows

补充:计算模型某些特征与标签的相关性

from pyspark.sql.functions import corr
df.select(corr('crew','passengers')).show()
+----------------------+
|corr(crew, passengers)|
+----------------------+
|    0.9152341306065384|
+----------------------+
df.select(corr('crew','cabins')).show()
+------------------+
|corr(crew, cabins)|
+------------------+
|0.9508226063578497|
+------------------+

2、分类(Classification)

data = spark.read.csv('customer_churn.csv',inferSchema=True,header=True)
data.printSchema()
root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)
## Onboard_date 和 Location 在这里不会用到,为了方便观察,这里不显示
data.drop('Onboard_date','Location').show(5)
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
|           Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|             Company|Churn|
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
|Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|          Harvey LLC|    1|
|   Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|          Wilson PLC|    1|
|     Eric Lozano|38.0|      12884.75|              0| 6.67|     12.0|Miller, Johnson a...|    1|
|   Phillip White|42.0|       8010.76|              0| 6.71|     10.0|           Smith Inc|    1|
|  Cynthia Norton|37.0|       9191.58|              0| 5.56|      9.0|          Love-Jones|    1|
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
only showing top 5 rows
data.drop('Onboard_date','Location').orderBy('Total_Purchase').show(5)
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
|           Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|             Company|Churn|
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
|    Kayla Reeves|38.0|         100.0|              0| 5.27|      5.0|       Stewart-Lopez|    0|
|   Justin Campos|53.0|        3263.0|              1| 2.77|      9.0|         Hall-Butler|    0|
|     Lori Medina|39.0|       3676.68|              1| 3.52|      9.0|Garcia, Hansen an...|    0|
|     Kelly Terry|45.0|       3689.95|              1| 5.01|     11.0|Ellis, Johnston a...|    0|
|Kathleen Marquez|35.0|        3825.7|              0| 4.28|      8.0|Steele, Nguyen an...|    0|
+----------------+----+--------------+---------------+-----+---------+--------------------+-----+
only showing top 5 rows

Tips:将连续特征转换为分类特征

这里是一个分类问题:预测员工是否会离职。因此需要将连续字段转换为分类字段。

from pyspark.ml.feature import Binarizer, Bucketizer
## 将数值特征转化为二值特征,threshold 参数表示决定二值化的阈值
## 为了设置 threshold 参数的大小,首先需要对 Total_Purchase 字段进行探索性统计分析
data.describe("Total_Purchase").show()
+-------+-----------------+
|summary|   Total_Purchase|
+-------+-----------------+
|  count|              900|
|   mean|10062.82403333334|
| stddev|2408.644531858096|
|    min|            100.0|
|    max|         18026.01|
+-------+-----------------+
data.drop('Names','Onboard_date','Location','Company','Account_Manager').summary().show()
+-------+-----------------+-----------------+-----------------+------------------+-------------------+
|summary|              Age|   Total_Purchase|            Years|         Num_Sites|              Churn|
+-------+-----------------+-----------------+-----------------+------------------+-------------------+
|  count|              900|              900|              900|               900|                900|
|   mean|41.81666666666667|10062.82403333334| 5.27315555555555| 8.587777777777777|0.16666666666666666|
| stddev|6.127560416916251|2408.644531858096|1.274449013194616|1.7648355920350969| 0.3728852122772358|
|    min|             22.0|            100.0|              1.0|               3.0|                  0|
|    25%|             38.0|          8480.93|             4.45|               7.0|                  0|
|    50%|             42.0|         10041.13|             5.21|               8.0|                  0|
|    75%|             46.0|         11758.69|             6.11|              10.0|                  0|
|    max|             65.0|         18026.01|             9.15|              14.0|                  1|
+-------+-----------------+-----------------+-----------------+------------------+-------------------+
data.drop('Onboard_date','Location','Company').filter("Churn==0").orderBy('Total_Purchase',ascending=False).show()
+------------------+----+--------------+---------------+-----+---------+-----+
|             Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|Churn|
+------------------+----+--------------+---------------+-----+---------+-----+
|     Ethan Cordova|39.0|      18026.01|              1| 3.82|      9.0|    0|
|      Kevin Powell|43.0|      16955.76|              0| 7.04|      8.0|    0|
|        Eric Terry|42.0|      16371.42|              1| 3.84|     10.0|    0|
|      Holly Flores|47.0|      15878.11|              1| 2.05|      8.0|    0|
|   Darin Alexander|43.0|      15858.91|              1| 4.51|      8.0|    0|
|  Michael Williams|35.0|      15571.26|              0| 6.45|      9.0|    0|
|     Kenneth James|41.0|      15516.52|              0| 3.53|     10.0|    0|
|Catherine Johnston|38.0|      15509.97|              0| 4.65|      8.0|    0|
|      Katie Wagner|43.0|      15423.03|              1| 2.41|      7.0|    0|
|    Brandon Hunter|45.0|      15188.65|              0| 6.17|      8.0|    0|
|       Erin Norris|37.0|      15070.32|              0| 6.91|      6.0|    0|
|    Phillip Spears|52.0|      14838.84|              0| 5.12|      8.0|    0|
|     Jessica Wells|41.0|      14738.09|              1|  6.5|      8.0|    0|
|       Wendy Moore|41.0|      14722.35|              0| 6.98|      6.0|    0|
|     Sharon Torres|36.0|      14715.53|              1| 5.73|      9.0|    0|
|    Jessica Flores|52.0|      14669.61|              0| 6.28|      9.0|    0|
|      Keith Bowman|46.0|       14664.0|              0| 6.54|      8.0|    0|
|       Manuel Hill|37.0|      14595.51|              1| 3.83|     12.0|    0|
|      Heidi Butler|39.0|      14425.74|              0| 5.91|      6.0|    0|
|     Lindsey Adams|46.0|      14361.38|              0| 4.52|      8.0|    0|
+------------------+----+--------------+---------------+-----+---------+-----+
only showing top 20 rows
data.drop('Onboard_date','Location','Company').filter("Churn==1").orderBy('Total_Purchase').show()
+-----------------+----+--------------+---------------+-----+---------+-----+
|            Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|Churn|
+-----------------+----+--------------+---------------+-----+---------+-----+
|      Amy Griffin|48.0|       4771.65|              0| 3.77|     12.0|    1|
| Brittany Hopkins|55.0|       5024.52|              0| 8.11|      9.0|    1|
|       David Hess|41.0|       5192.38|              1| 4.86|     11.0|    1|
|   Lindsay Martin|53.0|       5515.09|              0| 6.85|      8.0|    1|
|     Mary Aguilar|50.0|       6244.75|              0| 4.64|     11.0|    1|
|Mr. Jerome Dawson|36.0|       6330.43|              1| 5.43|      7.0|    1|
|      Alexis Hill|39.0|       6351.79|              0| 5.86|      6.0|    1|
|  Cheyenne Rogers|36.0|       6447.99|              1| 5.52|     11.0|    1|
|       Adam Gomez|48.0|       6495.01|              1| 5.57|     12.0|    1|
|   Harold Griffin|41.0|       6569.87|              1|  4.3|     11.0|    1|
| Stephen Callahan|42.0|       6635.19|              0| 6.68|     11.0|    1|
|      Randy Hayes|43.0|       6715.23|              0| 4.16|      8.0|    1|
|  Daniel Bartlett|45.0|       6749.49|              0| 5.88|     14.0|    1|
|   Jessica Horton|43.0|       6992.09|              1| 6.84|     11.0|    1|
|   Kenneth Bryant|47.0|       7222.35|              0| 6.41|     11.0|    1|
|    Russell Bauer|38.0|       7287.57|              1| 7.39|     11.0|    1|
|     David Montes|45.0|       7351.38|              0| 5.76|     11.0|    1|
| Patrick Robinson|47.0|        7396.1|              0| 4.11|     11.0|    1|
| Steven Stevenson|52.0|       7460.05|              0| 5.39|     12.0|    1|
|    Raymond Berry|41.0|       7777.37|              0| 4.81|     12.0|    1|
+-----------------+----+--------------+---------------+-----+---------+-----+
only showing top 20 rows

可以看到 Total_Purchase 字段的中位数和均值都在10000左右,因此阈值选为10000。

binarizer = Binarizer(threshold=10000, inputCol='Total_Purchase', outputCol='Total_Purchase_cat')
# 根据阈值列表(分割的参数),将连续变量转换为多项值(连续变量离散化到指定的范围区间)
# 提供5个分割点意味着产生4类
bucketizer = Bucketizer(splits=[0, 10, 30, 50, 70], inputCol='Age', outputCol='age_cat')
# pipeline stages
from pyspark.ml import Pipeline
stages = [binarizer, bucketizer]
pipeline = Pipeline(stages=stages)
# fit the pipeline model and transform the data
result = pipeline.fit(data).transform(data)
result.drop('Onboard_date','Location','Company').show(5)
+----------------+----+--------------+---------------+-----+---------+-----+------------------+-------+
|           Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|Churn|Total_Purchase_cat|age_cat|
+----------------+----+--------------+---------------+-----+---------+-----+------------------+-------+
|Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|    1|               1.0|    2.0|
|   Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|    1|               1.0|    2.0|
|     Eric Lozano|38.0|      12884.75|              0| 6.67|     12.0|    1|               1.0|    2.0|
|   Phillip White|42.0|       8010.76|              0| 6.71|     10.0|    1|               0.0|    2.0|
|  Cynthia Norton|37.0|       9191.58|              0| 5.56|      9.0|    1|               0.0|    2.0|
+----------------+----+--------------+---------------+-----+---------+-----+------------------+-------+
only showing top 5 rows

(1)将多列特征组合成一个向量列

from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['Age',
                                       'Total_Purchase',
                                       'Account_Manager',
                                       'Years',
                                       'Num_Sites'],outputCol='features')
output = assembler.transform(data)

(2)划分训练集和测试集

final_data = output.select('features','churn')
final_data.show(5)
+--------------------+-----+
|            features|churn|
+--------------------+-----+
|[42.0,11066.8,0.0...|    1|
|[41.0,11916.22,0....|    1|
|[38.0,12884.75,0....|    1|
|[42.0,8010.76,0.0...|    1|
|[37.0,9191.58,0.0...|    1|
+--------------------+-----+
only showing top 5 rows
train_churn,test_churn = final_data.randomSplit([0.8,0.2])

(3)选择模型并训练

方法一:逻辑回归模型

from pyspark.ml.classification import LogisticRegression
lr_churn = LogisticRegression(featuresCol = 'features',labelCol='churn')
model = lr_churn.fit(train_churn)
training_sum = model.summary
training_sum.predictions.show(5)
+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[22.0,11254.38,1....|  0.0|[4.29752479205010...|[0.98658035094082...|       0.0|
|[25.0,9672.03,0.0...|  0.0|[4.45078681080970...|[0.98846522196675...|       0.0|
|[26.0,8787.39,1.0...|  1.0|[0.42481734257158...|[0.60463542102943...|       0.0|
|[27.0,8628.8,1.0,...|  0.0|[5.16036305292277...|[0.99429313966103...|       0.0|
|[28.0,8670.98,0.0...|  0.0|[7.47283595367025...|[0.99943200850135...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows

(4)模型评估

from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator
# 代入测试集
pred_and_labels = model.evaluate(test_churn)
pred_and_labels.predictions.show(5)
+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[26.0,8939.61,0.0...|    0|[6.09287719449442...|[0.99774619086945...|       0.0|
|[28.0,11245.38,0....|    0|[3.64367153236502...|[0.97451057057915...|       0.0|
|[29.0,9378.24,0.0...|    0|[4.57063641463622...|[0.98975468320583...|       0.0|
|[30.0,6744.87,0.0...|    0|[3.31883963587618...|[0.96506949627380...|       0.0|
|[31.0,10058.87,1....|    0|[4.30854259864759...|[0.98672544256430...|       0.0|
+--------------------+-----+--------------------+--------------------+----------+
only showing top 5 rows
churn_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction',labelCol='churn')
churn_eval_multi = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='churn',metricName='accuracy')
auc = churn_eval_multi.evaluate(pred_and_labels.predictions)
auc
0.885

效果好像不太好?

方法二:决策树模型

from pyspark.ml.classification import DecisionTreeClassifier
dtc = DecisionTreeClassifier(labelCol='churn',featuresCol='features')
dtc_model = dtc.fit(train_churn)
print(dtc_model.featureImportances)
(5,[0,1,3,4],[0.09120095289461654,0.08511254111869927,0.1432383156878679,0.6804481902988163])
predictions = dtc_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.88

这个更不好?

方法三:随机森林模型

from pyspark.ml.classification import RandomForestClassifier
rfc = RandomForestClassifier(labelCol="churn", featuresCol="features", numTrees=20)
rfc_model = rfc.fit(train_churn)
print(rfc_model.featureImportances)
(5,[0,1,2,3,4],[0.10313914486830404,0.08820488003917278,0.022502521370334393,0.13157948168406858,0.6545739720381203])
predictions = rfc_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.88

咋回事?

方法四:梯度提升树模型

from pyspark.ml.classification import GBTClassifier
gbt = GBTClassifier(labelCol="churn", featuresCol="features", maxIter=20)
gbt_model = gbt.fit(train_churn)
predictions = gbt_model.transform(test_churn)
accuracy = churn_eval_multi.evaluate(predictions)
accuracy
0.875

3、聚类(Clustering)

data = spark.read.csv("hack_data.csv",header=True,inferSchema=True)
data.printSchema()
root
 |-- Session_Connection_Time: double (nullable = true)
 |-- Bytes Transferred: double (nullable = true)
 |-- Kali_Trace_Used: integer (nullable = true)
 |-- Servers_Corrupted: double (nullable = true)
 |-- Pages_Corrupted: double (nullable = true)
 |-- Location: string (nullable = true)
 |-- WPM_Typing_Speed: double (nullable = true)
data.drop('Location').show(5)
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|WPM_Typing_Speed|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
|                    8.0|           391.09|              1|             2.96|            7.0|           72.37|
|                   20.0|           720.99|              0|             3.04|            9.0|           69.08|
|                   31.0|           356.32|              1|             3.71|            8.0|           70.58|
|                    2.0|           228.08|              1|             2.48|            8.0|            70.8|
|                   20.0|            408.5|              0|             3.57|            8.0|           71.28|
+-----------------------+-----------------+---------------+-----------------+---------------+----------------+
only showing top 5 rows

(1)将多列特征组合成一个向量列

from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
feat_cols = ['Session_Connection_Time', 'Bytes Transferred', 'Kali_Trace_Used',
             'Servers_Corrupted', 'Pages_Corrupted','WPM_Typing_Speed']
vec_assembler = VectorAssembler(inputCols = feat_cols, outputCol='features')
final_data = vec_assembler.transform(data)
final_data.select('features').show(5)
+--------------------+
|            features|
+--------------------+
|[8.0,391.09,1.0,2...|
|[20.0,720.99,0.0,...|
|[31.0,356.32,1.0,...|
|[2.0,228.08,1.0,2...|
|[20.0,408.5,0.0,3...|
+--------------------+
only showing top 5 rows

(2)特征标准化

from pyspark.ml.feature import StandardScaler
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False)
cluster_final_data = scaler.fit(final_data).transform(final_data)
cluster_final_data.select("scaledFeatures").show(5)
+--------------------+
|      scaledFeatures|
+--------------------+
|[0.56785108466505...|
|[1.41962771166263...|
|[2.20042295307707...|
|[0.14196277116626...|
|[1.41962771166263...|
+--------------------+
only showing top 5 rows

(3)K-Means 聚类

from pyspark.ml.clustering import KMeans
model = KMeans(featuresCol='scaledFeatures',k=3)
model = model.fit(cluster_final_data)
model.computeCost(cluster_final_data)
434.75507308487647
model.clusterCenters()
[array([1.26023837, 1.31829808, 0.99280765, 1.36491885, 2.5625043 ,
        5.26676612]),
 array([3.05623261, 2.95754486, 1.99757683, 3.2079628 , 4.49941976,
        3.26738378]),
 array([2.93719177, 2.88492202, 0.        , 3.19938371, 4.52857793,
        3.30407351])]

(4)模型预测

model.transform(cluster_final_data).groupBy('prediction').count().show()
+----------+-----+
|prediction|count|
+----------+-----+
|         1|   88|
|         2|   79|
|         0|  167|
+----------+-----+
model.transform(cluster_final_data).show(5)
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
|Session_Connection_Time|Bytes Transferred|Kali_Trace_Used|Servers_Corrupted|Pages_Corrupted|            Location|WPM_Typing_Speed|            features|      scaledFeatures|prediction|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
|                    8.0|           391.09|              1|             2.96|            7.0|            Slovenia|           72.37|[8.0,391.09,1.0,2...|[0.56785108466505...|         0|
|                   20.0|           720.99|              0|             3.04|            9.0|British Virgin Is...|           69.08|[20.0,720.99,0.0,...|[1.41962771166263...|         0|
|                   31.0|           356.32|              1|             3.71|            8.0|             Tokelau|           70.58|[31.0,356.32,1.0,...|[2.20042295307707...|         0|
|                    2.0|           228.08|              1|             2.48|            8.0|             Bolivia|            70.8|[2.0,228.08,1.0,2...|[0.14196277116626...|         0|
|                   20.0|            408.5|              0|             3.57|            8.0|                Iraq|           71.28|[20.0,408.5,0.0,3...|[1.41962771166263...|         0|
+-----------------------+-----------------+---------------+-----------------+---------------+--------------------+----------------+--------------------+--------------------+----------+
only showing top 5 rows
model.transform(cluster_final_data).select('features','scaledFeatures','prediction').show(5)
+--------------------+--------------------+----------+
|            features|      scaledFeatures|prediction|
+--------------------+--------------------+----------+
|[8.0,391.09,1.0,2...|[0.56785108466505...|         0|
|[20.0,720.99,0.0,...|[1.41962771166263...|         0|
|[31.0,356.32,1.0,...|[2.20042295307707...|         0|
|[2.0,228.08,1.0,2...|[0.14196277116626...|         0|
|[20.0,408.5,0.0,3...|[1.41962771166263...|         0|
+--------------------+--------------------+----------+
only showing top 5 rows

4、基于TF-IDF 算法的文本挖掘

data = spark.read.csv("SMSSpamCollection",inferSchema=True,sep='\t')
data = data.withColumnRenamed('_c0','class').withColumnRenamed('_c1','text')
data.show(5)
+-----+--------------------+
|class|                text|
+-----+--------------------+
|  ham|Go until jurong p...|
|  ham|Ok lar... Joking ...|
| spam|Free entry in 2 a...|
|  ham|U dun say so earl...|
|  ham|Nah I don't think...|
+-----+--------------------+
only showing top 5 rows

可以看到,这是一个邮件及其类别的数据。

(1)数据预处理

from pyspark.sql.functions import length
# compute length of each text
data = data.withColumn('length',length(data['text']))
  • 分词
from pyspark.ml.feature import Tokenizer
tokenizer = Tokenizer(inputCol="text", outputCol="stop_tokens")
  • 去除停用词
from pyspark.ml.feature import StopWordsRemover
# stopremove = StopWordsRemover(inputCol='token_text',outputCol='stop_tokens')
  • 计算词频
from pyspark.ml.feature import CountVectorizer
count_vec = CountVectorizer(inputCol='stop_tokens',outputCol='c_vec')
  • 计算逆文本频率
from pyspark.ml.feature import IDF
idf = IDF(inputCol="c_vec", outputCol="tf_idf")
  • 将类标签由字符串映射到索引
from pyspark.ml.feature import StringIndexer
ham_spam_to_num = StringIndexer(inputCol='class',outputCol='label')

(2)将列转化为模型输入特征

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vector
clean_up = VectorAssembler(inputCols=['tf_idf','length'],outputCol='features')

(3)构建模型

from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes()

(4)划分训练集和测试集

from pyspark.ml import Pipeline
data_prep_pipe = Pipeline(stages=[ham_spam_to_num,tokenizer,count_vec,idf,clean_up])
cleaner = data_prep_pipe.fit(data)
clean_data = cleaner.transform(data)
full_data = clean_data.select(['label','features'])
(train_data,test_data) = full_data.randomSplit([0.8,0.2])

(5)模型训练

model = nb.fit(train_data)

(6)模型评估

test_results = model.transform(test_data)
test_results.show(5)
+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(13588,[0,1,2,3,4...|[-3592.6536481156...|[1.0,3.5894638863...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-2823.9774728892...|[1.0,3.0240329484...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-3095.2908236727...|[1.0,8.8358217240...|       0.0|
|  0.0|(13588,[0,1,2,3,4...|[-1075.7111968609...|[1.0,2.8629849890...|       0.0|
|  0.0|(13588,[0,1,2,3,5...|[-1787.4043923033...|[1.0,2.0766777717...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_eval = MulticlassClassificationEvaluator()
acc = acc_eval.evaluate(test_results)
print("Accuracy of model at predicting spam was: {}".format(acc))
Accuracy of model at predicting spam was: 0.9410145943960191
spark.stop()