Libsvm中grid.py文件的解读

1.导入相关文件

这里重点讲一下 _ _all_ _ = ['find_parameters']

_all__ = ['find_parameters'] 是 Python 中用于定义模块级别的变量 __all__ 的语法__all__ 是一个包含模块中应该被公开(即可以通过 from module import * 导入)的变量名的列表

  • __all__ 是一个约定俗成的变量名,用于指定在使用 from module import * 语句时,应该导入哪些变量名。这样可以控制模块的命名空间,避免不必要的变量污染。

  • ['find_parameters'] 是一个包含在 __all__ 中的列表,其中包含了模块中应该被导入的变量名。在这个例子中,只有一个变量名 find_parameters 被包含在 __all__ 中。

通过这个设置,当其他模块使用 from module import * 导入这个模块时,只有  find_parameters  这个变量名会被导入,其他未在 __all__ 中指定的变量不会被导入。这是一种良好的编程实践,因为它可以提供更清晰的模块接口,避免不必要的命名冲突和变量污染。 

2.GridOption类的定义

      构造函数接收两个参数:dataset_pathname 和 options

      根据操作系统设置svm-train.exe和gnuplot.exe 的路径,这个要根据自己系统的实际按照情况  来进行路径的设置。

      默认参数的设置以及解析传入参数的函数parse_options

      最后,检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性。

class GridOption:
    '''
    构造函数 __init__:
    接收两个参数 dataset_pathname 和 options
                dataset_pathname 是数据集的路径
                options 是一个包含其他配置选项的字典
    获取当前脚本所在目录,并根据操作系统设置 svmtrain_pathname 和 gnuplot_pathname
    '''
    def __init__(self, dataset_pathname, options):
        dirname = os.path.dirname(__file__)
        # 使用 sys.platform 来检查操作系统
        # 如果不是 Windows (sys.platform != 'win32'),则设置 svmtrain_pathname 为在当前脚本所在目录下的 ‘…/svm-train’,并设置 gnuplot_pathname 为 ‘/usr/bin/gnuplot’
        if sys.platform != 'win32':
            self.svmtrain_pathname = os.path.join(dirname, '../svm-train')
            self.gnuplot_pathname = '/usr/bin/gnuplot'
        else:
            # example for windows
            # 如果是 Windows,则设置 svmtrain_pathname 为在当前脚本所在目录下的 r’…\windows\svm-train.exe’,并设置 gnuplot_pathname 为 r’c:\tmp\gnuplot\binary\pgnuplot.exe’
            self.svmtrain_pathname = os.path.join(dirname, r'..\windows\svm-train.exe')
            # svmtrain_pathname = r'c:\Program Files\libsvm\windows\svm-train.exe'
            self.gnuplot_pathname = r'c:\tmp\gnuplot\binary\pgnuplot.exe'
        # 默认参数的设置
        # 设置了一系列参数的默认值,例如 fold、c_begin、c_end、c_step、g_begin、g_end、g_step 等,用于定义网格搜索的参数范围和步长
        # 设置了 grid_with_c 和 grid_with_g 为 True,表示要在网格搜索中搜索 C 和 gamma 参数
        self.fold = 5
        self.c_begin, self.c_end, self.c_step = -5,  15,  2
        self.g_begin, self.g_end, self.g_step =  3, -15, -2
        self.grid_with_c, self.grid_with_g = True, True
        self.dataset_pathname = dataset_pathname  # 将传入的 dataset_pathname 赋值给 self.dataset_pathname
        self.dataset_title = os.path.split(dataset_pathname)[1]   # 提取数据集的标题部分,通过 os.path.split(dataset_pathname) 和 [1] 获取,赋值给 self.dataset_title
        self.out_pathname = '{0}.out'.format(self.dataset_title)  # 设置 out_pathname 为 ‘{0}.out’,其中 {0} 是数据集标题
        self.png_pathname = '{0}.png'.format(self.dataset_title)  # 设置 png_pathname 为 ‘{0}.png’,其中 {0} 是数据集标题
        self.pass_through_string = ' '  # 设置 pass_through_string 为一个空格
        self.resume_pathname = None     # 设置 resume_pathname 为 None
        self.parse_options(options)     # 调用 parse_options 方法,该方法用于解析传入的选项,并更新类的属性值

    # 定义了 parse_options 方法,该方法用于解析传入的选项列表,更新 GridOption 类的属性值
    def parse_options(self, options):
        # options 是传入的选项,可以是字符串,也可以是由字符串组成的列表
        # 如果 options 是字符串,通过 options.split() 将其分割成列表
        if type(options) == str:
            options = options.split()
        i = 0  # 初始化变量 i 为 0,用于迭代 options 列表
        # 初始化空列表 pass_through_options,用于存储未被解析的选项
        pass_through_options = []
        
        # 使用 while 循环遍历 options 列表
        # 通过检查当前选项,更新相应的 GridOption 类属性
        while i < len(options):
            '''
            -log2c 和 -log2g:解析参数范围和步长,如果值为 'null',则相应的网格搜索标志设为 False
            -v:设置交叉验证的折数
            -c 和 -g:抛出错误,提示使用 -log2c 和 -log2g
            -svmtrain:设置 SVM 训练可执行文件路径
            -gnuplot:设置 Gnuplot 可执行文件路径,如果值为 'null',则设为 None
            -out:设置输出文件路径,如果值为 'null',则设为 None
            -png:设置 PNG 文件路径
            -resume:设置恢复训练的文件路径,如果未提供则使用默认文件名
            '''
            if options[i] == '-log2c':
                i = i + 1
                if options[i] == 'null':
                    self.grid_with_c = False
                else:
                    self.c_begin, self.c_end, self.c_step = map(float,options[i].split(','))
            elif options[i] == '-log2g':
                i = i + 1
                if options[i] == 'null':
                    self.grid_with_g = False
                else:
                    self.g_begin, self.g_end, self.g_step = map(float,options[i].split(','))
            elif options[i] == '-v':
                i = i + 1
                self.fold = options[i]
            elif options[i] in ('-c','-g'):
                raise ValueError('Use -log2c and -log2g.')
            elif options[i] == '-svmtrain':
                i = i + 1
                self.svmtrain_pathname = options[i]
            elif options[i] == '-gnuplot':
                i = i + 1
                if options[i] == 'null':
                    self.gnuplot_pathname = None
                else:
                    self.gnuplot_pathname = options[i]
            elif options[i] == '-out':
                i = i + 1
                if options[i] == 'null':
                    self.out_pathname = None
                else:
                    self.out_pathname = options[i]
            elif options[i] == '-png':
                i = i + 1
                self.png_pathname = options[i]
            elif options[i] == '-resume':
                if i == (len(options)-1) or options[i+1].startswith('-'):
                    self.resume_pathname = self.dataset_title + '.out'
                else:
                    i = i + 1
                    self.resume_pathname = options[i]
            else:
                pass_through_options.append(options[i])  # 未识别的选项将被添加到 pass_through_options 列表中
            i = i + 1
        # 使用 ' '.join(pass_through_options) 将未识别的选项组合成一个字符串,更新 pass_through_string 属性
        self.pass_through_string = ' '.join(pass_through_options)

        # 检查 SVM 训练可执行文件路径、数据集路径和 Gnuplot 可执行文件路径的存在性
        if not os.path.exists(self.svmtrain_pathname):
            raise IOError('svm-train executable not found')
        if not os.path.exists(self.dataset_pathname):
            raise IOError('dataset not found')
        if self.resume_pathname and not os.path.exists(self.resume_pathname):
            raise IOError('file for resumption not found')  # 如果 resume_pathname 存在,检查其存在性
        if not self.grid_with_c and not self.grid_with_g:   # 如果同时设置了 -log2c 和 -log2g 为 False,抛出错误
            raise ValueError('-log2c and -log2g should not be null simultaneously')
        if self.gnuplot_pathname and not os.path.exists(self.gnuplot_pathname):
        # 如果 Gnuplot 可执行文件不存在,输出错误信息并将其设为 None
            sys.stderr.write('gnuplot executable not found\n')
            self.gnuplot_pathname = None

