基于Scala开发Spark ML的ALS推荐模型实战

推荐系统,广泛应用到电商,营销行业。本文通过Scala,开发Spark ML的ALS算法训练推荐模型,用于电影评分预测推荐。

算法简介

ALS算法是Spark ML中实现协同过滤的矩阵分解方法。

ALS,即交替最小二乘法(Alternating Least Squares),是协同过滤技术中的一种经典算法。它通过对用户和物品的潜在特征进行建模,来预测用户对未知物品的评分或偏好。具体介绍如下:

  1. 矩阵分解模型:在推荐系统中,我们通常有一个用户-物品的评分矩阵,其中行表示用户,列表示物品,矩阵中的值代表用户对物品的评分。然而,这个矩阵通常是非常稀疏的,因为用户只给少数物品评分。ALS算法就是在这样的不完整评分矩阵上操作,通过矩阵分解来补全缺失值,进而产生推荐。
  2. 算法原理:ALS算法的核心思想是通过迭代过程更新用户和物品的潜在因子向量。在每次迭代中,一个评分被建模为用户潜在特征向量和物品潜在特征向量的点积,加上一个偏差项。通过最小化实际评分和预测评分之间的差异来不断优化这些潜在特征向量。
  3. Spark ML实现:在Spark ML库中,ALS算法被用于处理大规模的数据集,并提供了多种参数以适应不同的数据特性和需求。例如,可以设置潜在因子的数量、正则化参数、迭代次数等。此外,Spark ML的ALS还支持隐式反馈数据的变体,这对于无法获取明确评分的数据非常有用。

总的来说,ALS是一种强大的推荐系统算法,尤其适用于处理大规模稀疏数据集。通过合理地选择和调整参数,可以在保持高效计算的同时获得良好的推荐质量。

代码实战

pom.xml文件更新,加入相关依赖

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.example</groupId>
    <artifactId>sparkGNU2023</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>8</maven.compiler.source>
        <maven.compiler.target>8</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <scala.version>2.13</scala.version>
        <spark.version>3.4.1</spark.version>
        <log4j.version>1.2.17</log4j.version>
        <slf4j.version>1.7.22</slf4j.version>
    </properties>


    <dependencies>

        <!--日志相关依赖-->
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>jcl-over-slf4j</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-api</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>org.slf4j</groupId>
            <artifactId>slf4j-log4j12</artifactId>
            <version>${slf4j.version}</version>
        </dependency>
        <dependency>
            <groupId>log4j</groupId>
            <artifactId>log4j</artifactId>
            <version>${log4j.version}</version>
        </dependency>


        <dependency>
            <groupId>com.thoughtworks.paranamer</groupId>
            <artifactId>paranamer</artifactId>
            <version>2.8</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-core_2.13</artifactId>
            <version>3.4.1</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-sql_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-hive_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka-0-10_2.13</artifactId>
            <version>3.4.1</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-mllib_2.13</artifactId>
            <version>${spark.version}</version>
        </dependency>

        <dependency>
            <groupId>org.apache.spark</groupId>
            <artifactId>spark-streaming-kafka-0-8_2.11</artifactId>
            <version>2.4.8</version>
        </dependency>

        <dependency>
            <groupId>mysql</groupId>
            <artifactId>mysql-connector-java</artifactId>
            <version>8.0.30</version>
        </dependency>

        <dependency>
            <groupId>org.apache.flume.flume-ng-clients</groupId>
            <artifactId>flume-ng-log4jappender</artifactId>
            <version>1.11.0</version>
        </dependency>

        <!--        flume 拦截器相关依赖-->
        <dependency>
            <groupId>org.apache.flume</groupId>
            <artifactId>flume-ng-core</artifactId>
            <version>1.9.0</version>
            <scope>provided</scope>
        </dependency>

        <dependency>
            <groupId>com.alibaba</groupId>
            <artifactId>fastjson</artifactId>
            <version>1.2.62</version>
        </dependency>

    </dependencies>

    <build>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <version>3.8.1</version>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-assembly-plugin</artifactId>
                <version>3.6.0</version>
                <configuration>
                    <descriptorRefs>
                        <descriptorRef>jar-with-dependencies</descriptorRef>
                    </descriptorRefs>
                </configuration>
                <executions>
                    <execution>
                        <id>make-assembly</id>
                        <phase>package</phase>
                        <goals>
                            <goal>single</goal>
                        </goals>
                    </execution>
                </executions>
            </plugin>
        </plugins>
    </build>
    
