【Pytorch】Fizz Buzz

在这里插入图片描述

文章目录

  • 1 数据编码
  • 2 网络搭建
  • 3 网络配置,训练
  • 4 结果预测
  • 5 翻车现场

学习参考来自:

  • Fizz Buzz in Tensorflow
  • https://github.com/wmn7/ML_Practice/tree/master/2019_06_10
  • Fizz Buzz in Pytorch

I need you to print the numbers from 1 to 100, except that if the number is divisible by 3 print “fizz”, if it’s divisible by 5 print “buzz”, and if it’s divisible by 15 print “fizzbuzz”.

编程题很简单,我们用 MLP 实现试试

思路,训练集数据101~1024,对其进行某种规则的编码,标签为经分类 one-hot 编码后的标签
测试集,1~100

don’t say so much, show me the code.

1 数据编码

import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data

def binary_encode(i, num_digits):
    """将每个input转换为binary digits(转换为二进制的表示, 最多可是表示2^num_digits)
    :param i:
    :param num_digits:
    :return:
    """
    return np.array([i >> d & 1 for d in range(num_digits)])

编码形式,依次除以 2 0 , 1 , 2 , 3 , . . . 2^{0,1,2,3,...} 20,1,2,3,...,结果按位与 1

m & 1,结果为 0 表示 m 为偶数, 结果为 1 表示 m 为奇数

> > m >> m >>m 右移表示除以 2 m 2^m 2m

第一位就能表示奇偶了,所有数字编码都不一样

eg,101 进行 num_digits=10 编码后结果为 1 0 1 0 0 1 1 0 0 0

步骤

101 / 1 = 101 奇数 1
101 / 2 = 50 偶数 0
101 / 4 = 25 奇数 1
101 / 8 = 12 偶数 0
101 / 16 = 6 偶数 0
101 / 32 = 3 奇数 1
101 / 64 = 1 奇数 1
101 / 128 = 0 偶数 0
101 / 256= 0 偶数 0
101 / 512= 0 偶数 0

标签,0,1,2,3 四个类别

def fizz_buzz_encode(i):
    """将output转换为lebel
    :param i:
    :return:
    """
    if i % 15 == 0:  # fizzbuzz
        return 3
    elif i % 5 == 0:  # buzz
        return 2
    elif i % 3 == 0:  # fizz
        return 1
    else:
        return 0

编码长度设定,数据集 101 ~ 1024

NUM_DIGITS = 10
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, 2**NUM_DIGITS)])  # 101~1024
trY = np.array([fizz_buzz_encode(i) for i in range(101, 2**NUM_DIGITS)])

# print(len(trX), len(trY))  # 923 923
# print(trX[:5])
"""
[[1 0 1 0 0 1 1 0 0 0]
 [0 1 1 0 0 1 1 0 0 0]
 [1 1 1 0 0 1 1 0 0 0]
 [0 0 0 1 0 1 1 0 0 0]
 [1 0 0 1 0 1 1 0 0 0]]
"""
# print(trY[:5])  # [0 1 0 0 3]

2 网络搭建

搭建简单的 MLP 网络

class FizzBuzzModel(nn.Module):
    def __init__(self, in_features, out_classes, hidden_size, n_hidden_layers):
        super(FizzBuzzModel,self).__init__()
        layers = []
        for i in range(n_hidden_layers):
            layers.append(nn.Linear(hidden_size,hidden_size))
            # layers.append(nn.Dropout(0.5))
            layers.append(nn.BatchNorm1d(hidden_size))
            layers.append(nn.ReLU())
        self.inputLayer = nn.Linear(in_features, hidden_size)
        self.relu = nn.ReLU()
        self.layers = nn.Sequential(*layers)  # 重复的搭建隐藏层
        self.outputLayer = nn.Linear(hidden_size, out_classes)

    def forward(self, x):
        x = self.inputLayer(x)
        x = self.relu(x)
        x = self.layers(x)
        out = self.outputLayer(x)
        return out

初始化网络,看看网络结构

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# define the model
simpleModel = FizzBuzzModel(NUM_DIGITS, 4, 150, 3).to(device)
print(simpleModel)
"""
FizzBuzzModel(
  (inputLayer): Linear(in_features=10, out_features=150, bias=True)
  (relu): ReLU()
  (layers): Sequential(
    (0): Linear(in_features=150, out_features=150, bias=True)
    (1): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Linear(in_features=150, out_features=150, bias=True)
    (4): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): Linear(in_features=150, out_features=150, bias=True)
    (7): BatchNorm1d(150, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): ReLU()
  )
  (outputLayer): Linear(in_features=150, out_features=4, bias=True)
)
"""

输入 10, 输出4,隐藏层维度 150,隐藏层重复了 3 次

3 网络配置,训练