补充“win32” 是 Windows 操作系统的平台标识符。在 Python 中,sys.platform 返回一个字符串,表示当前运行 Python 解释器的平台。对于 Windows 系统,这个字符串通常是"win32"。所以,if sys.platform != 'win32' 这个条件语句检查当前操作系统是否为 Windows 之外的其他操作系统。

 3. 定义redraw 函数,用于在图形界面中绘制 SVM 参数搜索的轮廓图

def redraw(db,best_param,gnuplot,options,tofile=False):
    if len(db) == 0: return
    begin_level = round(max(x[2] for x in db)) - 3
    step_size = 0.5

    best_log2c,best_log2g,best_rate = best_param

    # if newly obtained c, g, or cv values are the same,
    # then stop redrawing the contour.
    if all(x[0] == db[0][0]  for x in db): return
    if all(x[1] == db[0][1]  for x in db): return
    if all(x[2] == db[0][2]  for x in db): return

    if tofile:
        gnuplot.write(b"set term png transparent small linewidth 2 medium enhanced\n")
        gnuplot.write("set output \"{0}\"\n".format(options.png_pathname.replace('\\','\\\\')).encode())
        #gnuplot.write(b"set term postscript color solid\n")
        #gnuplot.write("set output \"{0}.ps\"\n".format(options.dataset_title).encode().encode())
    elif sys.platform == 'win32':
        gnuplot.write(b"set term windows\n")
    else:
        gnuplot.write( b"set term x11\n")
    gnuplot.write(b"set xlabel \"log2(C)\"\n")
    gnuplot.write(b"set ylabel \"log2(gamma)\"\n")
    gnuplot.write("set xrange [{0}:{1}]\n".format(options.c_begin,options.c_end).encode())
    gnuplot.write("set yrange [{0}:{1}]\n".format(options.g_begin,options.g_end).encode())
    gnuplot.write(b"set contour\n")
    gnuplot.write("set cntrparam levels incremental {0},{1},100\n".format(begin_level,step_size).encode())
    gnuplot.write(b"unset surface\n")
    gnuplot.write(b"unset ztics\n")
    gnuplot.write(b"set view 0,0\n")
    gnuplot.write("set title \"{0}\"\n".format(options.dataset_title).encode())
    gnuplot.write(b"unset label\n")
    gnuplot.write("set label \"Best log2(C) = {0}  log2(gamma) = {1}  accuracy = {2}%\" \
                  at screen 0.5,0.85 center\n". \
                  format(best_log2c, best_log2g, best_rate).encode())
    gnuplot.write("set label \"C = {0}  gamma = {1}\""
                  " at screen 0.5,0.8 center\n".format(2**best_log2c, 2**best_log2g).encode())
    gnuplot.write(b"set key at screen 0.9,0.9\n")
    gnuplot.write(b"splot \"-\" with lines\n")

    db.sort(key = lambda x:(x[0], -x[1]))

    prevc = db[0][0]
    for line in db:
        if prevc != line[0]:
            gnuplot.write(b"\n")
            prevc = line[0]
        gnuplot.write("{0[0]} {0[1]} {0[2]}\n".format(line).encode())
    gnuplot.write(b"e\n")
    gnuplot.write(b"\n") # force gnuplot back to prompt when term set failure
    gnuplot.flush()

