论文笔记 Where Would I Go Next? Large Language Models as Human Mobility Predictor-CSDN博客
1 主函数
1.1 导入库
import os
import pickle
import time
import ast
import logging
from datetime import datetime
import pandas as pd
from openai import OpenAI
client = OpenAI(
api_key=...
)
1.2 参数读取
dataname = "geolife"
# 数据集名称
num_historical_stay = 40
# 长期mobility 的跨度
num_context_stay = 5
# 短期mobility的跨度
top_k = 10
# 输出location的数量
with_time = False
# 是否将目标stay的时间信息融入进来
sleep_single_query = 0.1
'''
the sleep time between queries
after the recent updates, the reliability of the API is greatly improved
so we can reduce the sleep time
'''
sleep_if_crash = 1
'''
the sleep time if the server crashes
'''
output_dir = f"output/{dataname}/top10_wot"
'''
the output path
'''
log_dir = f"logs/{dataname}/top10_wot"
'''
the log dir
'''
1.3 读取参数
tv_data, test_file = get_dataset(dataname)
#Number of total test sample: 3459
'''
这个数量是比f"data/{dataname}/{dataname}_test.csv" 要少的
'''
1.4 日志文件生成
logger = get_logger('my_logger', log_dir=log_dir)
1.5 user id 提取
uid_list = get_unqueried_user(dataname, output_dir)
print(f"uid_list: {uid_list}")
'''
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
Number of the remaining id: 45
uid_list: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45]
'''
1.6 生成query
query_all_user(
client,
dataname,
uid_list,
logger,
tv_data,
num_historical_stay,
num_context_stay,
test_file,
output_dir=output_dir,
top_k=top_k,
is_wt=with_time,
sleep_query=sleep_single_query,
sleep_crash=sleep_if_crash)
2 get_dataset函数
def get_dataset(dataname):
# Get training and validation set and merge them
train_data = pd.read_csv(f"data/{dataname}/{dataname}_train.csv")
valid_data = pd.read_csv(f"data/{dataname}/{dataname}_valid.csv")
#读取训练+验证集
# Get test data
with open(f"data/{dataname}/{dataname}_testset.pk", "rb") as f:
test_file = pickle.load(f) # test_file is a list of dict
#测试集
# merge train and valid data
tv_data = pd.concat([train_data, valid_data], ignore_index=True)
tv_data.sort_values(['user_id', 'start_day', 'start_min'], inplace=True)
if dataname == 'geolife':
tv_data['duration'] = tv_data['duration'].astype(int)
#合并训练+验证集
print("Number of total test sample: ", len(test_file))
return tv_data, test_file
3 get_logger
def get_logger(logger_name, log_dir='logs/'):
# Create log dir
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# Create a logger instance
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
#创建一个日志记录器实例,并将其命名为 logger_name,并设置日志记录器的级别为 DEBUG
# Create a console handler and set its log level
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# 创建一个控制台处理器,其作用是将接收到的日志消息输出到控制台
# Create a file handler and set its log level
current_datetime = datetime.now()
formatted_datetime = current_datetime.strftime("%Y%m%d_%H%M%S")
#获取当前日期和时间,格式化为 "YYYYMMDD_HHMMSS" 形式,这部分用于生成日志文件名
log_file = 'log_file' + formatted_datetime + '.log'
#将格式化的时间字符串添加到 "log_file" 后面,构成日志文件名,例如 log_file20230424_153000.log
log_file_path = os.path.join(log_dir, log_file)
#使用 os.path.join(log_dir, log_file) 创建完整的日志文件路径
file_handler = logging.FileHandler(log_file_path)
file_handler.setLevel(logging.DEBUG)
#创建一个文件处理器,用于将日志消息写入到指定的文件中
# Create a formatter and add it to the handlers
formatter = logging.Formatter('%(message)s')
#创建一个格式器 formatter,设置日志格式为仅包含消息体,即 '%(message)s'
console_handler.setFormatter(formatter)
file_handler.setFormatter(formatter)
#将控制台处理器和文件处理器添加到日志记录器实例上
# Add the handlers to the logger
logger.addHandler(console_handler)
logger.addHandler(file_handler)
return logger
4 get_unqueried_user
提取数据集对应的user id
def get_unqueried_user(dataname, output_dir='output/'):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
if dataname == "geolife":
all_user_id = [i+1 for i in range(45)]
elif dataname == "fsq":
all_user_id = [i+1 for i in range(535)]
processed_id = [int(file.split('.')[0]) for file in os.listdir(output_dir) if file.endswith('.csv')]
remain_id = [i for i in all_user_id if i not in processed_id]
print(remain_id)
print(f"Number of the remaining id: {len(remain_id)}")
return remain_id
5 query_all_user
def query_all_user(client, dataname, uid_list, logger, train_data, num_historical_stay,
num_context_stay, test_file, top_k, is_wt, output_dir, sleep_query, sleep_crash):
for uid in uid_list:
logger.info(f"=================Processing user {uid}==================")
user_train = get_user_data(train_data, uid, num_historical_stay, logger)
#当前研究的uid的长期历史mobility(M条)
historical_data, predict_X, predict_y = organise_data(dataname, user_train, test_file, uid, logger, num_context_stay)
'''
返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth
每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),
'''
single_user_query(client, dataname, uid, historical_data, predict_X, predict_y, logger, top_k=top_k,
is_wt=is_wt, output_dir=output_dir, sleep_query=sleep_query, sleep_crash=sleep_crash)
5.1 get_user_data
提取当前研究的uid的长期历史mobility(M条)
def get_user_data(train_data, uid, num_historical_stay, logger):
user_train = train_data[train_data['user_id']==uid]
#找到当下研究的user id对应的所有record
logger.info(f"Length of user {uid} train data: {len(user_train)}")
#user id一共多少条记录
user_train = user_train.tail(num_historical_stay)
logger.info(f"Number of user historical stays: {len(user_train)}")
#long term mobility需要考虑多长的历史轨迹
return user_train
5.2 organise_data
返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth
每一条记录的格式是:('09:08 PM', 'Wednesday', 466, 10),
def organise_data(dataname, user_train, test_file, uid, logger, num_context_stay=5):
# Use another way of organising data
# user_train只是临近的M个record
historical_data = []
if dataname == 'geolife':
for _, row in user_train.iterrows():
historical_data.append(
(convert_to_12_hour_clock(int(row['start_min'])),
int2dow(row['weekday']),
int(row['duration']),
row['location_id'])
)
elif dataname == 'fsq':
for _, row in user_train.iterrows():
historical_data.append(
(convert_to_12_hour_clock(int(row['start_min'])),
int2dow(row['weekday']),
row['location_id'])
)
'''
每次append如下内容
time-of-day:时间转化成几点几分 AM/PM的形式
day-of-week:日子转化成星期几的形式
duration :持续时间类型转化为整型
location id:location 对应的id
eg,
[('09:08 PM', 'Wednesday', 466, 10),
('04:58 AM', 'Thursday', 187, 17),
('08:07 AM', 'Thursday', 146, 1),
('10:35 AM', 'Thursday', 193, 17),
('01:54 PM', 'Thursday', 556, 10)]
'''
logger.info(f"historical_data: {historical_data}")
logger.info(f"Number of historical_data: {len(historical_data)}")
# Get user ith test data
list_user_dict = []
for i_dict in test_file:
if dataname == 'geolife':
i_uid = i_dict['user_X'][0]
elif dataname == 'fsq':
i_uid = i_dict['user_X']
if i_uid == uid:
list_user_dict.append(i_dict)
#测试集中和user id 相同的 record 放入 list_user_dict
#这个user id 需要测试的轨迹
predict_X = []
predict_y = []
for i_dict in list_user_dict:
construct_dict = {}
if dataname == 'geolife':
context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]],
[int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]],
[int(i) for i in i_dict['dur_X'][-num_context_stay:]],
i_dict['X'][-num_context_stay:]))
elif dataname == 'fsq':
context = list(zip([convert_to_12_hour_clock(int(item)) for item in i_dict['start_min_X'][-num_context_stay:]],
[int2dow(i) for i in i_dict['weekday_X'][-num_context_stay:]],
i_dict['X'][-num_context_stay:]))
'''
只看geolife的话,context是一个有五个元素的list
每个元素和前面append到historical_data的格式是一样的
'''
target = (convert_to_12_hour_clock(int(i_dict['start_min_Y'])), int2dow(i_dict['weekday_Y']), None, "<next_place_id>")
#('12:36 AM', 'Friday', None, '<next_place_id>')
construct_dict['context_stay'] = context
construct_dict['target_stay'] = target
#构造输入,临近的N个location+目标的时刻和星期
predict_y.append(i_dict['Y'])
#ground-truth的station id
predict_X.append(construct_dict)
#构造的输入
logger.info(f"Number of predict_data: {len(predict_X)}")
#这个user_id在test 数据中有多少条记录
logger.info(f"predict_y: {predict_y}")
logger.info(f"Number of predict_y: {len(predict_y)}")
#虽然这个数量应该和predict_X的一样
return historical_data, predict_X, predict_y
'''
返回这个user id的:
```长期mobility(不同的test数据共享)
```短期mobility(临近5段location)
```ground truth
'''
5.2.1 convert_to_12_hour_clock
#转化成几点几分 AM/PM的形式
def convert_to_12_hour_clock(minutes):
#原始数据的minutes 从这一天的0点算起,第几分钟
if minutes < 0 or minutes >= 1440:
return "Invalid input. Minutes should be between 0 and 1439."
hours = minutes // 60
minutes %= 60
period = "AM"
if hours >= 12:
period = "PM"
if hours == 0:
hours = 12
elif hours > 12:
hours -= 12
return f"{hours:02d}:{minutes:02d} {period}"
#转化成几点几分 AM/PM的形式
5.2.2 int2dow
#转化成星期几的形式
def int2dow(int_day):
tmp = {0: 'Monday', 1: 'Tuesday', 2: 'Wednesday',
3: 'Thursday', 4: 'Friday', 5: 'Saturday', 6: 'Sunday'}
return tmp[int_day]
5.3 single_user_query
保存location的预测结果
def single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,logger, top_k, is_wt, output_dir, sleep_query, sleep_crash):
# Initialize variables
total_queries = len(predict_X)
logger.info(f"Total_queries: {total_queries}")
#这个user id 一共有多少条查询
processed_queries = 0
current_results = pd.DataFrame({
'user_id': None,
'ground_truth': None,
'prediction': None,
'reason': None
}, index=[])
out_filename = f"{uid:02d}" + ".csv"
out_filepath = os.path.join(output_dir, out_filename)
try:
# Attempt to load previous results if available
current_results = load_results(out_filepath)
processed_queries = len(current_results)
logger.info(f"Loaded {processed_queries} previous results.")
except FileNotFoundError:
logger.info("No previous results found. Starting from scratch.")
'''
读取这个用户已经处理了的预测结果
'''
# Process remaining queries
for i in range(processed_queries, total_queries):
#预测这个用户剩余的查询
logger.info(f'The {i+1}th sample: ')
if dataname == 'geolife':
if is_wt is True:
if top_k == 1:
completions = single_query_top1(client, historical_data, predict_X[i])
elif top_k == 10:
completions = single_query_top10(client, historical_data, predict_X[i])
else:
raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
else:
if top_k == 1:
completions = single_query_top1_wot(client, historical_data, predict_X[i])
elif top_k == 10:
completions = single_query_top10_wot(client, historical_data, predict_X[i])
else:
raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
elif dataname == 'fsq':
if is_wt is True:
if top_k == 1:
completions = single_query_top1_fsq(client, historical_data, predict_X[i])
elif top_k == 10:
completions = single_query_top10_fsq(client, historical_data, predict_X[i])
else:
raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
else:
if top_k == 1:
completions = single_query_top1_wot_fsq(client, historical_data, predict_X[i])
elif top_k == 10:
completions = single_query_top10_wot_fsq(client, historical_data, predict_X[i])
else:
raise ValueError(f"The top_k must be one of 1, 10. However, {top_k} was provided")
'''
gpt针对不同情况的完整response
'''
response = completions.choices[0].message.content
#gpt的response
# Log the prediction results and usage.
logger.info(f"Pred results: {response}")
logger.info(f"Ground truth: {predict_y[i]}")
logger.info(dict(completions).get('usage'))
#使用的token数
try:
res_dict = ast.literal_eval(response)
# 解析gpt的输出,至字典的形式
if top_k != 1:
res_dict['prediction'] = str(res_dict['prediction'])
res_dict['user_id'] = uid
res_dict['ground_truth'] = predict_y[i]
except Exception as e:
res_dict = {'user_id': uid, 'ground_truth': predict_y[i], 'prediction': -100, 'reason': None}
logger.info(e)
logger.info(f"API request failed for the {i+1}th query")
# time.sleep(sleep_crash)
#如果上述任何一步出问题,说明预测失败
finally:
new_row = pd.DataFrame(res_dict, index=[0])
# A dataframe with only one record
#当前这个location 预测的dataframe
current_results = pd.concat([current_results, new_row], ignore_index=True)
# Add new row to the current df
#这个user id的累计 预测location
# Save the current results
current_results.to_csv(out_filepath, index=False)
#save_results(current_results, out_filename)
logger.info(f"Saved {len(current_results)} results to {out_filepath}")
#保存这个user的location 预测结果
# Continue processing remaining queries
if len(current_results) < total_queries:
#remaining_predict_X = predict_X[len(current_results):]
#remaining_predict_y = predict_y[len(current_results):]
#remaining_queries = queries[len(current_results):]
logger.info("Restarting queries from the last successful point.")
single_user_query(client, dataname, uid, historical_data, predict_X, predict_y,
logger, top_k, is_wt, output_dir, sleep_query, sleep_crash)
5.3.1 load_results
就是读取这个user 之前已经保存的预测记录
def load_results(filename):
# Load previously saved results from a CSV file
results = pd.read_csv(filename)
return results
5.3.2 single_query_top1_wot
5.3.3 single_query_top1
5.3.4 single_query_top10
top10的区别不大(几乎一模一样),就多了这样一句话:
5.3.5 get_chat_completion
提供gpt prompt获得相应的结果
def get_chat_completion(client, prompt, model="gpt-3.5-turbo-0613", json_mode=False, max_tokens=1200):
"""
args:
client: the openai client object (new in 1.x version)
prompt: the prompt to be completed
model: specify the model to use
json_mode: whether return the response in json format (new in 1.x version)
"""
messages = [{"role": "user", "content": prompt}]
if json_mode:
completion = client.chat.completions.create(
model=model,
response_format={"type": "json_object"},
messages=messages,
temperature=0, # the degree of randomness of the model's output
max_tokens=max_tokens # the maximum number of tokens to generate
)
else:
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=0,
max_tokens=max_tokens
)
# res_content = response.choices[0].message["content"]
# token_usage = response.usage
return completion