本文实现了基于mpi4py的多进程算法
mpi不过多介绍,某些函数的用法也不是介绍范围,这里只给出怎么实现多进程的方程求根算法。区间划分求解方程,在串行程序里,二分法是非常经典的算法,现在对其进行拓展,实现划分n个区间的求根算法,并利用多个进程计算各自区间。
一、原理
方程求根问题,对于:f(x) = 0 ,绘制y = f(x) 函数图像如下:
对于给定区间[l,r],将其均匀划分为两个区间[l,m]和[m,r]。由零点定理知连
续单调函数的零点存在于两端函数值异号的区间,这里不妨假设为[m,r]区间,则目前区间范围可以由原[l,r]缩小为[m,r]将[m,r]视为新区间进一步进行迭代,最终知道区间大小符合精度要求,区间中点作为方程的根,终止算法。
将上述算法推广到 n+1 个区间的情形,如下图所示:
为算法分配 n 个处理单元,每个处理单元负责计算各自区间端点的值(其中一个处理单元计算两个端点值),筛选出符合零点定理的区间进一步迭代更新区间,最终求解方程根的近似值。具体的,n 条线,划分出 n+1 个区间,总共包含 n+2 个端点值,分配 n 个处理进程。从任务分配来看,有三类进程:
(1)进程 0 为主进程,需要从用户读入初始区间(或者同时读入函数和区间),然后进行区间划分,将整个区间左右端点、以及划分区间长度广播至其他进程,以便每个进程定位自己需要处理的区间。其次需要处理前两个小区间,这就包括了计算第一个小区间左端点的函数值和第一个区间的右端点值(第二个区间的左端点),并且需要从下一个进程处获取第二个区间的右端点值。最后收集每次迭代后根所处的区间。
(2)第二类进程是 1~n-2 号进程,他们的操作都一致,对于第 i 号进程,计算第i+2 个区间的左端点函数值,并从 i+1 号进程获取区间的右端点函数值,判断区间是否包含根,若包含根,向主进程 0 发送消息。
(3)第三类进程是 n-1 号进程,其需要处理最后一个区间的左端点函数计算,但是不需要从其他进程处获取右端点,其右端点由主进程广播所传递获取并计算对应函数值,其他的处理逻辑与第二类进程类似。
最后区间范围缩小到误差允许范围内时停止迭代,主进程直接广播[0,0,0]数
据,其余进程根据广播也停止各自的计算任务。
使用 python MPI 库编程实现上述逻辑,其中以计算方程
x
3
+
2
x
+
1
=
0
x^3 + 2x + 1 = 0
x3+2x+1=0为例,代码实现见第二部分,结果见第三部分。
二、源码
根据原理,使用 mpi4py 库实现代码:
from mpi4py import MPI
import sys
comm = MPI.COMM_WORLD
rank = comm.rank
size = comm.size
eps = 0.000005 # 求解精度
def func(x):
return x ** 3 + 2 * x + 1 # -0.453398
# while循环写作最外面每次进程都要多走几次判断条件,开销更大一些
li, ri, step = 0, 0, 0
if rank == 0:
# 主进程,负责读入初始区间,区间端点和区间长度分发给各个进程
# 计算前两个区间左端点值,从下一个区间获取第二个区间右端点值
print("input the endpoint values of the interval l, r: ")
li, ri = map(float, input().split())
while True:
if ri - li <= eps:
comm.bcast((0, 0, 0), root=0) # 终止时也发送一次广播
break
step = (ri - li) * 1.0 / (size + 1) # 区间长度
comm.bcast((li, ri, step), root=0)
print("l=%f, r=%f" % (li, ri))
sys.stdout.flush()
fl = func(li)
# 区间左端点l的函数值由0进程负责,右端点r对应的值让最后一个进程处理
# 0号进程处理两个区间
r0 = func(li + step)
r1 = comm.recv(source=rank + 1) # 主进程不需要发送端点给其他进程
if fl * r0 < 0:
ri = li + step # 第一个区间含根,更新右边界
elif r0 * r1 < 0:
li, ri = li + step, li + 2 * step # 第二个区间含根
else: # 新区间在其他进程处,等待消息
li, ri = comm.recv(source=MPI.ANY_SOURCE)
print("result of x^3+2x+1=0: %.6f" % ((li + ri) * 1.0 / 2))
elif rank == size - 1:
# 最后一个进程处理最后一个区间,并且计算最后一个区间两端点的值
# 不从其他进程获取任何数据(除广播数据外)
while True:
lt, rt, st = comm.bcast((li, ri, step), root=0)
if rt - lt <= eps:
break
nl = lt + (rank + 1) * st
r0 = func(nl)
comm.send(r0, dest=rank - 1)
r1 = func(rt)
print('rank %d,l=%.4f, r=%.4f' % (rank, nl, nl + st))
sys.stdout.flush()
if r0 * r1 < 0: # 包含根,向0进程汇报
comm.send([nl, nl + st], dest=0)
else:
# 其余进程处理逻辑一致,计算区间左端点,从下一个进程获取右端点
while True:
lt, rt, st = comm.bcast((li, ri, step), root=0)
if rt - lt <= eps:
break
nl = lt + (rank + 1) * st
r0 = func(nl)
comm.send(r0, dest=rank - 1)
r1 = comm.recv(source=rank + 1)
print('rank %d,l=%.4f, r=%.4f' % (rank, nl, nl + st))
sys.stdout.flush()
if r0 * r1 < 0: # 包含根,向0进程汇报
comm.send([nl, nl + st], dest=0)
三、 运行结果
运行命令:
mpiexec -n 6 python caculate.py
运行上述程序,其中初始化区间为[-100,100],6 个进程计算结果为: -0.453397,将x=-0.453397 代回 x 3 + 2 x + 1 x^3 + 2x + 1 x3+2x+1计算结果为:0.000001705,可见误差已经符合要求。
注意,代码里没有特别的处理无解情况下的逻辑,输入的初始区间不包含根代码无法正常结束。