4. 函数 calculate_jobs 的定义

 该函数接受一个参数 options,并返回两个值:jobs 和 resumed_jobs,同时里面嵌套定义了函数 range_f 和函数 permute_sequence。

 函数的主要目的是生成一系列的任务(jobs),每个任务是一个参数组合,用于训练支持向量机(SVM)。这些参数是通过对给定的一组参数范围进行排列组合得到的。

def calculate_jobs(options):

    def range_f(begin,end,step):
        # like range, but works on non-integer too
        seq = []
        while True:
            if step > 0 and begin > end: break
            if step < 0 and begin < end: break
            seq.append(begin)
            begin = begin + step
        return seq

    def permute_sequence(seq):
        n = len(seq)
        if n <= 1: return seq

        mid = int(n/2)
        left = permute_sequence(seq[:mid])
        right = permute_sequence(seq[mid+1:])

        ret = [seq[mid]]
        while left or right:
            if left: ret.append(left.pop(0))
            if right: ret.append(right.pop(0))

        return ret


    c_seq = permute_sequence(range_f(options.c_begin,options.c_end,options.c_step))
    g_seq = permute_sequence(range_f(options.g_begin,options.g_end,options.g_step))

    if not options.grid_with_c:
        c_seq = [None]
    if not options.grid_with_g:
        g_seq = [None]

    nr_c = float(len(c_seq))
    nr_g = float(len(g_seq))
    i, j = 0, 0
    jobs = []

    while i < nr_c or j < nr_g:
        if i/nr_c < j/nr_g:
            # increase C resolution
            line = []
            for k in range(0,j):
                line.append((c_seq[i],g_seq[k]))
            i = i + 1
            jobs.append(line)
        else:
            # increase g resolution
            line = []
            for k in range(0,i):
                line.append((c_seq[k],g_seq[j]))
            j = j + 1
            jobs.append(line)

    resumed_jobs = {}

    if options.resume_pathname is None:
        return jobs, resumed_jobs

    for line in open(options.resume_pathname, 'r'):
        line = line.strip()
        rst = re.findall(r'rate=([0-9.]+)',line)
        if not rst:
            continue
        rate = float(rst[0])

        c, g = None, None
        rst = re.findall(r'log2c=([0-9.-]+)',line)
        if rst:
            c = float(rst[0])
        rst = re.findall(r'log2g=([0-9.-]+)',line)
        if rst:
            g = float(rst[0])

        resumed_jobs[(c,g)] = rate

    return jobs, resumed_jobs

range_f函数:

  • range_f 函数是一个自定义的函数,类似于内置函数 range,但可以处理非整数的步长。它生成一个序列,从 begin 开始,以 step 为步长,直到不再满足条件。

permute_sequence函数:

  • permute_sequence 函数用于对给定序列进行排列组合。它采用分而治之的方法,将序列分成两半,然后递归地对左右两半进行排列组合,最终将结果合并。

参数生成:

  • 使用 range_f 函数生成了两个序列 c_seq 和 g_seq,分别表示参数 c 和 g 的可能取值。如果选项 options.grid_with_c 或 options.grid_with_g 为 False,则相应的参数序列为单一值,即  [None]

生成任务列表:

  • 使用生成的参数序列,通过两个循环(while 循环)生成所有可能的参数组合,存储在 jobs 列表中。

处理恢复任务:

  • 如果存在恢复路径 options.resume_pathname,则从该路径读取已经完成的任务信息,提取出参数组合和对应的性能率,并存储在 resumed_jobs 字典中。

返回结果:

  • 最终,函数返回两个值:生成的任务列表 jobs 和已经完成的任务信息字典 resumed_jobs

这段代码主要用于生成一系列参数组合,以及处理从先前运行中恢复的任务信息。这类功能通常在超参数搜索和模型训练中使用,以便系统能够自动尝试多种参数组合。

5.类WorkerStopToken的定义

通常用作信号或标记,用于通信或控制多线程或多进程的执行流程。在这里, WorkerStopToken 的目的是作为一个简单的标记,用于通知工作线程停止或表示工作线程已经停止。在实际应用中,它可能会与其他线程或进程之间的通信机制一起使用,以实现协同工作或关闭。

class WorkerStopToken::定义了一个新的类,类名为 WorkerStopToken 

pass:在Python中,pass 是一个占位符语句,不执行任何操作。在这里,它被用作类的主体部分,表示这个类是一个空类,没有任何成员或方法。

6. 类Worker的定义

 Worker类继承自Python中的Thread类 ,这个类表示一个工作线程,用于执行支持向量机(SVM)的训练任务,该类定义了三个函数:_ _init_ _方法、run方法、get_cmd方法

class Worker(Thread):
    def __init__(self,name,job_queue,result_queue,options):
        Thread.__init__(self)
        self.name = name
        self.job_queue = job_queue
        self.result_queue = result_queue
        self.options = options

    

