RegNet模型案例
1、RegNet研究概述
论文 :Designing Network Design Spaces
研究背景
**RegNet(Designing Network Design Spaces)**在这项工作中,作者提出了一种新的网络设计范式,它结合了人工设计和NAS的优点。作者不是专注于设计单个网络实例,而是设计参数化网络总体的设计空间。就像手工设计一样,目标是可解释性,并发现描述网络的通用设计原则,这些原则简单、工作良好,并且可以跨设置进行推广。与NAS一样,作者的目标是利用半自动化的过程来帮助实现这些目标。采用的一般策略是逐步设计初始的、相对不受限制的设计空间的简化版本,同时保持或提高其质量。整个过程类似于手工设计,提升到群体水平,并通过网络设计空间的分布估计进行指导。作为这个范例的测试平台,重点是在假设包括VGG、ResNet和ResNeXt在内的标准模型族的情况下探索网络结构(例如宽度、深度、组等)。作者从一个相对不受约束的设计空间开始,称之为AnyNet(例如,宽度和深度在不同阶段自由变化),到达一个由简单的“规则”网络组成的低维设计空间,作者称之为RegNet。RegNet设计空间的核心很简单:stage的宽度和深度由量化的线性函数决定。与AnyNet相比,RegNet设计空间的模型更简单,更容易解释,好模型的集中度更高。
论文实验结果
RegNet网络和当时最先进的不同体量网络在ImageNet数据集上进行了比较。在轻量级网络的对比结果中,体量大致相当的RegNetY-600MF取得了第二名的成绩,仅次于最佳结果AMOEBANET-C,如下图所示。
在全体量网络的对比中,RegNet与当时的最先进网络EfficientNet进行了比较,结果显示在二者的体量都较小时,EfficientNet优于RegNet,但体量较大时,RegNet则超越了EfficientNet的结果,如图所示。
2、网络的基本模块结构与源码分析
RegNet构建块
RegNet网络并不特意手工设计新的网络构造块来提升网络性能,而是将网络构造块的一些参数作为搜索内容,通过实验结果来发现能让这些参数产生良好性能网络的设计规律。以RegNetX网络为例,其X构建块如上图所示,是基于标准的残差瓶颈构建块提取一些结构参数而来。
最终,总结搜索得到的参数规律,将其以公式约束和表示,就能够计算得到不同体量、规格的RegNet网络。
源码分析
这里我们采用MindSpore框架来实现RegNet构建块结构的计算和生成,在实现的过程中参考了facebookresearch/pycls库的写法,下面简单介绍实现过程。
首先实现RegNet构建块结构的计算,根据论文的描述,这一步主要分为初步计算出RegNet构建块结构参数的值和调整参数的适应性。
# 根据传入的参数值初步计算得到构建块参数
def generate_regnet(w_a, w_0, w_m, d, q=8):
"""Generates per stage widths and depths from RegNet parameters."""
assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0
ws_cont = np.arange(d) * w_a + w_0
ks = np.round(np.log(ws_cont / w_0) / np.log(w_m))
ws_all = w_0 * np.power(w_m, ks)
ws_all = np.round(np.divide(ws_all, q)).astype(int) * q
ws, ds = np.unique(ws_all, return_counts=True)
num_stages, total_stages = len(ws), ks.max() + 1
ws, ds, ws_all, ws_cont = (x.tolist() for x in (ws, ds, ws_all, ws_cont))
return ws, ds, num_stages, total_stages, ws_all, ws_cont
# 调整计算出的参数的适应性,防止出现不合理的参数组合等
def adjust_block_compatibility(ws, bs, gs):
"""Adjusts the compatibility of widths, bottlenecks, and groups."""
assert len(ws) == len(bs) == len(gs)
assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs))
assert all(b < 1 or b % 1 == 0 for b in bs)
vs = [int(max(1, w * b)) for w, b in zip(ws, bs)]
gs = [int(min(g, v)) for g, v in zip(gs, vs)]
ms = [np.lcm(g, int(b)) if b > 1 else g for g, b in zip(gs, bs)]
vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)]
ws = [int(v / b) for v, b in zip(vs, bs)]
assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs))
return ws, bs, gs
# 合并两个计算过程
def generate_regnet_full(w_a, w_0, w_m, d, stride, bot_mul, group_w):
"""Generates per stage ws, ds, gs, bs, and ss from RegNet cfg."""
ws, ds = generate_regnet(w_a, w_0, w_m, d)[0:2]
ss = [stride for _ in ws]
bs = [bot_mul for _ in ws]
gs = [group_w for _ in ws]
ws, bs, gs = adjust_block_compatibility(ws, bs, gs)
return ws, ds, ss, bs, gs
根据这些计算出的参数,就可以实现构建具体的、结构不同的RegNet网络构建块:
# 根据参数构建构造块,组装成网络
class RegNet(AnyNet):
"""RegNet model."""
@staticmethod
def regnet_get_params(w_a, w_0, w_m, d, stride, bot_mul, group_w, stem_type, stem_w, block_type, head_w,
num_classes, se_r):
"""Get AnyNet parameters that correspond to the RegNet."""
ws, ds, ss, bs, gs = generate_regnet_full(w_a, w_0, w_m, d, stride, bot_mul, group_w)
return {
"stem_type": stem_type,
"stem_w": stem_w,
"block_type": block_type,
"depths": ds,
"widths": ws,
"strides": ss,
"bot_muls": bs,
"group_ws": gs,
"head_w": head_w,
"se_r": se_r,
"num_classes": num_classes,
}
def __init__(self, w_a, w_0, w_m, d, group_w, stride=2, bot_mul=1.0, stem_type='simple_stem_in', stem_w=32,
block_type='res_bottleneck_block', head_w=0, num_classes=1000, se_r=0.0, in_channels=3):
params = RegNet.regnet_get_params(w_a, w_0, w_m, d, stride, bot_mul, group_w, stem_type, stem_w, block_type,
head_w, num_classes, se_r)
super(RegNet, self).__init__(params['depths'], params['stem_type'], params['stem_w'], params['block_type'],
params['widths'], params['strides'], params['bot_muls'], params['group_ws'],
params['head_w'], params['num_classes'], params['se_r'], in_channels)
3、RegNet网络结构与源码分析
整体网络结构
RegNet的整体结构如下图所示:
对上图做如下说明:
1)整体上看,RegNet采用的依然是类似ResNet的网络结构,整个网络可以分为stem、body和head三个部分。
2)stem和head部分在本文没有详细讨论,只着重针对于其中的body部分结构进行实验。在ResNet中,网络的body又可以进一步划分为几个stage,而每个stage则由一些block组成。
3)RegNet的body部分的具体结构即由公式计算结果给出,例如stage内包含的构建块的数量,每个构建块的具体计算通道数等。
4)网络的stem和head可以被更换,构建块的结构也有几种选择,比如采用一般残差构建块或者瓶颈残差构建块等。此外,对于采用了SE模块的RegNet版本成为RegNetY,否则则成为RegNetX。
源码分析
随后介绍实现RegNet网络的剩余工作,计算出了网络的具体结构参数后,剩下的工作还有根据具体的参数生成构建块,并且组装stage、stem和head:
# 生成head
class AnyHead(nn.Cell):
"""AnyNet head: optional conv, AvgPool, 1x1."""
def __init__(self, w_in, head_width, num_classes):
super(AnyHead, self).__init__()
self.head_width = head_width
if head_width > 0:
self.conv = conv2d(w_in, head_width, 1)
self.bn = norm2d(head_width)
self.af = activation()
w_in = head_width
self.avg_pool = gap2d()
self.fc = linear(w_in, num_classes, bias=True)
def construct(self, x):
x = self.af(self.bn(self.conv(x))) if self.head_width > 0 else x
x = self.avg_pool(x)
x = self.fc(x)
return x
# 生成stage
class AnyStage(nn.Cell):
"""AnyNet stage (sequence of blocks w/ the same output shape)."""
def __init__(self, w_in, w_out, stride, d, block_fun, params):
super(AnyStage, self).__init__()
self.blocks = nn.CellList()
for _ in range(d):
block = block_fun(w_in, w_out, stride, params)
self.blocks.append(block)
stride, w_in = 1, w_out
def construct(self, x):
for block in self.blocks:
x = block(x)
return x
# 组装网络
class AnyNet(nn.Cell):
"""AnyNet model."""
@staticmethod
def anynet_get_params(depths, stem_type, stem_w, block_type, widths, strides, bot_muls, group_ws, head_w,
num_classes, se_r):
nones = [None for _ in depths]
return {
"stem_type": stem_type,
"stem_w": stem_w,
"block_type": block_type,
"depths": depths,
"widths": widths,
"strides": strides,
"bot_muls": bot_muls if bot_muls else nones,
"group_ws": group_ws if group_ws else nones,
"head_w": head_w,
"se_r": se_r,
"num_classes": num_classes,
}
def __init__(self, depths, stem_type, stem_w, block_type, widths, strides, bot_muls, group_ws, head_w, num_classes,
se_r, in_channels):
super(AnyNet, self).__init__()
p = AnyNet.anynet_get_params(depths, stem_type, stem_w, block_type, widths, strides, bot_muls, group_ws, head_w,
num_classes, se_r)
stem_fun = get_stem_fun(p["stem_type"])
block_fun = get_block_fun(p["block_type"])
self.stem = stem_fun(in_channels, p["stem_w"])
prev_w = p["stem_w"]
keys = ["depths", "widths", "strides", "bot_muls", "group_ws"]
self.stages = nn.CellList()
for i, (d, w, s, b, g) in enumerate(zip(*[p[k] for k in keys])):
params = {"bot_mul": b, "group_w": g, "se_r": p["se_r"]}
stage = AnyStage(prev_w, w, s, d, block_fun, params)
self.stages.append(stage)
prev_w = w
self.head = AnyHead(prev_w, p["head_w"], p["num_classes"])
def construct(self, x):
x = self.stem(x)
for module in self.stages:
x = module(x)
x = self.head(x)
return x
4、与PyTorch实现的差异
网络搭建所用算子
PyTorch |
MindSpore |
差异 |
torch.nn.functional.adaptive_avg_pool2d |
mindspore.nn.AvgPool2d |
前者为函数式编程,且池化方式为自适应平均池化,MindSpore为全局平均池化 |
torch.nn.Conv2d |
mindspore.nn.Conv2D |
卷积算子差异 |
torch.nn.BatchNorm2d |
mindspore.nn.BatchNorm2d |
BN算子差异 |
torch.nn.MaxPool2d |
mindspore.nn.MaxPool2d |
MaxPool2d算子差异 |
torch.nn.Linear |
mindspore.nn.Dense |
全连接层差异 |
参考资料
1.facebookresearch/pycls
2.PyTorch与MindSpore API映射表