Spark记录未整理

Spark记录未整理,请以较平静的心态阅读。
目的: 根据user_id进行分组,同时将同一user_id看过的anime_id转化为一个字符串数组(anime_ids),将anime_ids转化为二维的list [[[20, 81, 170, 263…],[]…],最后构建一个关于anime_ids的邻接矩阵。

本文的目的

1)根据是anime_id构建一个邻接矩阵
2)构建deep walk的转移矩阵和入口矩阵
3)根据deep walk算法实现sample
4)使用spark的word2vec训练稠密向量
4)使用redis缓存user_id的embedding参数
5)使用局部敏感hash算法完成近似查找

原始数据

+-------+--------+------+                                                       
|user_id|anime_id|rating|
+-------+--------+------+
|      1|    8074|    10|
|      1|   11617|    10|
|      1|   11757|    10|
|      1|   15451|    10|
|      2|   11771|    10|
+-------+--------+------+

包引入

import pyspark
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
import os
import numpy as np
from collections import defaultdict

spark 运行需要jdk,导入环境变量JAVA_HOME
os.environ["JAVA_HOME"] = "/usr/lib/jvm/jre-1.8.0-openjdk" 

邻接矩阵构建

row和col是anime_id,值是看过行anime_id又看过列anime_id的user_id数

# 辅助函数
def print_(name):
    print_str = "=" * 50 + ">>>" + name + "<<<" + "=" * 50
    print(print_str)


spark = SparkSession.builder.appName("feat-eng").getOrCreate()

rating_df = spark.read.csv(
    "/data/jupyter/recommendation_data/rating.csv", header=True, inferSchema=True
)
rating_df = rating_df.where("rating > 7")
# 原始数据
# print(rating_df.show(5))

watch_seq_df = rating_df.groupBy("user_id").agg(
    F.collect_list(F.col("anime_id").cast("string")).alias("anime_ids")
)

# print(watch_seq_df.show(2))
# +-------+--------------------+
# |user_id|           anime_ids|
# +-------+--------------------+
# |    148|[20, 81, 170, 263...|
# |    463|[20, 24, 68, 102,...|
# +-------+--------------------+
# print(watch_seq_df.printSchema())
# root
#  |-- user_id: integer (nullable = true)
#  |-- anime_ids: array (nullable = true)
#  |    |-- element: string (containsNull = false)

# 转成一个list
watch_seq = watch_seq_df.collect()
watch_seq = [s["anime_ids"] for s in watch_seq]

# 邻接矩阵
matrix = defaultdict(lambda: defaultdict(int))

for i in range(len(watch_seq)):
    seq = watch_seq[i]
    for x in range(len(seq)):
        for y in range(x + 1, len(seq)):
            a = seq[x]
            b = seq[y]
            if a == b:
                continue  # 对角线不统计
            matrix[a][b] += 1
            matrix[b][a] += 1

看过同时看过anime_id=20和其他anime_id的人数,dict的key是anime_id=20的邻居节点,value是同时观看的用户人数,结果如下:

print(matrix["20"])

在这里插入图片描述

概率转移矩阵和入口转移矩阵

概率转移矩阵:决定当前节点在deep walk中如何选择下一个节点
入口转移矩阵:deep walk 初始化时选择的节点

概率转移矩阵格式:某一个anime_id的neighbours和他的转移概率(tranfer_probs)

{
    "anime_id": {
        "neighbours": [2, 3, 5, 7],
        "probs": [0.16, 0.16, 0.32, 0.32]
    },
}

入口转移矩阵格式:计算某一个anime_id看在的次数在所有用户看过的anime_id的总次数的占比

[0.001953672356421993, 0.0004123166720890604, 0.0008729517041885576, ...]

代码

def get_transfer_prob(vs):
    neighbours = list(vs.keys())
    total_weight = sum(vs.values())
    probs = [vs[k] / total_weight for k in vs.keys()]

    return {"neighbours": neighbours, "prob": probs}


tranfer_probs = {k: get_transfer_prob(v) for k, v in matrix.items()}

entrance_items = list(tranfer_probs)

neighbour_sum = {k: sum(matrix[k].values()) for k in entrance_items}
total_sum = sum(neighbour_sum.values())

entrence_probs = [neighbour_sum[e] / total_sum for e in entrance_items]

构建随机游走