__init__ 方法:

  • 初始化方法,接受四个参数:name(线程名称)、job_queue(任务队列)、result_queue(结果队列)、options(选项参数)
  • 将这些参数保存为实例变量(也可以说是成员变量),用于在线程运行时访问
  • self:表示对象的实例

 

    def run(self):
        while True:
            (cexp,gexp) = self.job_queue.get()
            if cexp is WorkerStopToken:
                self.job_queue.put((cexp,gexp))
                # print('worker {0} stop.'.format(self.name))
                break
            try:
                c, g = None, None
                if cexp != None:
                    c = 2.0**cexp
                if gexp != None:
                    g = 2.0**gexp
                rate = self.run_one(c,g)
                if rate is None: raise RuntimeError('get no rate')
            except:
                # we failed, let others do that and we just quit

                traceback.print_exception(sys.exc_info()[0], sys.exc_info()[1], sys.exc_info()[2])

                self.job_queue.put((cexp,gexp))
                sys.stderr.write('worker {0} quit.\n'.format(self.name))
                break
            else:
                self.result_queue.put((self.name,cexp,gexp,rate))

 run 方法:

  • run 方法是 Thread 类的默认方法,在启动线程时会自动调用。这里是线程的主要执行逻辑。
  • 使用无限循环 (while True) 从任务队列 (job_queue) 获取任务,任务是 (cexp, gexp),其中 cexp 和 gexp 表示对应的参数指数。
  • 如果接收到 WorkerStopToken,表示线程应该停止,将任务重新放回队列,并通过 break 退出循环,结束线程。
  • 否则,尝试将指数转换为实际的参数值 c 和 g,然后调用 run_one 方法执行具体的 SVM 训练,并获取性能率。
  • 如果执行出错,将异常信息输出到标准错误流,并将任务重新放回队列,然后通过  sys.stderr.write  输出线程终止的信息,并通过 break 退出循环,结束线程。
  • 如果一切正常,将线程的名字、cexpgexp 和性能率放入结果队列 (result_queue)。

这段代码实现了一个工作线程的逻辑,用于执行 SVM 训练任务。它通过任务队列接收参数组合,执行训练,并将结果放入结果队列。这样的多线程结构通常用于加速大规模参数搜索和训练任务  


    def get_cmd(self,c,g):
        options=self.options
        cmdline = '"' + options.svmtrain_pathname + '"'
        if options.grid_with_c:
            cmdline += ' -c {0} '.format(c)
        if options.grid_with_g:
            cmdline += ' -g {0} '.format(g)
        cmdline += ' -v {0} {1} {2} '.format\
            (options.fold,options.pass_through_string,options.dataset_pathname)
        return cmdline

get_cmd 方法:

  • 用于生成 SVM 训练的命令行字符串,其中包括 SVM 训练器的路径、参数 -c(如果启用)、参数 -g(如果启用)、参数 -v、折数、透传参数和数据集路径。

下面我再来详细地讲解一下get_cmd方法 :

def get_cmd(self,c,g)

      定义了一个方法 get_cmd,接受两个参数 c 和 g,表示 SVM 训练 的参数

options = self.options

      将类实例中的 options 属性赋给局部变量 options,以便在后续代码中使用

cmdline = '"' + options.svmtrain_pathname + '"'

     构建命令行字符串的开头部分,包含 SVM 训练器的路径。使用双引号将路径括起来,以防止 路径中包含空格时出现问题。

if options.grid_with_c:

     检查选项 grid_with_c 是否为真,即是否启用了参数 c 的网格搜索

cmdline += ' -c {0} '.format(c)

     如果启用了参数 c 的网格搜索,则将参数 c 的值添加到命令行字符串中

if options.grid_with_g:

     检查选项 grid_with_g 是否为真,即是否启用了参数 g 的网格搜索

cmdline += ' -g {0} '.format(g)

     如果启用了参数 g 的网格搜索,则将参数 g 的值添加到命令行字符串中

cmdline += ' -v {0} {1} {2} '.format(options.fold, options.pass_through_string, options.dataset_pathname)

    添加 SVM 训练的其他参数,包括:

  • -v:表示要进行交叉验证
  • {0}:使用 options.fold 指定的折数
  • {1}:用户传递的额外参数
  • {2}:数据集的路径,由 options.dataset_pathname 指定

return cmdline

     返回构建好的 SVM 训练命令行字符串 

总体而言,这段代码的作用是根据给定的参数 c 和 g 以及一些配置选项生成用于执行 SVM 训练的命令行字符串。生成的命令行包括 SVM 训练器的路径、参数 -c(如果启用)、参数 -g(如果启用)、参数 -v、交叉验证的折数、额外参数和数据集的路径。

 

 7.类LocalWorker的定义

定义了一个名为 LocalWorker 的类,它继承自先前提到的 Worker 类,并重写了 run_one 方法

class LocalWorker(Worker):
    def run_one(self,c,g):
        cmdline = self.get_cmd(c,g)
        result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
        for line in result.readlines():
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

run_one方法

该方法接受两个参数 c 和 g,表示 SVM 训练的参数

cmdline = self.get_cmd(c,g)

 调用父类 Worker 的 get_cmd 方法,获取 SVM 训练的命令行字符串,并将其赋给 cmdline

result =  Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout

使用 subprocess.Popen 创建一个新的进程,运行 SVM 训练的命令行,其中

  • cmdline 是要执行的命令行字符串
  • shell=True 表示使用系统的 shell 执行命令
  • stdout=PIPE 表示将命令的标准输出捕获到 result 变量中
  • stderr=PIPE 表示将命令的标准错误捕获,但在这段代码中没有使用
  • stdin=PIPE 表示标准输入连接到管道,但在这段代码中没有使用

for line in result.readlines():

遍历命令的标准输出的每一行

if str(line).find('Cross') != -1:

判断当前行是否包含字符串 ‘Cross’。如果包含,说明这一行包含了交叉验证的结果信息

return float(line.split()[-1][0:-1])

如果找到包含 ‘Cross’ 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点 数。这个值表示 SVM 训练的性能率。