</project>

训练ALS模型

基于scala训练ALS模型

package base.charpter10

import breeze.linalg.sum
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.recommendation.ALS
import org.apache.spark.sql.functions.{col, count, explode, when}
import org.apache.spark.sql.{DataFrame, SparkSession}

/**
 * @projectName sparkGNU2023  
 * @package base.charpter10  
 * @className base.charpter10.MovieRecommender  
 * @description ${description}  
 * @author pblh123
 * @date 2024/3/29 15:18
 * @version 1.0
 *
 */
    
object MovieRecommender {

  def main(args: Array[String]): Unit = {
    // 创建Spark会话
    val spark = SparkSession.builder()
      .appName("MovieRecommender")
      .master("local[*]")
      .getOrCreate()

    import spark.implicits._

    // 假设我们有一个用户-物品评分数据集,格式为(userId, itemId, rating)
    /**
     * UserID,MovieID,Rating,Timestamp
     *  1,1193,5,978300760
     *  1,661,3,978302109
     */
    // 指定CSV文件的路径,以及解析选项
    val csvFilePath = "data/ratings.csv"
    val csvOptions = Map(
      "header" -> "true", // 是否有列名头
      "inferSchema" -> "true", // 是否自动推断数据类型
    "encoding" -> "UTF-8", // 如果有特定的编码格式,例如对于包含中文的CSV文件:
    )

    // 读取CSV文件并创建DataFrame
    val ratingsDF = spark.read.format("csv")
      .options(csvOptions)
      .load(csvFilePath)

    // 显示DataFrame的前几行以验证数据是否正确加载
    println("查看原始据数据样例:")
    ratingsDF.show(5)

    val ratings: DataFrame = ratingsDF.select("UserID", "MovieID", "Rating")
      .withColumnRenamed("UserID", "userId")
      .withColumnRenamed("MovieID", "itemId")
      .withColumnRenamed("Rating", "rating")


    // 将数据集分割为训练集和测试集
    val Array(training, test) = ratings.randomSplit(Array(0.8, 0.2))

    println("查看训练集数据")
    training.show(5)
    println("查看测试集数据")
    test.show(5)


    // 设置ALS参数
    // 创建一个ALS实例并配置参数
    val als = new ALS()
      .setMaxIter(10) // 设置最大迭代次数为5,10,本地测试时,设置过大,会报错
      .setRegParam(0.01) // 设置正则化参数为0.01
      .setUserCol("userId") // 设置用户列名为"userId"
      .setItemCol("itemId") // 设置物品列名为"itemId"
      .setRatingCol("rating") // 设置评分列名为"rating"

    /**
     * ALS(Alternating Least Squares)是一种基于矩阵分解的协同过滤算法,用于处理用户和物品之间的评分数据。各参数说明如下:
     *  setMaxIter: 设置最大迭代次数,决定模型训练的精细程度。迭代次数越多,模型通常越精确,但训练时间也可能更长。
     *  setRegParam: 设置正则化参数,用于控制模型的复杂度和过拟合程度。较小的正则化参数值可能导致模型过复杂,容易过拟合;较大的值则可能导致模型过于简单,欠拟合。
     *  setUserCol, setItemCol, setRatingCol: 分别设置用户ID列、物品ID列和评分列的名称。这些列名根据实际的数据结构来确定,用于告诉ALS算法在哪些列中查找用户、物品和评分信息。
     */


    // 训练ALS模型
    println("开始训练模型")
    val model = als.fit(training)

    // 对测试集进行预测
    val predictions = model.transform(test)

    predictions.show()
    predictions.filter($"rating".isNotNull && $"prediction".isNotNull).count() // 确认有非空的评分和预测值


    // 评估模型
    val evaluator = new RegressionEvaluator()
      .setMetricName("rmse")
      .setLabelCol("rating")
      .setPredictionCol("prediction")
    val rmse = evaluator.evaluate(predictions)
    println(s"Root-mean-square error = $rmse")

    // 为用户生成推荐
    // 该函数是基于一个模型(model)为所有用户推荐项目的函数。它将为每个用户推荐5个项目
    /**
     * +------+--------------------------------------------------------------------------------------------+
     *  |userId|recommendations[{itemid,pred_rating},{itemid,pred_rating},...]                                                                             |
     *  +------+--------------------------------------------------------------------------------------------+
     *  |12    |[{1864, 9.721167}, {2964, 8.815781}, {3867, 8.480173}, {1539, 7.8904114}, {563, 7.8829007}] |
     *  |22    |[{2964, 6.090676}, {3215, 5.6165895}, {1534, 5.4731245}, {718, 5.462125}, {2632, 5.4482727}]|
     */

    val userRecs = model.recommendForAllUsers(5)
    userRecs.show(5,false)
    println("保存预测结果")
//    userRecs.write.mode("overwrite").parquet("models/recomALSmodel") // 保存为parquet格式,一般用于集群中
// userRecs是一个DataFrame,其中"recommendations"列是数组类型
    val explodedUserRecs = userRecs.withColumn("recommendations", explode($"recommendations"))
      .select($"userId", $"recommendations.itemId".as("itemId"), $"recommendations.rating".as("PredRating"))
    explodedUserRecs.write.mode("overwrite").format("csv").save("predictRes/recomALS")  // PC 调试使用

    // 保存模型到指定路径
    val modelPath = "models/recomALSmodel"
    model.write.overwrite().save(modelPath)
    println(s"Model saved to $modelPath")

    // 停止Spark会话
    spark.stop()

    /*
    当程序试图停止Spark会话时,可能会触发清理临时文件的操作,
    从而导致出现NoSuchFileException异常。通常情况下,这不是代码逻辑的问题,
    而是Spark内部在清理资源时可能出现的问题。
    可以尝试重启Spark环境或者适当增大Spark的临时目录空间来避免此类问题。
     */

  }

}