流程:
1)在入口转移概率中随机选一个节点,加入路径中
2)循环length次在转移概率中随机选一个节点,加入路径中
3)更新当前节点
4)重复循环形成一个length+1的一个节点路径
4)外部循环n次形成n组采样

rng = np.random.default_rng()

def one_walk(length, entrance_items, entrence_probs, tranfer_probs):
    start_point = rng.choice(entrance_items, 1, p=entrence_probs)[0]
    path = [str(start_point)]
    current_point = start_point
    for _ in range(length):
        neighbours = tranfer_probs[current_point]["neighbours"]
        transfor_prob = tranfer_probs[current_point]["prob"]
        next_point = rng.choice(neighbours, 1, p=transfor_prob)[0]
        path.append(str(next_point))
        current_point = next_point
    return path


n = 500
sample = [one_walk(20, entrance_items, entrence_probs, tranfer_probs) for _ in range(n)]

采样格式如下,数数字代表anime_id:

在这里插入图片描述

spark的word2vec

将sample采样后数据转化成DataFrame格式,利用spark的word2vec转化为稠密向量


sample_df = spark.createDataFrame([[row] for row in sample], ["anime_ids"])


from pyspark.ml.feature import Word2Vec

item2vec = Word2Vec(vectorSize=5, maxIter=2, windowSize=15)  # skip model
item2vec.setInputCol("anime_ids")
item2vec.setOutputCol("anime_ids_vec")
model = item2vec.fit(sample_df)
# test 这里是找到和anime_id=20比较近的10条数据
# rec = model.findSynonyms("20", 10)

# 获取训练参数,便于进行下一步操作
item_vec = model.getVectors().collect()
item_emb = {}
for item in item_vec:
    item_emb[item.word] = item.vector.toArray()


@F.udf(returnType="array<float>")
def build_user_emb(anime_seq):
    anime_embs = [item_emb[aid] if aid in item_emb else [] for aid in anime_seq]
    anime_embs = list(filter(lambda l: len(l) > 0, anime_embs))
    emb = np.mean(anime_embs, axis=0)
    return emb.tolist()


user_emb_df = watch_seq_df.withColumn("user_emb", build_user_emb(F.col("anime_ids")))
print(user_emb_df.show(3))
# +-------+--------------------+--------------------+
# |user_id|           anime_ids|            user_emb|
# +-------+--------------------+--------------------+
# |    148|[20, 81, 170, 263...|[0.22693123, -0.0...|
# |    463|[20, 24, 68, 102,...|[0.24494155, -0.0...|
# |    471| [1604, 6702, 10681]|[0.37150145, 0.04...|
# +-------+--------------------+--------------------+

build_user_emb函数根据前面将某一个用户看过的所有anime_id逐个进行embedding,之后平均处理。这样最终我们就可以根据某一个用户看过的所有anime_id获得某一个用户的embedding。

redis保存缓存数据

from redis import Redis

redis = Redis()
user_emb = user_emb_df.collect()
# user_id和其对应的embedding
user_emb = {row.user_id: row.user_emb for row in user_emb}

# 辅助函数将float类别转化为字符串用:分割
def vec2str(vec):
    if vec is None:
        return ""
    return ":".join([str(v) for v in vec])

def save_user_emb(embs):
    str_emb = {item_id: vec2str(v) for item_id, v in embs.items()}
    redis.hset("recall-user-emb", mapping=str_emb)

# 用户的embedding保存到redis
save_user_emb(user_emb)
# test语句
# redis.hget("recall-user-emb", "148")

# 辅助函数将str读成float的列表
def str2vec(s):
    if len(s) == 0:
        return None
    return [float(x) for x in s.split(":")]

def load_user_emb():
    result = redis.hgetall("recall-user-emb")
    return {user_id.decode(): str2vec(emb.decode()) for user_id, emb in result.items()}
# 读取所有数据
load_user_emb()

近似查找ANN算法实现

使用包faiss的局部敏感hash函数(LSH)

import faiss
import numpy as np

emb_items = item_emb.items()
# emb_items 格式
# dict_items([('9936', array([-0.02695967, -0.14685549, -0.13547155, -0.00479672, -0.00417542])), 
#('710', array([ 0.03320758, -0.11462902,  0.17806329, -0.3202453 ,  0.12857111])),...])
emb_items = list(__builtin__.filter(lambda t: len(t[1]) > 0, emb_items))

