🍋🍋1.安装thop
安装thop有两种方式。
🏆第一种
pip install thop
🏆第二种
用源码编译安装
- 从官网下载【github】thop安装压缩包
- 下载压缩文件,解压到虚拟环境的site-packages文件下
- 激活进入自己的虚拟环境
- cd到压缩包解压的路径下
- 输入命令python setup.py install,等待安装完成即可
setup.py源码:
#!/usr/bin/env python
import os, sys
import shutil
import datetime
from setuptools import setup, find_packages
from setuptools.command.install import install
readme = open("README.md").read()
fp = open("thop/__version__.py", "r").read()
VERSION = eval(fp.strip().split()[-1])
requirements = [
"torch",
]
# import subprocess
# commit_hash = subprocess.check_output("git rev-parse HEAD", shell=True).decode('UTF-8').rstrip()
# VERSION += "_" + str(int(commit_hash, 16))[:8]
VERSION += "_" + datetime.datetime.now().strftime("%Y%m%d%H%M")[2:]
print(VERSION)
setup(
# Metadata
name="thop",
version=VERSION,
author="Ligeng Zhu",
author_email="ligeng.zhu+github@gmail.com",
url="https://github.com/Lyken17/pytorch-OpCounter/",
description="A tool to count the FLOPs of PyTorch model.",
long_description=readme,
long_description_content_type="text/markdown",
license="MIT",
# Package info
packages=find_packages(exclude=("*test*",)),
#
zip_safe=True,
install_requires=requirements,
# Classifiers
classifiers=[
"Programming Language :: Python :: 3",
],
)
🍉🍉2.通过thop.profile计算网络参数量和计算效率
测试代码:
from torchvision.models import resnet50
from thop import profile
model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
print(macs)
print(params)
代码步骤
- 定义一个自己的网络结构
- 输入一个满足尺寸要求的随机输入
- 随机影像数据输入网络
- 通过profile(model,(input,))进行网络参数量和运行效率算
比如自己自定义网络Test:
from thop import profile
class Test(nn.Module):
def __init__(self, input_size, output_szie):
super(Test, self).__init__()
self.out = nn.Linear(input_size, output_szie)
def forward(self, x):
output = self.out(x)
return output
t = Test(10, 2)
x = torch.randn(4, 10)
flops,params = profile(t, (x,), verbose=False)
print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
print('Params = ' + str(params / 1000 ** 2) + 'M')
🍌🍌3.增加输出结果可读性
为了增加结果的可读性,除了自己通过公式进行计算,将结果转换成Gb,Mb或者其他格式的,官网还提供了参数clever_format,可以通过该参数将结果输出成可读性较友好的结果。使用方法如下:
#all thop.clever_format to give a better format of the output.
from thop import clever_format
macs, params = clever_format([macs, params], "%.3f")
🍎🍎4.常见网络参数量展示
整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