定义下超参数,损失函数,优化器,载入数据训练,输出训练精度与损失

# Loss and optimizer
learning_rate = 0.05
criterion = nn.CrossEntropyLoss()
# optimizer = torch.optim.Adam(simpleModel.parameters(), lr=learning_rate)
optimizer = torch.optim.SGD(simpleModel.parameters(), lr=learning_rate)

# 使用batch进行训练
FizzBuzzDataset = Data.TensorDataset(torch.from_numpy(trX).float().to(device),
                                     torch.from_numpy(trY).long().to(device))

loader = Data.DataLoader(dataset=FizzBuzzDataset,
                         batch_size=128*5,
                         shuffle=True)

# 进行训练
simpleModel.train()
epochs = 3000

for epoch in range(1, epochs):
    for step, (batch_x, batch_y) in enumerate(loader):
        out = simpleModel(batch_x)  # 前向传播
        loss = criterion(out, batch_y)  # 计算损失
        optimizer.zero_grad()  # 梯度清零
        loss.backward()  # 反向传播
        optimizer.step()  # 随机梯度下降
    correct = 0
    total = 0
    _, predicted = torch.max(out.data, 1)
    total += batch_y.size(0)
    correct += (predicted == batch_y).sum().item()
    acc = 100*correct/total
    print('Epoch : {:0>4d} | Loss : {:<6.4f} | Train Accuracy : {:<6.2f}%'.format(epoch, loss, acc))

"""
Epoch : 0001 | Loss : 1.5343 | Train Accuracy : 14.63 %
Epoch : 0002 | Loss : 1.9779 | Train Accuracy : 42.58 %
Epoch : 0003 | Loss : 2.4198 | Train Accuracy : 53.41 %
Epoch : 0004 | Loss : 1.7360 | Train Accuracy : 53.41 %
Epoch : 0005 | Loss : 1.3161 | Train Accuracy : 49.73 %
Epoch : 0006 | Loss : 1.4866 | Train Accuracy : 22.75 %
Epoch : 0007 | Loss : 1.3993 | Train Accuracy : 25.57 %
Epoch : 0008 | Loss : 1.2428 | Train Accuracy : 28.49 %
Epoch : 0009 | Loss : 1.1906 | Train Accuracy : 44.31 %
Epoch : 0010 | Loss : 1.1929 | Train Accuracy : 52.44 %
...
Epoch : 2990 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2991 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2992 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2993 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2994 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2995 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2996 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2997 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2998 | Loss : 0.0000 | Train Accuracy : 100.00%
Epoch : 2999 | Loss : 0.0000 | Train Accuracy : 100.00%
"""

训练集上精度是 OK 的,能到 100%,下面看看测试集上的精度

4 结果预测

把 one-hot 标签转化成 fizz buzz 的形式

def fizz_buzz_decode(i, prediction):
    return [str(i), "fizz", "buzz", "fizzbuzz"][prediction]

载入测试集,开始预测

simpleModel.eval()
# 进行预测
testX = np.array([binary_encode(i, NUM_DIGITS) for i in range(1, 101)])
predicts = simpleModel(torch.from_numpy(testX).float().to(device))
# 预测的结果
_, res = torch.max(predicts, 1)
print(res)
"""
tensor([0, 0, 0, 1, 0, 0, 0, 2, 1, 0, 1, 3, 3, 1, 1, 0, 0, 0, 0, 0, 0, 3, 1, 0,
        0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 1, 0, 1, 1, 1, 0,
        0, 0, 0, 0], device='cuda:0')
"""

# 格式的转换
predictions = [fizz_buzz_decode(i, prediction) for (i, prediction) in zip(range(1, 101), res)]
print(predictions)
"""
['1', '2', '3', 'fizz', '5', '6', '7', 'buzz', 'fizz', '10', 'fizz', 'fizzbuzz', 'fizzbuzz', 'fizz', 'fizz', '16', '17', '18', '19', '20', '21', 'fizzbuzz', 'fizz', '24', '25', '26', '27', '28', '29', '30', 'fizz', '32', '33', '34', '35', '36', '37', '38', '39', 'fizz', '41', 'fizz', '43', '44', '45', 'fizz', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', 'fizz', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', 'fizz', 'fizz', 'fizz', 'fizz', 'buzz', 'buzz', 'fizz', '80', '81', '82', '83', '84', '85', '86', '87', 'fizzbuzz', '89', '90', 'fizz', '92', 'fizz', 'fizz', 'fizz', '96', '97', '98', '99', '100']
"""

5 翻车现场

对比下标签

labels = []
for i in range(1, 101):
    if i % 15 == 0:  # fizzbuzz
        labels.append("fizzbuzz")
    elif i % 5 == 0:  # buzz
        labels.append("buzz")
    elif i % 3 == 0:  # fizz
        labels.append("fizz")
    else:
        labels.append(str(i))
