企業(yè)展示型電商網(wǎng)站模板google關鍵詞規(guī)劃師
Pytorch | 從零構(gòu)建ParNet/Non-Deep Networks對CIFAR10進行分類
- CIFAR10數(shù)據(jù)集
- ParNet
- 架構(gòu)特點
- 優(yōu)勢
- 應用
- ParNet結(jié)構(gòu)代碼詳解
- 結(jié)構(gòu)代碼
- 代碼詳解
- SSE
- ParNetBlock 類
- DownsamplingBlock 類
- FusionBlock 類
- ParNet 類
- 訓練過程和測試結(jié)果
- 代碼匯總
- parnet.py
- train.py
- test.py
前面文章我們構(gòu)建了AlexNet、Vgg、GoogleNet、ResNet、MobileNet、EfficientNet對CIFAR10進行分類:
Pytorch | 從零構(gòu)建AlexNet對CIFAR10進行分類
Pytorch | 從零構(gòu)建Vgg對CIFAR10進行分類
Pytorch | 從零構(gòu)建GoogleNet對CIFAR10進行分類
Pytorch | 從零構(gòu)建ResNet對CIFAR10進行分類
Pytorch | 從零構(gòu)建MobileNet對CIFAR10進行分類
Pytorch | 從零構(gòu)建EfficientNet對CIFAR10進行分類
這篇文章我們來構(gòu)建ParNet(Non-Deep Networks).
CIFAR10數(shù)據(jù)集
CIFAR-10數(shù)據(jù)集是由加拿大高級研究所(CIFAR)收集整理的用于圖像識別研究的常用數(shù)據(jù)集,基本信息如下:
- 數(shù)據(jù)規(guī)模:該數(shù)據(jù)集包含60,000張彩色圖像,分為10個不同的類別,每個類別有6,000張圖像。通常將其中50,000張作為訓練集,用于模型的訓練;10,000張作為測試集,用于評估模型的性能。
- 圖像尺寸:所有圖像的尺寸均為32×32像素,這相對較小的尺寸使得模型在處理該數(shù)據(jù)集時能夠相對快速地進行訓練和推理,但也增加了圖像分類的難度。
- 類別內(nèi)容:涵蓋了飛機(plane)、汽車(car)、鳥(bird)、貓(cat)、鹿(deer)、狗(dog)、青蛙(frog)、馬(horse)、船(ship)、卡車(truck)這10個不同的類別,這些類別都是現(xiàn)實世界中常見的物體,具有一定的代表性。
下面是一些示例樣本:
ParNet
ParNet是一種高效的深度學習網(wǎng)絡架構(gòu)由谷歌研究人員于2021年提出,以下從其架構(gòu)特點、優(yōu)勢及應用等方面進行詳細介紹:
架構(gòu)特點
- 并行子結(jié)構(gòu):ParNet的核心在于其并行的子結(jié)構(gòu)設計。它由多個并行的分支組成,每個分支都包含一系列的卷積層和池化層等操作。這些分支在網(wǎng)絡中同時進行計算,就像多條并行的道路同時運輸信息一樣,大大提高了信息處理的效率。
- 多尺度特征融合:不同分支在不同的尺度上對輸入圖像進行處理,然后將這些多尺度的特征進行融合。例如,一個分支可能專注于提取圖像中的局部細節(jié)特征,而另一個分支則更擅長捕捉圖像的全局上下文信息。通過融合這些不同尺度的特征,ParNet能夠更全面、更準確地理解圖像內(nèi)容。
- 深度可分離卷積:在網(wǎng)絡的卷積操作中,大量使用了深度可分離卷積。這種卷積方式將傳統(tǒng)的卷積操作分解為深度卷積和逐點卷積兩個步驟,大大減少了計算量,同時提高了模型的運行速度,使其更適合在移動設備等資源受限的環(huán)境中應用。
優(yōu)勢
- 高效性:由于其并行結(jié)構(gòu)和深度可分離卷積的使用,ParNet在計算效率上具有很大的優(yōu)勢。它可以在保證模型性能的前提下,大大減少模型的參數(shù)量和計算量,從而實現(xiàn)快速的推理和訓練。
- 靈活性:ParNet的并行子結(jié)構(gòu)和多尺度特征融合方式使其具有很強的靈活性。它可以根據(jù)不同的任務和數(shù)據(jù)集進行調(diào)整和優(yōu)化,輕松適應各種圖像識別和處理任務。
- 可擴展性:該網(wǎng)絡架構(gòu)具有良好的可擴展性,可以方便地增加或減少分支的數(shù)量和深度,以滿足不同的性能需求。
應用
- 圖像分類:在圖像分類任務中,ParNet能夠快速準確地對圖像中的物體進行分類。例如,在CIFAR-10和ImageNet等標準圖像分類數(shù)據(jù)集上,ParNet取得了與現(xiàn)有先進模型相當?shù)臏蚀_率,同時具有更快的推理速度。
- 目標檢測:在目標檢測任務中,ParNet可以有效地檢測出圖像中的目標物體,并確定其位置和類別。通過對多尺度特征的融合和利用,ParNet能夠更好地處理不同大小和形狀的目標物體,提高檢測的準確率和召回率。
- 語義分割:在語義分割任務中,ParNet能夠?qū)D像中的每個像素進行分類,將圖像分割成不同的語義區(qū)域。其多尺度特征融合的特點使得它在處理復雜的場景和物體邊界時具有更好的效果,能夠生成更準確的分割結(jié)果。
ParNet結(jié)構(gòu)代碼詳解
結(jié)構(gòu)代碼
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SSE(nn.Module):def __init__(self, in_channels):super(SSE, self).__init__()self.global_avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(in_channels, in_channels)def forward(self, x):out = self.global_avgpool(x)out = out.view(out.size(0), -1)out = self.fc(out)out = torch.sigmoid(out)out = out.view(out.size(0), out.size(1), 1, 1)return x * outclass ParNetBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ParNetBlock, self).__init__()self.branch1x1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.branch3x3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.sse = SSE(out_channels)def forward(self, x):branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3(x)out = branch1x1 + branch3x3out = self.sse(out)out = F.silu(out)return outclass DownsamplingBlock(nn.Module):def __init__(self, in_channels, out_channels):super(DownsamplingBlock, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.se(out)return outclass FusionBlock(nn.Module):def __init__(self, in_channels, out_channels):super(FusionBlock, self).__init__()self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)def forward(self, x1, x2):x1, x2 = self.conv1x1(x1), self.conv1x1(x2)x1, x2 = self.bn(x1), self.bn(x2)x1, x2 = self.relu(x1), self.relu(x2)x1, x2 = self.se(x1), self.se(x2)out = torch.cat([x1, x2], dim=1)out = self.concat(out)return outclass ParNet(nn.Module):def __init__(self, num_classes):super(ParNet, self).__init__()self.downsampling_blocks = nn.ModuleList([DownsamplingBlock(3, 64),DownsamplingBlock(64, 128),DownsamplingBlock(128, 256),])self.streams = nn.ModuleList([nn.Sequential(ParNetBlock(64, 64),ParNetBlock(64, 64),ParNetBlock(64, 64),DownsamplingBlock(64, 128)),nn.Sequential(ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128)),nn.Sequential(ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256))])self.fusion_blocks = nn.ModuleList([FusionBlock(128, 256),FusionBlock(256, 256)])self.final_downsampling = DownsamplingBlock(256, 1024)self.fc = nn.Linear(1024, num_classes)def forward(self, x):downsampled_features = []for i, downsampling_block in enumerate(self.downsampling_blocks):x = downsampling_block(x)downsampled_features.append(x)stream_features = []for i, stream in enumerate(self.streams):stream_feature = stream(downsampled_features[i])stream_features.append(stream_feature)fused_features = stream_features[0]for i in range(1, len(stream_features)):fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])x = self.final_downsampling(fused_features)x = F.adaptive_avg_pool2d(x, (1, 1))x = x.view(x.size(0), -1)x = self.fc(x)return x
代碼詳解
以下是對上述提供的ParNet
代碼的詳細解釋,這段代碼使用PyTorch
框架構(gòu)建了一個名為ParNet
的神經(jīng)網(wǎng)絡模型,整體結(jié)構(gòu)符合ParNet
網(wǎng)絡架構(gòu)的特點,下面從不同模塊依次進行分析:
SSE
class SSE(nn.Module):def __init__(self, in_channels):super(SSE, self).__init__()self.global_avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(in_channels, in_channels)def forward(self, x):out = self.global_avgpool(x)out = out.view(out.size(0), -1)out = self.fc(out)out = torch.sigmoid(out)out = out.view(out.size(0), out.size(1), 1, 1)return x * out
-
功能概述:
這個類實現(xiàn)了類似Squeeze-and-Excitation(SE)模塊的功能,旨在對輸入特征進行通道維度的重加權(quán),突出重要的通道特征,抑制相對不重要的通道特征。 -
__init__
方法:- 首先通過
nn.AdaptiveAvgPool2d(1)
創(chuàng)建了一個自適應平均池化層,它可以將輸入特征圖在空間維度上壓縮為大小為(1, 1)
的特征圖,也就是將每個通道的特征進行全局平均池化,得到通道維度上的統(tǒng)計信息,無論輸入特征圖的尺寸是多少都可以自適應處理。 - 接著創(chuàng)建了一個全連接層
nn.Linear(in_channels, in_channels)
,其輸入和輸出維度都是in_channels
,目的是學習通道維度上的變換權(quán)重。
- 首先通過
-
forward
方法:- 先將輸入
x
經(jīng)過全局平均池化層得到壓縮后的特征表示out
,然后通過view
操作將其維度調(diào)整為二維形式(批次大小,通道數(shù)),方便后續(xù)全連接層處理。 - 接著將這個特征送入全連接層進行線性變換,再經(jīng)過
sigmoid
激活函數(shù),將輸出值映射到(0, 1)
區(qū)間,得到每個通道對應的權(quán)重。 - 最后將權(quán)重的維度調(diào)整回四維(批次大小,通道數(shù),1,1),并與原始輸入
x
進行逐元素相乘,實現(xiàn)對不同通道特征的重加權(quán)。
- 先將輸入
ParNetBlock 類
class ParNetBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ParNetBlock, self).__init__()self.branch1x1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.branch3x3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.sse = SSE(out_channels)def forward(self, x):branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3(x)out = branch1x1 + branch3x3out = self.sse(out)out = F.silu(out)return out
-
功能概述:
該類定義了ParNet
中的一個基礎并行塊結(jié)構(gòu),包含兩個并行分支(1x1
卷積分支和3x3
卷積分支)以及一個SSE
模塊,用于提取和融合特征,并進行通道重加權(quán)和非線性激活。 -
__init__
方法:- 構(gòu)建了兩個并行分支,
branch1x1
是一個由1x1
卷積層、批歸一化層和ReLU
激活函數(shù)組成的序列,1x1
卷積主要用于調(diào)整通道維度,同時可以融合不同通道間的信息,且計算量相對較小。 branch3x3
同樣是由3x3
卷積層(帶有合適的填充保證特征圖尺寸不變)、批歸一化層和ReLU
激活函數(shù)組成,3x3
卷積能夠捕捉局部空間特征信息。- 最后實例化了一個
SSE
模塊,用于后續(xù)對融合后的特征進行通道維度的重加權(quán)。
- 構(gòu)建了兩個并行分支,
-
forward
方法:- 首先將輸入
x
分別送入兩個并行分支進行處理,得到兩個分支的輸出branch1x1
和branch3x3
,然后將它們對應元素相加進行特征融合。 - 接著把融合后的特征送入
SSE
模塊進行通道重加權(quán),最后使用F.silu
(也就是swish
函數(shù))激活函數(shù)對結(jié)果進行非線性激活,并返回處理后的特征。
- 首先將輸入
DownsamplingBlock 類
class DownsamplingBlock(nn.Module):def __init__(self, in_channels, out_channels):super(DownsamplingBlock, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.se(out)return out
-
功能概述:
用于對輸入特征圖進行下采樣操作,同時融合了批歸一化、非線性激活以及類似SE
的通道重加權(quán)功能,以減少特征圖的空間尺寸并提取更抽象的特征。 -
__init__
方法:
創(chuàng)建了一個3x3
卷積層,其步長設置為2
,配合合適的填充,在進行卷積操作時可以實現(xiàn)特征圖在空間維度上長寬各減半的下采樣效果,同時調(diào)整通道維度到out_channels
。還定義了批歸一化層、ReLU
激活函數(shù)以及一個SSE
模塊。 -
forward
方法:
按照順序依次將輸入x
經(jīng)過卷積層、批歸一化層、ReLU
激活函數(shù)進行處理,然后再通過SSE
模塊進行通道重加權(quán),最終返回下采樣并處理后的特征圖。
FusionBlock 類
class FusionBlock(nn.Module):def __init__(self, in_channels, out_channels):super(FusionBlock, self).__init__()self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)def forward(self, x1, x2):x1, x2 = self.conv1x1(x1), self.conv1x1(x2)x1, x2 = self.bn(x1), self.bn(x2)x1, x2 = self.relu(x1), self.relu(x2)x1, x2 = self.se(x1), self.se(x2)out = torch.cat([x1, x2], dim=1)out = self.concat(out)return out
-
功能概述:
該類用于融合不同分支或不同階段的特征,通過一系列操作包括調(diào)整通道維度、批歸一化、激活以及通道重加權(quán),然后將兩個特征在通道維度上進行拼接并進一步融合。 -
__init__
方法:- 首先創(chuàng)建了
1x1
卷積層,步長設置為2
,用于對輸入的兩個特征分別進行通道維度調(diào)整以及下采樣操作(特征圖空間尺寸減半)。 - 接著定義了批歸一化層、
ReLU
激活函數(shù)以及SSE
模塊,用于對下采樣后的特征進行處理。還創(chuàng)建了一個1x1
卷積層concat
,用于將拼接后的特征進一步融合為指定的通道維度。
- 首先創(chuàng)建了
-
forward
方法:
分別對輸入的兩個特征x1
和x2
依次進行1x1
卷積、批歸一化、ReLU
激活以及SSE
模塊的處理,然后將它們在通道維度上進行拼接(torch.cat
操作,維度dim=1
表示按通道維度拼接),最后通過concat
卷積層將拼接后的特征融合為指定的通道維度,并返回融合后的特征。
ParNet 類
class ParNet(nn.Module):def __init__(self, num_classes):super(ParNet, self).__init__()self.downsampling_blocks = nn.ModuleList([DownsamplingBlock(3, 64),DownsamplingBlock(64, 128),DownsamplingBlock(128, 256),])self.streams = nn.ModuleList([nn.Sequential(ParNetBlock(64, 64),ParNetBlock(64, 64),ParNetBlock(64, 64),DownsamplingBlock(64, 128)),nn.Sequential(ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128)),nn.Sequential(ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256))])self.fusion_blocks = nn.ModuleList([FusionBlock(128, 256),FusionBlock(256, 256)])self.final_downsampling = DownsamplingBlock(256, 1024)self.fc = nn.Linear(1024, num_classes)def forward(self, x):downsampled_features = []for i, downsampling_block in enumerate(self.downsampling_blocks):x = downsampling_block(x)downsampled_features.append(x)stream_features = []for i, stream in enumerate(self.streams):stream_feature = stream(downsampled_features[i])stream_features.append(stream_feature)fused_features = stream_features[0]for i in range(1, len(stream_features)):fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])x = self.final_downsampling(fused_features)x = F.adaptive_avg_pool2d(x, (1, 1))x = x.view(x.size(0), -1)x = self.fc(x)return x
-
功能概述:
這是整個ParNet
網(wǎng)絡的定義類,整合了前面定義的各個模塊,構(gòu)建出完整的網(wǎng)絡結(jié)構(gòu),包括下采樣、并行分支處理、特征融合以及最后的分類全連接層等部分,能夠接收輸入圖像數(shù)據(jù)并輸出對應的分類預測結(jié)果。 -
__init__
方法:downsampling_blocks
:通過nn.ModuleList
創(chuàng)建了一個包含三個下采樣塊的列表,用于對輸入圖像依次進行下采樣,將圖像的空間尺寸逐步縮小,同時增加通道數(shù),從最初的3
通道(對應RGB圖像)逐步變?yōu)?code>64、128
、256
通道。streams
:同樣是nn.ModuleList
,定義了三個并行的流(stream),每個流由多個ParNetBlock
和一個DownsamplingBlock
組成,不同流在不同的特征圖尺度和通道維度上進行特征提取和處理,每個流內(nèi)部的ParNetBlock
用于提取和融合局部特征,最后的DownsamplingBlock
用于進一步下采樣。fusion_blocks
:也是nn.ModuleList
,包含兩個特征融合塊,用于融合不同流的特征,將各個流提取到的不同層次的特征進行融合,以綜合利用多尺度信息。final_downsampling
:定義了一個下采樣塊,用于對融合后的特征再進行一次下采樣,將通道數(shù)提升到1024
,進一步提取更抽象的全局特征。fc
:創(chuàng)建了一個全連接層,用于將最終提取到的特征映射到指定的類別數(shù)量num_classes
,實現(xiàn)圖像分類任務的輸出。
-
forward
方法:- 首先,通過循環(huán)將輸入
x
依次經(jīng)過各個下采樣塊進行下采樣,并將每次下采樣后的特征保存到downsampled_features
列表中,得到不同階段下采樣后的特征圖。 - 接著,針對每個流,將對應的下采樣后的特征圖送入流中進行處理,每個流內(nèi)部的模塊會進一步提取和融合特征,得到每個流輸出的特征,并保存在
stream_features
列表中。 - 然后,先取第一個流的特征作為初始的融合特征,再通過循環(huán)依次使用特征融合塊將其他流的特征與已有的融合特征進行融合,不斷更新融合特征。
- 之后,將融合后的特征送入最后的下采樣塊進行進一步下采樣處理。
- 再通過自適應平均池化
F.adaptive_avg_pool2d
將特征圖在空間維度上壓縮為(1, 1)
大小,然后使用view
操作將其展平為二維向量。 - 最后將展平后的特征送入全連接層進行分類預測,返回最終的分類結(jié)果。
- 首先,通過循環(huán)將輸入
總體而言,這段代碼構(gòu)建了一個符合ParNet
架構(gòu)特點的神經(jīng)網(wǎng)絡模型,通過多個模塊的組合實現(xiàn)了高效的特征提取、融合以及分類功能,可應用于圖像分類等相關任務。
訓練過程和測試結(jié)果
訓練過程損失函數(shù)變化曲線:
訓練過程準確率變化曲線:
測試結(jié)果:
代碼匯總
項目github地址
項目結(jié)構(gòu):
|--data
|--models|--__init__.py|-parnet.py|--...
|--results
|--weights
|--train.py
|--test.py
parnet.py
import torch
import torch.nn as nn
import torch.nn.functional as Fclass SSE(nn.Module):def __init__(self, in_channels):super(SSE, self).__init__()self.global_avgpool = nn.AdaptiveAvgPool2d(1)self.fc = nn.Linear(in_channels, in_channels)def forward(self, x):out = self.global_avgpool(x)out = out.view(out.size(0), -1)out = self.fc(out)out = torch.sigmoid(out)out = out.view(out.size(0), out.size(1), 1, 1)return x * outclass ParNetBlock(nn.Module):def __init__(self, in_channels, out_channels):super(ParNetBlock, self).__init__()self.branch1x1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.branch3x3 = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),nn.BatchNorm2d(out_channels),nn.ReLU(inplace=True))self.sse = SSE(out_channels)def forward(self, x):branch1x1 = self.branch1x1(x)branch3x3 = self.branch3x3(x)out = branch1x1 + branch3x3out = self.sse(out)out = F.silu(out)return outclass DownsamplingBlock(nn.Module):def __init__(self, in_channels, out_channels):super(DownsamplingBlock, self).__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)def forward(self, x):out = self.conv(x)out = self.bn(out)out = self.relu(out)out = self.se(out)return outclass FusionBlock(nn.Module):def __init__(self, in_channels, out_channels):super(FusionBlock, self).__init__()self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2, bias=False)self.bn = nn.BatchNorm2d(out_channels)self.relu = nn.ReLU(inplace=True)self.se = SSE(out_channels)self.concat = nn.Conv2d(out_channels * 2, out_channels, kernel_size=1, bias=False)def forward(self, x1, x2):x1, x2 = self.conv1x1(x1), self.conv1x1(x2)x1, x2 = self.bn(x1), self.bn(x2)x1, x2 = self.relu(x1), self.relu(x2)x1, x2 = self.se(x1), self.se(x2)out = torch.cat([x1, x2], dim=1)out = self.concat(out)return outclass ParNet(nn.Module):def __init__(self, num_classes):super(ParNet, self).__init__()self.downsampling_blocks = nn.ModuleList([DownsamplingBlock(3, 64),DownsamplingBlock(64, 128),DownsamplingBlock(128, 256),])self.streams = nn.ModuleList([nn.Sequential(ParNetBlock(64, 64),ParNetBlock(64, 64),ParNetBlock(64, 64),DownsamplingBlock(64, 128)),nn.Sequential(ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128),ParNetBlock(128, 128)),nn.Sequential(ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256),ParNetBlock(256, 256))])self.fusion_blocks = nn.ModuleList([FusionBlock(128, 256),FusionBlock(256, 256)])self.final_downsampling = DownsamplingBlock(256, 1024)self.fc = nn.Linear(1024, num_classes)def forward(self, x):downsampled_features = []for i, downsampling_block in enumerate(self.downsampling_blocks):x = downsampling_block(x)downsampled_features.append(x)stream_features = []for i, stream in enumerate(self.streams):stream_feature = stream(downsampled_features[i])stream_features.append(stream_feature)fused_features = stream_features[0]for i in range(1, len(stream_features)):fused_features = self.fusion_blocks[i - 1](fused_features, stream_features[i])x = self.final_downsampling(fused_features)x = F.adaptive_avg_pool2d(x, (1, 1))x = x.view(x.size(0), -1)x = self.fc(x)return x
train.py
test.py