运行代码,效果图如下

TodoList:目前RMSE计算出问题,原数据清洗没有做,模型参数还可以调整。后期调整更新后,再发一篇文章。

使用训练的模型预测新数据

scala开发应用模型demo代码

package base.charpter10

import org.apache.spark.ml.recommendation.ALSModel
import org.apache.spark.sql.SparkSession

/**
 * @projectName sparkGNU2023  
 * @package base.charpter10  
 * @className base.charpter10.RecommendationModelLoadDemo  
 * @description ${description}  
 * @author pblh123
 * @date 2024/3/29 15:36
 * @version 1.0
 *
 */
    
object RecommendationModelLoadDemo {
  def main(args: Array[String]): Unit = {
    // 创建Spark会话
    val spark = SparkSession.builder().master("local[*]")
      .appName("RecommendationModelUsageDemo")
      .getOrCreate()

    import spark.implicits._

    // 加载之前保存的ALS模型
    val modelPath = "models/recomALSmodel"
    val loadedModel: ALSModel = ALSModel.load(modelPath)

    // 假设我们有一些新的用户-物品对,我们想要预测它们的评分
    val userItemPairs = Seq(
      (1, 4), // 用户1对物品4的评分预测
      (2, 2) // 用户2对物品2的评分预测
    ).toDF("userId", "itemId")

    // 使用模型进行评分预测
    val predictions = loadedModel.transform(userItemPairs)
    predictions.show()

    // 现在,假设我们想要为用户1生成前N个推荐物品
    val numRecommendations = 5 // 为用户推荐的物品数量
    val userRecs = loadedModel.recommendForAllUsers(numRecommendations)
    userRecs.show(5,false)

    // 停止Spark会话
    spark.stop()
  }

}

运行效果如下