总体而言,这段代码实现了在本地环境运行 SVM 训练任务的逻辑。它通过创建新的进程执行 SVM 训练命令行,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。 

 8.类SSHWorker的定义

class SSHWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host,options):
        Worker.__init__(self,name,job_queue,result_queue,options)
        self.host = host
        self.cwd = os.getcwd()
    def run_one(self,c,g):
        cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"'.format\
            (self.host,self.cwd,self.get_cmd(c,g))
        result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout
        for line in result.readlines():
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

 定义了一个名为 SSHWorker 的类,它同样继承自之前提到的 Worker 类,并进行了一些定制化。

该类定义了初始化函数和run_one函数

 __init__方法

初始化方法,除了调用父类的初始化方法外,还接受一个额外的参数 host,表示远程主机的地 址。

  self.host = host:将传入的 host 参数保存为实例变量,以便在后续代码中使用

  self.cwd = os.getcwd():获取当前工作目录,并保存为实例变量 cwd

run_one方法

重写了 run_one 方法,该方法接受两个参数 c 和 g,表示 SVM 训练的参数

cmdline = 'ssh -x -t -t {0} "cd {1}; {2}"' .format (self.host, self.cwd, self.get_cmd(c,g))

构建了一个 SSH 命令行字符串,该命令行用于在远程主机上执行 SVM 训练任务

  ssh -x -t -t:表示使用 SSH 连接,并在远程主机上执行命令

     {0}:用传入的 host 替换占位符,表示远程主机的地址

  "cd {1}; {2}":在远程主机上执行的命令,首先切换到当前工作目录(cwd),然后执行通过

      调 用 get_cmd 方法生成的 SVM 训练命令

result = Popen(cmdline,shell=True,stdout=PIPE,stderr=PIPE,stdin=PIPE).stdout:

使用 subprocess.Popen 创建一个新的进程,运行 SSH 命令行

  • cmdline 是要执行的 SSH 命令行字符串
  • stdout=PIPE 表示将命令的标准输出捕获到 result 变量中

for line in result.readlines():

遍历命令的标准输出的每一行

if str(line).find('Cross') != -1:

判断当前行是否包含字符串 ‘Cross’。如果包含,说明这一行包含了交叉验证的结果信息

return float(line.split()[-1][0:-1])

如果找到包含 ‘Cross’ 的行,提取该行的最后一个单词,去掉末尾的换行符,并将其转换为浮点数。这个值表示在远程主机上运行 SVM 训练的性能率

总体而言,这段代码实现了在远程主机上通过 SSH 运行 SVM 训练任务的逻辑。它构建了相应的 SSH 命令行,执行远程任务,并从命令的标准输出中提取包含交叉验证结果的行,最终返回性能率作为结果。

 9.类TelnetWorker的定义

class TelnetWorker(Worker):
    def __init__(self,name,job_queue,result_queue,host,username,password,options):
        Worker.__init__(self,name,job_queue,result_queue,options)
        self.host = host
        self.username = username
        self.password = password
    def run(self):
        import telnetlib
        self.tn = tn = telnetlib.Telnet(self.host)
        tn.read_until('login: ')
        tn.write(self.username + '\n')
        tn.read_until('Password: ')
        tn.write(self.password + '\n')

        # XXX: how to know whether login is successful?
        tn.read_until(self.username)
        #
        print('login ok', self.host)
        tn.write('cd '+os.getcwd()+'\n')
        Worker.run(self)
        tn.write('exit\n')
    def run_one(self,c,g):
        cmdline = self.get_cmd(c,g)
        result = self.tn.write(cmdline+'\n')
        (idx,matchm,output) = self.tn.expect(['Cross.*\n'])
        for line in output.split('\n'):
            if str(line).find('Cross') != -1:
                return float(line.split()[-1][0:-1])

总体而言,这段代码实现了在远程主机上通过 Telnet 运行 SVM 训练任务的逻辑。它通过 Telnet 协议连接远程主机,执行相应的命令,并从输出中提取包含交叉验证结果的行,最终返回性能率作为结果。需要注意的是,代码中对登录成功的判断逻辑可能需要进一步完善。

 10.函数find_parameters的定义

这段代码实现了对 SVM 模型参数的并行搜索和优化,通过多线程/进程执行不同参数组合的训练 任务,然后比较性能,最终找到最佳的参数组合。

用于参数搜索和优化的部分,具体来说,它使用了多线程/进程的方式来执行 SVM 参数的搜索工作

def find_parameters(dataset_pathname, options=''):

    def update_param(c,g,rate,best_c,best_g,best_rate,worker,resumed):
        if (rate > best_rate) or (rate==best_rate and g==best_g and c<best_c):
            best_rate,best_c,best_g = rate,c,g
        stdout_str = '[{0}] {1} {2} (best '.format\
            (worker,' '.join(str(x) for x in [c,g] if x is not None),rate)
        output_str = ''
        if c != None:
            stdout_str += 'c={0}, '.format(2.0**best_c)
            output_str += 'log2c={0} '.format(c)
        if g != None:
            stdout_str += 'g={0}, '.format(2.0**best_g)
            output_str += 'log2g={0} '.format(g)
        stdout_str += 'rate={0})'.format(best_rate)
        print(stdout_str)
        if options.out_pathname and not resumed:
            output_str += 'rate={0}\n'.format(rate)
            result_file.write(output_str)
            result_file.flush()

        return best_c,best_g,best_rate

 def find_parameters(dataset_pathname, options=''):

  • 定义了一个名为 find_parameters 的函数,用于寻找 SVM 模型的最佳参数

def update_param(c, g, rate, best_c, best_g, best_rate, worker, resumed):

  • 定义了一个辅助函数 update_param,用于更新最佳参数和最佳性能率

 