item_ids = [i[0] for i in emb_items]
embs = [i[1] for i in emb_items]

index = faiss.IndexLSH(len(embs[0]), 256)
index.add(np.asarray(embs, dtype=np.float32))
# embs[99]=[-0.03878592  0.15011692  0.01134511  0.03049661  0.11688153] 
# user_id=99最接近的10个原始向量的index
# D 是距离(faiss内部定义) I 是原始向量的index
D, I = index.search(np.asanyarray([embs[99]], dtype=np.float32), 10)
print(D) # [[ 0.  8. 23. 25. 29. 32. 33. 33. 39. 39.]]
print(I) # [[ 99 690 540 347 754 788 170 186 206 612]]

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/527693.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【芯片设计- RTL 数字逻辑设计入门 1.1 -- Verdi 使用入门介绍 1】

请阅读【芯片设计 RTL 数字逻辑设计扫盲 】 文章目录 Verdi 介绍Verdi 特点和功能Verdi 基本操作Verdi -elab与-dbdir区别-elab 参数介绍-dbdir 参数介绍区别总结Verdi 介绍 Verdi 是由Synopsys公司开发的一款业界领先的自动化电子设计自动化(EDA)工具,主要用于功能验证和调…

java数据结构与算法刷题-----LeetCode628. 三个数的最大乘积

java数据结构与算法刷题目录&#xff08;剑指Offer、LeetCode、ACM&#xff09;-----主目录-----持续更新(进不去说明我没写完)&#xff1a;https://blog.csdn.net/grd_java/article/details/123063846 文章目录 排序选择线性搜索最值 排序 解题思路&#xff1a;时间复杂度O( …

React - 你知道在React组件的哪个阶段发送Ajax最合适吗

难度级别:中级及以上 提问概率:65% 如果求职者被问到了这个问题,那么只是单纯的回答在哪个阶段发送Ajax请求恐怕是不够全面的。最好是先详细描述React组件都有哪些生命周期,最后再回过头来点题作答,为什么应该在这个阶段发送Ajax请求。那…

【踩坑】修复Latex表格竖线分割/竖线割断/竖线不完整问题

转载请注明出处&#xff1a;小锋学长生活大爆炸[xfxuezhang.blog.csdn.net] 推荐一下 Latex 三线表 横线竖线短横线【踩坑】Latex中multicolumn/multirow单元格竖线消失的恢复方法LaTeX简单常用方法笔记Latex论文写作小技巧记录 1、有时候在画表格的时候&#xff0c;可能会出现…

51单片机之自己配串口寄存器实现波特率9600

本配置是根据手册进行开发配置的 1、首先配置SCON 所以综上所诉 SCON 0x40 &#xff08;0100 0000&#xff09; 2、PCON不用配置 3、配置定时器1 4、波特率的计算 5、配置AUXR 6、对比 7、实现 8、优化&#xff08;实现字符串&#xff09; 引入TI &#xff08;智能延时&…

CLIPSeg如果报“目标计算机积极拒绝,无法连接。”怎么办?

CLIPSeg这个插件在使用的时候&#xff0c;偶尔会遇到以下报错&#xff1a; Error occurred when executing CLIPSeg: (MaxRetryError("HTTPSConnectionPool(hosthuggingface.co, port443): Max retries exceeded with url: /CIDAS/clipseg-rd64-refined/resolve/main/toke…

基于jenkins+gitlab+docker部署zabbix

背景 我现在已经在一台服务器上部署了jenkins和gitlab&#xff0c;现在有一个场景是需要在服务器上再部署一个zabbix&#xff0c;需要通过jenkins加上gitlab部署&#xff0c;并且要求zabbix是通过docker部署的 前提条件 jenkins、gitlab已完成部署并能正常访问&#xff0c;服…

从路由器syslog日志监控路由器流量

路由器是关键的网络基础设施组件&#xff0c;需要随时监控&#xff0c;定期监控路由器可以帮助管理员确保路由器通信正常。日常监控还可以清楚地显出通过网络的流量&#xff0c;通过分析路由器流量&#xff0c;安全管理员可及早识别可能发生的网络事件&#xff0c;从而避免停机…

C语言 | Leetcode C语言题解之第9题回文数