print(labels)
print(labels == predictions)

"""
['1', '2', 'fizz', '4', 'buzz', 'fizz', '7', '8', 'fizz', 'buzz', '11', 'fizz', '13', '14', 'fizzbuzz', '16', '17', 'fizz', '19', 'buzz', 'fizz', '22', '23', 'fizz', 'buzz', '26', 'fizz', '28', '29', 'fizzbuzz', '31', '32', 'fizz', '34', 'buzz', 'fizz', '37', '38', 'fizz', 'buzz', '41', 'fizz', '43', '44', 'fizzbuzz', '46', '47', 'fizz', '49', 'buzz', 'fizz', '52', '53', 'fizz', 'buzz', '56', 'fizz', '58', '59', 'fizzbuzz', '61', '62', 'fizz', '64', 'buzz', 'fizz', '67', '68', 'fizz', 'buzz', '71', 'fizz', '73', '74', 'fizzbuzz', '76', '77', 'fizz', '79', 'buzz', 'fizz', '82', '83', 'fizz', 'buzz', '86', 'fizz', '88', '89', 'fizzbuzz', '91', '92', 'fizz', '94', 'buzz', 'fizz', '97', '98', 'fizz', 'buzz']
False
"""

哈哈哈, False 翻车了,尝试了很多次,很难 True

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/230099.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字化转型怎么才能做成功?_光点科技

数字化转型对于现代企业来说是一场必要的革命。它不仅仅是技术的更迭&#xff0c;更是企业战略、文化和运营方式全面升级的体现。一个成功的数字化转型能够使企业更具竞争力、更灵活应对市场变化&#xff0c;并最终实现业务增长和效率提升。那么&#xff0c;数字化转型怎么才能…

JVM常见垃圾回收器

串行垃圾回收器 Serial和Serial Old串行垃圾回收器&#xff0c;是指使用单线程进行垃圾回收&#xff0c;堆内存较小&#xff0c;适合个人电脑 Serial作用于新生代&#xff0c;采用复制算法 Serial Old作用于老年代&#xff0c;采用标记-整理算法 垃圾回收时&#xff0c;只有…

Navicat 技术指引 | 适用于 GaussDB 分布式的数据生成功能

Navicat Premium&#xff08;16.3.3 Windows 版或以上&#xff09;正式支持 GaussDB 分布式数据库。GaussDB 分布式模式更适合对系统可用性和数据处理能力要求较高的场景。Navicat 工具不仅提供可视化数据查看和编辑功能&#xff0c;还提供强大的高阶功能&#xff08;如模型、结…

物联网后端个人第十四周总结

物联网方面进度 1.登陆超时是因为后端运行的端口和前端监听的接口不一样&#xff0c;所以后端也没有报错&#xff0c;将二者修改一致即可 2.登录之后会进行平台的初始化&#xff0c;但是初始化的时候会卡住,此时只需要将路径的IP端口后边的内容去掉即可 3.阅读并完成了jetlinks…

log4j(日志的配置)

日志一般配置在resources的config下面的&#xff0c;并且Util当中的initLogRecord中的initLog&#xff08;&#xff09;方法就是加载这个log4j.properties的. 首先先看log4j.properties的配置文件 log4j.rootLoggerdebug, stdout, Rlog4j.appender.stdoutorg.apache.log4j.Co…

【UE 材质】任务目标点效果

效果 步骤 1. 新建一个工程&#xff0c;创建一个Basic关卡 2. 新建一个材质&#xff0c;这里命名为“M_GoalPoint” 打开“M_GoalPoint”&#xff0c;设置混合模式为“半透明”&#xff0c;勾选“双面” 在材质图表中添加如下节点 此时预览效果如下 继续添加如下节点 此时效果…

iPaaS架构深入探讨

在数字化时代全面来临之际&#xff0c;企业正面临着前所未有的挑战与机遇。技术的迅猛发展与数字化转型正在彻底颠覆各行各业的格局&#xff0c;不断推动着企业迈向新的前程。然而&#xff0c;这一数字化时代亦衍生出一系列复杂而深奥的难题&#xff1a;各异系统之间数据孤岛、…

3D材质编辑:制作被火烧的木头

在线工具推荐&#xff1a; 3D数字孪生场景编辑器 - GLTF/GLB材质纹理编辑器 - 3D模型在线转换 - Three.js AI自动纹理开发包 - YOLO 虚幻合成数据生成器 - 三维模型预览图生成器 - 3D模型语义搜索引擎 当谈到游戏角色的3D模型风格时&#xff0c;有几种不同的风格&#xf…

使用STM32定时器实现精确的时间测量和延时