options = GridOption(dataset_pathname, options);

    if options.gnuplot_pathname:
        gnuplot = Popen(options.gnuplot_pathname,stdin = PIPE,stdout=PIPE,stderr=PIPE).stdin
    else:
        gnuplot = None

options = GridOption(dataset_pathname, options);

  • 使用 GridOption 类处理参数选项,GridOption 类是对参数进行解析和处理的一个自定义类

if options.gnuplot_pathname:

  • 判断是否提供了 gnuplot 路径,如果提供了,则创建一个与 gnuplot 进程进行通信的管道

  # put jobs in queue

    jobs,resumed_jobs = calculate_jobs(options)
    job_queue = Queue(0)
    result_queue = Queue(0)

    for (c,g) in resumed_jobs:
        result_queue.put(('resumed',c,g,resumed_jobs[(c,g)]))

    for line in jobs:
        for (c,g) in line:
            if (c,g) not in resumed_jobs:
                job_queue.put((c,g))

    # hack the queue to become a stack --
    # this is important when some thread
    # failed and re-put a job. It we still
    # use FIFO, the job will be put
    # into the end of the queue, and the graph
    # will only be updated in the end

    job_queue._put = job_queue.queue.appendleft

jobs, resumed_jobs = calculate_jobs(options)

调用 calculate_jobs 函数,生成需要执行的任务列表 jobs 和已经恢复的任务列表 resumed_jobs 

job_queue = Queue(0) 和 result_queue = Queue(0):

创建两个队列,job_queue 用于存放待执行的任务,result_queue 用于存放执行结果

for (c, g) in resumed_jobs: 和 for line in jobs:

  • 循环遍历已经恢复的任务和待执行的任务

job_queue._put = job_queue.queue.appendleft

将 job_queue 的 _put 方法指向 appendleft 方法,将队列变成一个栈,以确保重新放入的任务在队列头部


 # fire telnet workers

    if telnet_workers:
        nr_telnet_worker = len(telnet_workers)
        username = getpass.getuser()
        password = getpass.getpass()
        for host in telnet_workers:
            worker = TelnetWorker(host,job_queue,result_queue,
                     host,username,password,options)
            worker.start()

    # fire ssh workers

    if ssh_workers:
        for host in ssh_workers:
            worker = SSHWorker(host,job_queue,result_queue,host,options)
            worker.start()

    # fire local workers

    for i in range(nr_local_worker):
        worker = LocalWorker('local',job_queue,result_queue,options)
        worker.start()

    # gather results

    done_jobs = {}

    if options.out_pathname:
        if options.resume_pathname:
            result_file = open(options.out_pathname, 'a')
        else:
            result_file = open(options.out_pathname, 'w')

   

if telnet_workers: 和 if ssh_workers:

根据是否提供了 Telnet 或 SSH 主机列表,启动相应的 TelnetWorker 或 SSHWorker

for i in range(nr_local_worker): 启动本地工作线程,数量由 nr_local_worker 决定

done_jobs = {}:  用于存放已完成的任务及其结果

if options.out_pathname:如果提供了输出路径,则打开一个文件用于记录结果



    db = []
    best_rate = -1
    best_c,best_g = None,None

    for (c,g) in resumed_jobs:
        rate = resumed_jobs[(c,g)]
        best_c,best_g,best_rate = update_param(c,g,rate,best_c,best_g,best_rate,'resumed',True)

    for line in jobs:
        for (c,g) in line:
            while (c,g) not in done_jobs:
                (worker,c1,g1,rate1) = result_queue.get()
                done_jobs[(c1,g1)] = rate1
                if (c1,g1) not in resumed_jobs:
                    best_c,best_g,best_rate = update_param(c1,g1,rate1,best_c,best_g,best_rate,worker,False)
            db.append((c,g,done_jobs[(c,g)]))
        if gnuplot and options.grid_with_c and options.grid_with_g:
            redraw(db,[best_c, best_g, best_rate],gnuplot,options)
            redraw(db,[best_c, best_g, best_rate],gnuplot,options,True)

db = [] 和 best_rate = -1用于存放任务执行结果的数据库和记录最佳性能率的变量

for (c, g) in resumed_jobs: 遍历已恢复的任务,更新最佳参数和最佳性能率

for line in jobs: 遍历待执行的任务

while (c, g) not in done_jobs: 循环等待任务执行完成,并将执行结果放入 done_jobs

(worker, c1, g1, rate1) = result_queue.get()从结果队列中获取执行结果

db.append((c, g, done_jobs[(c, g)])):将任务执行结果加入数据库

if gnuplot and options.grid_with_c and options.grid_with_g:

如果提供了 gnuplot 路径,并且需要绘制图形,则调用 redraw 函数绘制图形



    if options.out_pathname:
        result_file.close()
    job_queue.put((WorkerStopToken,None))
    best_param, best_cg  = {}, []
    if best_c != None:
        best_param['c'] = 2.0**best_c
        best_cg += [2.0**best_c]
    if best_g != None:
        best_param['g'] = 2.0**best_g
        best_cg += [2.0**best_g]
    print('{0} {1}'.format(' '.join(map(str,best_cg)), best_rate))

    return best_rate, best_param

 if options.out_pathname:

  • 如果提供了输出路径,就关闭之前打开的文件result_file

job_queue.put((WorkerStopToken, None))

  • 向任务队列中放入停止信号,以停止工作线程

best_param, best_cg = {}, [] 和 print('{0} {1}'.format(' '.join(map(str, best_cg)), best_rate))

  • 输出最佳参数和最佳性能率