评估效果说明:目前的预测评分不合理,是因为模型没有经过精挑,优化,预测的记过会依据预测评分高低排序,选取得分高的前5个结果返回。后期模型调优后,结果就正常了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/511744.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

golang语言系列:Web框架+路由 之 Echo

云原生学习路线导航页&#xff08;持续更新中&#xff09; 本文是golang语言系列文章&#xff0c;本篇主要对 Echo 框架 的基本使用方法 进行学习 1.Echo是什么 Go 有众多Web框架&#xff0c;Echo 是其中的一个&#xff0c;官网介绍Echo有高性能、可扩展性、极简的特点。使用E…

java发送请求-cookie有关代码

在初始化后添加cookie的代码 用这个httpclients类调custom方法&#xff0c;进行代码定制化 找和cookie有关的方法&#xff0c;设置默认的cookie存储信息 入参是接口 将入参粘贴后找方法&#xff0c;用new实现这个接口 这个方法是无参空构造&#xff0c;可以使用 设置了cookie …

【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案

【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案 大家好 我是寸铁&#x1f44a; 总结了一篇【Redis】MISCONF Redis is configured to save RDB snapshots报错解决方案✨ 喜欢的小伙伴可以点点关注 &#x1f49d; 前言 今天在登录redis时&#xff0c…

入门级深度学习主机组装过程

一 配置 先附上电脑配置图&#xff0c;如下&#xff1a; 利用公司的办公电脑对配置进行升级改造完成。除了显卡和电源&#xff0c;其他硬件都是公司电脑原装。 二 显卡 有钱直接上 RTX4090&#xff0c;也不能复用公司的电脑&#xff0c;其他配置跟不上。 进行深度学习&…

路由和远程访问是什么?

路由和远程访问在现代互联网时代中&#xff0c;扮演着至关重要的角色。它们为我们提供了便捷的信息传递途径&#xff0c;让不同地区的电脑、设备以及人们之间能够轻松进行通信和交流。 对于路由来说&#xff0c;它是连接互联网上的各个网络的核心设备。一台路由器可以将来自不同…

2013年认证杯SPSSPRO杯数学建模B题(第二阶段)流行音乐发展简史全过程文档及程序

2013年认证杯SPSSPRO杯数学建模 B题 流行音乐发展简史 原题再现&#xff1a; 随着互联网的发展&#xff0c;流行音乐的主要传播媒介从传统的电台和唱片逐渐过渡到网络下载和网络电台等。网络电台需要根据收听者的已知喜好&#xff0c;自动推荐并播放其它音乐。由于每个人喜好…

CCIE-03-Layer2-LAN-TS

目录 实验条件网络拓朴实验目标 开始排错问题1. SW2上的DHCP中继没有配置正确问题2. SW1/SW2的SVI接口被关闭问题3. 安全端口配置了不同的MAC地址 实验条件 网络拓朴 Output1 Output2 实验目标 排除故障使得PC101访问Server1时符合图片中给出的Output 开始排错 根据要求…

认识什么是Webpack

目录 1. 认识Webpack 1.1. 什么是Webpack?&#xff08;定义&#xff09; 1.2. 使用Webpack 1.2.1. 需求 1.2.2. 步骤 1.3. 入口和出口默认值 1.3.1. 需求代码如下 2. 修改Webpack打包入口和出口 2.1. 步骤&#xff1a; 2.2. 注意 3. Webpack自动生成html文件 3.1.…

NPW(监控片的)的要点精讲

半导体的生产过程已经历经数十年的发展&#xff0c;其中主要有两个大的发展趋势&#xff0c;第一&#xff0c;晶圆尺寸越做越大&#xff0c;到目前已有超过70%的产能是12寸晶圆&#xff0c;不过18寸晶圆产业链推进缓慢&#xff1b;第二&#xff0c;电子器件的关键尺寸越做越小&…

VLAN间路由

部署了VLAN的传统交换机不能实现不同VLAN间的二层报文转发&#xff0c;因此必须引入路由技术来实现不同VLAN间的通信。VLAN路由可以通过二层交换机配合路由器来实现&#xff0c;也可以通过三层交换机来实现&#xff1b; VLAN间通讯限制 每个VLAN都是一个独立的广播域&#xff…