✅作者简介&#xff1a;热爱科研的嵌入式开发者&#xff0c;修心和技术同步精进&#xff0c; 代码获取、问题探讨及文章转载可私信。 ☁ 愿你的生命中有够多的云翳,来造就一个美丽的黄昏。 &#x1f34e;获取更多嵌入式资料可点击链接进群领取&#xff0c;谢谢支持&#xff01;…

3DCAT+上汽奥迪:打造新零售汽车配置器实时云渲染解决方案

在 5G、云计算等技术飞速发展的加持下&#xff0c;云渲染技术迎来了突飞猛进的发展。在这样的背景下&#xff0c;3DCAT应运而生&#xff0c;成为了业内知名的实时云渲染服务商之一。 交互式3D实时云看车作为云渲染技术的一种使用场景&#xff0c;也逐步成为一种新的看车方式&a…

AWS Remote Control ( Wi-Fi ) on i.MX RT1060 EVK - 3 “编译 NXP i.MX RT1060”( 完 )

此章节叙述如何修改、建构 i.MX RT1060 的 Sample Code“aws_remote_control_wifi_nxp” 1. 点击“Import SDK example(s)” 2. 选择“MIMXRT1062xxxxA”>“evkmimxrt1060”&#xff0c;并确认 SDK 版本后&#xff0c;点击“Next>” 3. 选择“aws_examples”>“aw…

项目优化(异步化)

项目优化&#xff08;异步化&#xff09; 1. 认识异步化 1.1 同步与异步 同步&#xff1a;一件事情做完&#xff0c;再做另外一件事情&#xff0c;不能同时进行其他的任务。异步&#xff1a;不用等一件事故完&#xff0c;就可以做另外一件事情。等第一件事完成时&#xff0c…

处理哈希冲突的常见方法(五种)

1、开放地址法&#xff08;Open Addressing&#xff09;&#xff1a; 线性探测&#xff08;Linear Probing&#xff09;&#xff1a; 当发生冲突时&#xff0c;顺序地查找下一个可用的槽位&#xff0c;直到找到空槽或者整个表被搜索一遍。这个方法的缺点是可能出现“聚簇”&am…

webpack该如何打包

1.我们先创建一个空的大文件夹 2.打开该文件夹的终端 输入npm init -y 2.1.打开该文件夹的终端 2.2在该终端运行 npm init -y 3.安装webpack 3.1打开webpack网址 点击“中文文档” 3.2点击“指南”在点击“起步” 3.3复制基本安装图片画线的代码 4.在一开始的文件夹下在创建一…

【性能测试】Jmeter 配置元件(一):计数器

Jmeter 配置元件&#xff08;一&#xff09;&#xff1a;计数器 在 Jmeter 中&#xff0c;通过函数 ${__counter(,)} 可以实现每次加 1 1 1 的计数效果。但如果步长不为 1 1 1&#xff0c;则要利用到我们的计数器。 函数作用${__counter(,)}计数器&#xff0c;每次加 1${__d…

int 和 Integer 有什么区别,还有 Integer 缓存的实现

✨前言✨   Java本文主要介绍Java int 和 Integer的区别以及Integer 缓存的实现 &#x1f352;欢迎点赞 &#x1f44d; 收藏 ⭐留言评论 &#x1f4dd;私信必回哟&#x1f601; &#x1f352;博主将持续更新学习记录收获&#xff0c;友友们有任何问题可以在评论区留言 文章目…

AI别墅设计

这两年我致力于研发别墅AI自动化设计&#xff0c;包括设计别墅各层的平面图以及导出三维效果图。目的是可以快速生成大量别墅的设计图和效果图让自建房用户可以第一时间挑选自己需要的房型并看到房屋建成后效果&#xff0c;大大提高建筑施工和设计人员的工作效率。 别墅设计包括…

记一次测试环境git翻车经历

本来想拉一个功能分支进行新的功能开发&#xff0c;合并代码发现没有冲突居然有文件被修改了&#xff0c;贸然选择最近的一次回滚提交&#xff0c;没想到不假思索的push -f 导致一部分dev主干的代码不见了。 事故记录 开发分支origin/dev&#xff0c;功能分支file 合并之后发…

pytorch-mask-rcnn 官方

This is a Pytorch implementation 实现 of Mask R-CNN that is in large parts based on Matterports Mask_RCNN. Matterports repository is an implementation on Keras and TensorFlow. The following parts of the README are excerpts 摘录 from the Matterport README. …

电子学会C/C++编程等级考试2021年09月(五级)真题解析

C/C++等级考试(1~8级)全部真题・点这里 第1题:抓牛 农夫知道一头牛的位置,想要抓住它。农夫和牛都位于数轴上,农夫起始位于点N(0<=N<=100000),牛位于点K(0<=K<=100000)。农夫有两种移动方式: 1、从X移动到X-1或X+1,每次移动花费一分钟 2、从X移动到2*X,每…