return best_rate, best_param  返回最佳性能率和最佳参数

 

11.程序入口函数的定义

 这是一个命令行工具的入口,用于解析命令行参数并调用  find_parameters 函数进行参数搜索

if __name__ == '__main__':

    def exit_with_help():
        print("""\
Usage: grid.py [grid_options] [svm_options] dataset

grid_options :
-log2c {begin,end,step | "null"} : set the range of c (default -5,15,2)
    begin,end,step -- c_range = 2^{begin,...,begin+k*step,...,end}
    "null"         -- do not grid with c
-log2g {begin,end,step | "null"} : set the range of g (default 3,-15,-2)
    begin,end,step -- g_range = 2^{begin,...,begin+k*step,...,end}
    "null"         -- do not grid with g
-v n : n-fold cross validation (default 5)
-svmtrain pathname : set svm executable path and name
-gnuplot {pathname | "null"} :
    pathname -- set gnuplot executable path and name
    "null"   -- do not plot
-out {pathname | "null"} : (default dataset.out)
    pathname -- set output file path and name
    "null"   -- do not output file
-png pathname : set graphic output file path and name (default dataset.png)
-resume [pathname] : resume the grid task using an existing output file (default pathname is dataset.out)
    This is experimental. Try this option only if some parameters have been checked for the SAME data.

svm_options : additional options for svm-train""")
        sys.exit(1)

   
    if len(sys.argv) < 2:
        exit_with_help()
    dataset_pathname = sys.argv[-1]
    options = sys.argv[1:-1]
    try:
        find_parameters(dataset_pathname, options)
    except (IOError,ValueError) as e:
        sys.stderr.write(str(e) + '\n')
        sys.stderr.write('Try "grid.py" for more information.\n')
        sys.exit(1)

if __name__ == '__main__':

  • 这是 Python 中的惯用写法,表示以下代码块将在作为脚本直接执行时运行

def exit_with_help():

  • 定义了一个辅助函数 exit_with_help,用于打印使用帮助信息并退出程序

print('' '' ''\ …'' '' '')和 sys.exit(1)

  •  打印使用帮助信息,并使用 sys.exit(1) 退出程序 

if len(sys.argv) < 2: 和 exit_with_help()

如果命令行参数数量小于 2,则调用 exit_with_help 函数打印使用帮助信息并退出程序 

dataset_pathname = sys.argv[-1] 和 options = sys.argv[1:-1]:

  • 将命令行参数中的最后一个参数(数据集路径)赋值给 dataset_pathname,将除第一个参数和最后一个参数外的其他参数赋值给 options

try: ... except (IOError, ValueError) as e: ...

  • 使用 try...except 结构捕获可能发生的 IOError 和 ValueError 异常
  • 在 try 块中调用 find_parameters 函数,传入数据集路径和其他参数
  • 如果捕获到异常,则将异常信息写入标准错误输出,打印提示信息,并退出程序

总体而言,这段代码实现了一个命令行工具的入口,用于解析命令行参数并调用  find_parameters 函数进行参数搜索。如果命令行参数不符合要求或者执行过程中出现异常,将打印使用帮助信息或错误信息,并退出程序。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/227121.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【开源】基于Vue+SpringBoot的快乐贩卖馆管理系统

项目编号&#xff1a; S 064 &#xff0c;文末获取源码。 \color{red}{项目编号&#xff1a;S064&#xff0c;文末获取源码。} 项目编号&#xff1a;S064&#xff0c;文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、功能模块2.1 数据中心模块2.2 搞笑视频模块2.3 视…

CentOS上安装和配置Apache HTTP服务器

在CentOS系统上安装和配置Apache HTTP服务器可以为您的网站提供可靠的托管环境。Apache是开源的Web服务器软件&#xff0c;具有广泛的支持和强大的功能。下面是在CentOS上安装和配置Apache HTTP服务器的步骤&#xff1a; 步骤一&#xff1a;安装Apache HTTP服务器 打开终端&am…

总结一篇本地idea配合阿里云服务器使用docker

idea打包打镜像发到阿里云服务器 先说一下使用docker desktop软件怎么使用 1.下载docker desktop官网&#xff0c;先注册个账号吧&#xff0c;后面桌面软件登录会用到&#xff08;当然&#xff0c;配合这个软件使用需要科学上网&#xff09; 安装这个要配合wsl使用&#xf…

电脑搜不自己的手机热点,其余热点均可!

一、现象&#xff1a; 之前可正常连接&#xff0c;突然间发现收不到自己的WiFi信号&#xff0c;其余人均可收到。通过重复手机电脑关机、改变热点设置中的频段等方式均没解决&#xff0c;同事电脑和手机可搜索到我的WiFi。 二、问题&#xff1a; WiF驱动程序更新 三&#x…

ingress介绍和ingress通过LoadBalancer暴露服务配置

目录 一.ingress基本原理介绍 1.将原有用于暴露服务和负载均衡的服务的三四层负载均衡变为一个七层负载均衡 2.controller和ingress 3.通过下面这个图可能会有更直观的理解 二.为什么会出现ingress 1.NodePort存在缺点 2.LoadBalancer存在缺点 三.ingress三种暴露服务的…

『TypeScript』从零开始编写你的第一个TypeScript程序

&#x1f4e3;读完这篇文章里你能收获到 了解TypeScript及为什么使用TypeScriptTypeScript的安装过程编写第一个HelloTs程序 文章目录 一、TypeScript简介1. 什么是TypeScript&#xff1f;2. 为什么选择使用TypeScript&#xff1f;2.1 静态类型检查2.2 更好的代码维护性2.3 更…