ubuntu-server部署hive-part1-安装jdk

参照 https://blog.csdn.net/qq_41946216/article/details/134345137 操作系统版本&#xff1a;ubuntu-server-22.04.3 虚拟机&#xff1a;virtualbox7.0 安装jdk 上传解压 以root用户&#xff0c;将jdk上传至/opt目录下 tar zxvf jdk-8u271-linux-x64.tar.gz 配置环境变量…

LLM - 大语言模型 基于人类反馈的强化学习(RLHF)

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/137269049 基于人类反馈的强化学习(RLHF,Reinforcement Learning from Human Feedback),结合 强化学习(RL) 和 人类反馈 来优化模型的性能。这种方法主要包…

多态--下

文章目录 概念多态如何实现的指向谁调谁&#xff1f;例子分析 含有虚函数类的大小是多少&#xff1f;虚函数地址虚表地址多继承的子类的大小怎么计算&#xff1f;练习题虚函数和虚继承 概念 优先使用组合、而不是继承; 继承会破坏父类的封装、因为子类也可以调用到父类的函数;…

【OpenCV】图像像素的遍历

1 前言 介绍两种遍历像素的方法&#xff08;非指针、指针&#xff09;。注意&#xff1a;.at() .ptr()的作用、用法。相关API&#xff1a; Mat对象.ptr() Mat对象.at() 2 代码及内容 #include "iostream" #include "opencv2/opencv.hpp"using namespac…

知识图谱简介:探索知识的宇宙

知识图谱简介&#xff1a;探索知识的宇宙 一、引言 在这个由数据驱动的世界里&#xff0c;信息呈现出爆炸式的增长&#xff0c;人们对于管理和利用这些庞大数据量的需求也随之增长。知识图谱以其独特的方式&#xff0c;成为了整合和利用这些信息的有力工具。它不仅有助于组织杂…

小象超市(原美团买菜) 的大屏图表

文章目录 概要技术细节技术名词解释小结 概要 20203年12月1日&#xff0c;美团旗下自营零售品牌“美团买菜”升级为全新品牌“小象超市”。 &#xff0c;“小象超市”坚持美团自营零售模式&#xff0c;通过在社区设立的集存储、分拣、配送为一体的便民服务站&#xff0c;为社区…

ES8 学习 -- async 和 await / 对象方法扩展 / 字符串填充

文章目录 1. async 和 await1.1 基本语法1.2 使用示例1.3 案例练习 2. 对象方法扩展2.1 Object.values(obj)2.2 Object.entries(obj)2.3 Object.getOwnPropertyDescriptors(obj)使用示例 3. 字符串填充4. 函数参数的末尾加逗号 1. async 和 await async 函数&#xff0c;使得异…

Android Studio学习10——资源res的使用

一、String,StringArray的使用 一次修改&#xff0c;多出生效 String StringArray 二、color的使用 颜色代码对应表 和上面的相似用法 三、Dimen(尺寸)的使用 用的少&#xff0c;一般直接写尺寸 四、如何写一个drawable作为背景 五、如何写一个可以改变的drawable(按钮按下…

【合合TextIn】OCR身份证 / 银行卡识别功能适配鸿蒙系统

目录 一、 鸿蒙系统与信创国产化的背景 二、两款产品的兼容性升级详情 三、身份证产品介绍 四、银行卡识别产品 五、承诺与展望 一、 鸿蒙系统与信创国产化的背景 自鸿蒙系统推出以来&#xff0c;其不仅成为了华为在软件领域的重要里程碑&#xff0c;更是国产操作系统的一…

redis的键值基本操作

设置数据 首先设置键值对 删除age&#xff0c;会得到nil&#xff0c;表示这个键已经被删除掉了 判断age键还在不在 查找所有键 查找所有以me结尾的键 删除所有键 redis的键和值都是二进制存储的&#xff0c;所以默认不支持中文。 但是&#xff0c;我们重新登录客户端&#xff…