题目&#xff1a; 题解&#xff1a; bool isPalindrome(int x) {if(x < 0)return false;long int sum0;long int nx;while(n!0){sumsum*10n%10;nn/10;}if(sumx)return true;elsereturn false; }

MongoDB基本操作之备份与恢复【验证有效】

资源获取 MongoDB Database Tools 解压zip包&#xff0c;将其中的工具复制到bin目录下 mongodump与mongorestore – 备份 mongodump -h localhost:27017 -u admin -p pass --authenticationDatabase admin -d runoob -o /usr/local/mongo/bak/ --forceTableScan –切换数据库…

《系统架构设计师教程(第2版)》第8章-系统质量属性与架构评估-03-ATAM方法架构评估实践(下)

文章目录 3. 测试阶段3.1 头脑风暴和优先场景&#xff08;第7步&#xff09;3.1.1 理论部分3.1.2 示例 3.2 分析架构方法&#xff08;第8步&#xff09;3.2.1 调查架构方法1&#xff09;安全性2&#xff09;性能 3.2.2 创建分析问题3.2.3 分析问题的答案胡佛架构银行体系结构 3…

深入理解JVM垃圾收集器

相关系列 深入理解JVM垃圾收集算法-CSDN博客 目前市面常见的垃圾收集器有Serial、ParNew、Parallel、CMS、Serial Old、Parallel Old、G1、ZGC以及有二种不常见的Epsilon、Shenandoah的&#xff0c;从上图可以看到有连线的的垃圾收集器是可以组合使用&#xff0c;是年轻代老年代…

快速删除node_modules

1.rd /s /q node_modules 2.rimraf node_modules/ 亲测可用

Java零基础入门-封装

一、概述 谈起面向对面编程&#xff0c;我们都知道有三大特征【封装、继承、多态】&#xff0c;跟随我一起学习的小伙伴都知道&#xff0c;对于三大特征的后两种&#xff0c;我们在前两期已经讲过了&#xff0c;至于我为啥没有按照特征顺序来教学&#xff0c;是因为我常不按规律…

MySQL8.3.0 主从复制方案(master/slave)

一 、什么是MySQL主从 MySQL主从&#xff08;Master-Slave&#xff09;复制是一种数据复制机制&#xff0c;用于将一个MySQL数据库服务器&#xff08;主服务器&#xff09;的数据复制到其他一个或多个MySQL数据库服务器&#xff08;从服务器&#xff09;。这种复制机制可以提供…

Android Studio中查看和修改project的编译jdk版本

android studio中查看和修改project的编译jdk版本操作如下&#xff1a; File->settings->Build,Execution,deployment->Build Tools->Gradles 进入Gradles页面可以查看并修改project的编译jdk版本&#xff0c;如图所示

基于 Lambda 实现 Claude3 的流式响应

在如今的大语言模型推理输出场景中&#xff0c;流式响应基本已成为必备的功能之一。一方面符合大语言模型生成方式的本质&#xff0c;另一方面当模型推理效率不是很高时&#xff0c;流式响应比起全部 generate 后再输出、能大幅缩短从开始请求到输出第一个 Token 的时间&#x…

访问网站显示不安全是什么原因?怎么解决?

访问网站时显示“不安全”&#xff0c;主要原因以及解决办法&#xff1a; 1.没用HTTPS加密&#xff1a;网站还在用老的HTTP协议&#xff0c;数据传输没加密&#xff0c;容易被人偷看或篡改。解决办法是网站管理员启用HTTPS&#xff0c;也就是给网站装个“SSL证书”。这个是最常…

5.6 mybatis之RowBounds分页用法

文章目录 mybatis 中&#xff0c;使用 RowBounds 进行分页&#xff0c;非常方便&#xff0c;不需要在 sql 语句中写 limit&#xff0c;即可完成分页功能。但是由于它是在 sql 查询出所有结果的基础上截取数据的&#xff0c;所以在数据量大的sql中并不适用&#xff0c;它更适合在…

深度学习学习日记4.8(下午)

1.softmax 函数的得出的结果是样本被预测到每个类别的概率&#xff0c;所有类别的概率相加总和等于1。使用 softmax 进行数据归一化&#xff0c;将数字转换成概率。 2.熵&#xff0c;不确定性&#xff0c;越低越好 3.KL 散度交叉熵-信息熵 预测越准&#xff0c;交叉熵越小&am…