Prometheus 配置文件和标签 Pmsql

1.Prometheus配置文件 Prometheus可以通过命令行或者配置文件的方式对服务进行配置。 命令行方式一般用于不可变的系统参数配置&#xff0c;例如存储位置、要保留在磁盘和内存中的数据量等&#xff1b;配置文件用于定义与数据动态获取相关的配置选项和文件等内容。命令行方式…

Sql Server关于表的建立、修改、删除

表的创建&#xff1a; &#xff08;1&#xff09;在“对象资源管理器”面板中展开“数据库”节点&#xff0c;可以看到自己创建的数据库&#xff0c;比如Product。展开Product节点&#xff0c;右击“表”节点&#xff0c;在弹出的快捷菜单中选择“新建表”项&#xff0c;进入“…

【keil备忘录】2. stm32 keil仿真时的时间测量功能

配置仿真器Trace内核时钟为单片机实际的内核时钟&#xff0c;需要勾选Enable设置&#xff0c;设置完成后Enable取消勾选也可以&#xff0c;经测试时钟频率配置仍然生效&#xff0c;此处设置为48MHZ: 时间测量时必须打开register窗口&#xff0c;否则可能不会计数 右下角有计…

AMD 发布新芯片MI300,支持训练和运行大型语言模型

AMD 宣布推出 MI300 芯片&#xff0c;其 Ryzen 8040移动处理器将于2024年用于笔记本电脑。 AMD官方网站&#xff1a;AMD ׀ together we advance_AI AMD——美国半导体公司专门为计算机、通信和消费电子行业设计和制造各种创新的微处理器&#xff08;CPU、GPU、主板芯片组、电…

【AIGC】prompt工程从入门到精通

注&#xff1a;本文示例默认“文心大模型3.5”演示&#xff0c;表示为>或w>&#xff08;wenxin)&#xff0c;有时为了对比也用百川2.0展示b>&#xff08;baichuan) 有时候为了模拟错误输出&#xff0c;会用到m>&#xff08;mock)表示&#xff08;因为用的大模型都会…

SLAM算法与工程实践——SLAM基本库的安装与使用(3):Pangolin库

SLAM算法与工程实践系列文章 下面是SLAM算法与工程实践系列文章的总链接&#xff0c;本人发表这个系列的文章链接均收录于此 SLAM算法与工程实践系列文章链接 下面是专栏地址&#xff1a; SLAM算法与工程实践系列专栏 文章目录 SLAM算法与工程实践系列文章SLAM算法与工程实践…

网络管理相关

管理功能分为管理站manager和代理agent两部分。 网络管理&#xff1a; 网络管理系统中&#xff0c;每一个网络节点都包含有一组与管理有关的软件&#xff0c;叫做网络管理实体NME。 管理站的另外一组软件叫做网络管理应用NMA&#xff0c;提供用户接口&#xff0c;根据用户命令显…

http与apache

目录 1.http相关概念 2.http请求的完整过程 3.访问浏览器背后的原理过程 4.动态页面与静态页面区别 静态页面&#xff1a; 动态页面&#xff1a; 5.http协议版本 6.http请求方法 7.HTTP协议报文格式 8.http响应状态码 1xx&#xff1a;提示信息 2xx&#xff1a;成功…

css 十字分割线(含四等分布局)

核心技术 伪类选择器含义li:nth-child(2)第2个 lili:nth-child(n)所有的lili:nth-child(2n)所有的第偶数个 lili:nth-child(2n1)所有的第奇数个 lili:nth-child(-n5)前5个 lili:nth-last-child(-n5)最后5个 lili:nth-child(7n)选中7的倍数 border-right: 3px solid white;borde…

ssh安装和Gitee(码云)源码拉取

文章目录 安装ssh服务注册码云公钥设置码云账户SSH公钥安装git客户端和git-lfs源码获取 安装ssh服务 更新软件源&#xff1a; sudo apt-get update安装ssh服务 sudo apt-get install openssh-server检查ssh是否安装成功 which ssh输出&#xff1a; /usr/bin/ssh启动ssh 服…

AI并行计算:CUDA和ROCm

1 介绍 1.1 CUDA CUDA&#xff08;Compute Unified Device Architecture&#xff09;是Nvidia于2006年推出的一套通用并行计算架构&#xff0c;旨在解决在GPU上的并行计算问题。其易用性和便捷性能够方便开发者方便的进行GPU编程&#xff0c;充分利用GPU的并行能力&#xff0…

PHP对接企业微信

前言 最近在做项目中&#xff0c;要求在后台管理中有企业微信管理的相关功能。相关准备工作&#xff0c;需要准备好企业微信账号&#xff0c;添加自建应用&#xff0c;获得相应功能的权限&#xff0c;以及agentid、secre等。 参考文档&#xff1a; 企业微信开发文档 功能实现 因…

leetcode:1365. 有多少小于当前数字的数字(python3解法)

难度&#xff1a;简单 给你一个数组 nums&#xff0c;对于其中每个元素 nums[i]&#xff0c;请你统计数组中比它小的所有数字的数目。 换而言之&#xff0c;对于每个 nums[i] 你必须计算出有效的 j 的数量&#xff0c;其中 j 满足 j ! i 且 nums[j] < nums[i] 。 以数组形式…

Android : XUI- SimpleImageBanner+BannerItem带标题的轮播图-简单应用

示例图&#xff1a; 1.导入XUI http://t.csdnimg.cn/qgGaN 2.MainActivity.java package com.example.viewpagerbanne;import android.os.Bundle; import android.view.View; import android.widget.Toast; import androidx.appcompat.app.AppCompatActivity; import com.xu…