1、基于卷积神经网络的调制分类的原理及流程
基于卷积神经网络(CNN)的调制分类是一种常见的信号处理任务,用于识别或分类不同调制方式的信号。下面是基于CNN的调制分类的原理和流程:
原理:
- CNN是一种深度学习模型,通过卷积层、池化层和全连接层等结构来提取数据中的特征。在调制分类任务中,CNN可以学习到调制信号的特征以区分不同的调制方式。
- 输入到CNN模型的数据是经过预处理和特征提取后的信号样本,通常是时域信号或频域信号。CNN将这些信号作为输入,并通过网络中的不同层来提取特征并完成调制分类任务。
流程:
-
数据准备:准备好用于训练和测试的信号样本数据集,每个样本包含一个已知调制方式的信号。
-
数据预处理:对信号数据进行预处理,可能包括归一化、降噪、平滑处理等,以确保数据质量。
-
数据特征提取:将信号数据转换为适合CNN输入的格式,例如在时域或频域下进行信号特征提取,将其转换为矩阵形式。
-
构建CNN模型:定义CNN模型的结构,包括卷积层、池化层、激活函数层和全连接层等。可以根据具体需求自定义网络结构。
-
模型训练:使用训练集数据对CNN模型进行训练,通过反向传播算法不断调整模型参数以使模型输出尽可能接近真实标签。
-
模型评估:使用测试集数据评估训练好的模型性能,包括准确率、召回率等指标,对模型进行优化和调整。
-
模型应用:将训练好的CNN模型用于未知信号的调制分类,通过模型预测得到信号的调制方式。
-
参数调优:根据模型评估结果,调整模型结构、超参数等进行优化,以提高调制分类的准确性和性能。
在Matlab中,可以使用深度学习工具箱等相关工具进行CNN模型的搭建和训练
2、基于卷积神经网络的调制分类的说明
使用卷积神经网络 (CNN) 进行调制分类
生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类
3、使用 CNN 预测调制类型
1)调制数据类型
二相相移键控 (BPSK)
四相相移键控 (QPSK)
八相相移键控 (8-PSK)
十六相正交调幅 (16-QAM)
六十四相正交调幅 (64-QAM)
四相脉冲振幅调制 (PAM4)
高斯频移键控 (GFSK)
连续相位频移键控 (CPFSK)
广播 FM (B-FM)
双边带振幅调制 (DSB-AM)
单边带振幅调制 (SSB-AM)
2)实现代码
modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ...
"16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
"B-FM", "DSB-AM", "SSB-AM"]));
3)加载训练网络代码
load trainedModulationClassificationNetwork
trainedNet
trainedNet =
dlnetwork with properties:
Layers: [19×1 nnet.cnn.layer.Layer]
Connections: [18×2 table]
Learnables: [22×3 table]
State: [10×3 table]
InputNames: {'Input Layer'}
OutputNames: {'SoftMax'}
Initialized: 1
View summary with summary.
4、加载训练的网络
1)说明
经过训练的 CNN 接受 1024 个通道减损采样,并预测每个帧的调制类型
生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。
以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。
randi:生成随机位
pammod (Communications Toolbox):PAM4 调制位
rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器
filter:脉冲确定符号的形状
comm.RicianChannel (Communications Toolbox):应用莱斯多径通道
comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移
interp1:应用时钟偏移引起的计时漂移
awgn (Communications Toolbox):添加 AWGN
2)实现代码
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));
% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
'SampleRate', fs, ...
'PathDelays', [0 1.8 3.4] / 200e3, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4);
frequencyShifter = comm.PhaseFrequencyOffset(...
'SampleRate', fs);
% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);
% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;
% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);
% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);
% Add noise
rx = awgn(outTimeDrift,SNR,0);
% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);
% Classification
scores1 = predict(trainedNet,unknownFrames);
prediction1 = scores2label(scores1,modulationTypes);
3)返回分类器预测
prediction1
prediction1 = 7×1 categorical
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
PAM4
4) 分类器还返回一个包含每一帧分数的向量
代码
helperModClassPlotScores(scores1,modulationTypes)
视图效果
5、生成用于训练的波形
1)说明1
为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。
网络训练阶段使用训练和验证帧
使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。
2)代码实现
trainNow = false;
if trainNow == true
numFramesPerModType = 10000;
else
numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;
sps = 8; % Samples per symbol
spf = 1024; % Samples per frame
fs = 200e3; % Sample rate
fc = [902e6 100e6]; % Center frequencies
3)说明2
创建通道减损
让每帧通过通道并具有
-
AWGN
-
莱斯多径衰落
-
时钟偏移,导致中心频率偏移和采样时间漂移
由于本示例中的网络基于单个帧作出决定,因此每个帧必须通过独立的通道。
AWGN
通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道。
莱斯多径
通道使用 comm.RicianChannel (Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。假设延迟分布为 [0 1.8 3.4] 个样本,对应的平均路径增益为 [0 -2 -10] dB。K 因子为 4,最大多普勒频移为 4 Hz,等效于 902 MHz 的步行速度。使用以下设置实现通道。
时钟偏移
时钟偏移是发射机和接收机的内部时钟源不准确造成的。
代码
maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);
4)说明3
频率偏移
基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移。使用 comm.PhaseFrequencyOffset (Communications Toolbox) 实现通道。
采样率偏移
基于时钟偏移因子 C,对每帧进行采样率偏移。使用 interp1 函数实现通道,以 C×fs 的新速率对帧进行重新采样。
合并后的通道
使用 helperModClassTestChannel 对象对帧应用所有三种通道减损。
代码
channel = helperModClassTestChannel(...
'SampleRate', fs, ...
'SNR', SNR, ...
'PathDelays', [0 1.8 3.4] / fs, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4, ...
'MaximumClockOffset', 5, ...
'CenterFrequency', 902e6)
channel =
helperModClassTestChannel with properties:
SNR: 30
CenterFrequency: 902000000
SampleRate: 200000
PathDelays: [0 9.0000e-06 1.7000e-05]
AveragePathGains: [0 -2 -10]
KFactor: 4
MaximumDopplerShift: 4
MaximumClockOffset: 5
5)波形生成
说明
创建一个循环,它为每种调制类型生成通道减损的帧并将这些帧及其对应标签存储在 MAT 文件中。通过将数据保存到文件中,您无需每次运行此示例时都生成数据。您还可以更高效地共享数据。
从每帧的开头删除随机数量的样本,以去除瞬变并确保帧相对于符号边界具有随机起点。
代码
rng(12)
tic
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)
fileNameRoot = "frame";
% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
if length(files) == numModulationTypes*numFramesPerModType
dataFilesExist = true;
end
end
if ~dataFilesExist
disp("Generating data and saving in data files...")
[success,msg,msgID] = mkdir(dataDirectory);
if ~success
error(msgID,msg)
end
for modType = 1:numModulationTypes
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Generating %s frames\n', ...
elapsedTime, modulationTypes(modType))
label = modulationTypes(modType);
numSymbols = (numFramesPerModType / sps);
dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
% Analog modulation types use a center frequency of 100 MHz
channel.CenterFrequency = 100e6;
else
% Digital modulation types use a center frequency of 902 MHz
channel.CenterFrequency = 902e6;
end
for p=1:numFramesPerModType
% Generate random data
x = dataSrc();
% Modulate
y = modulator(x);
% Pass through independent channels
rxSamples = channel(y);
% Remove transients from the beginning, trim to size, and normalize
frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
% Save data file
fileName = fullfile(dataDirectory,...
sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
save(fileName,"frame","label")
end
end
else
disp("Data files exist. Skip data generation.")
end
Generating data and saving in data files...
00:00:09 - Generating 16QAM frames
00:00:11 - Generating 64QAM frames
00:00:13 - Generating 8PSK frames
00:00:15 - Generating B-FM frames
00:00:17 - Generating BPSK frames
00:00:20 - Generating CPFSK frames
00:00:22 - Generating DSB-AM frames
00:00:24 - Generating GFSK frames
00:00:26 - Generating PAM4 frames
00:00:28 - Generating QPSK frames
00:00:30 - Generating SSB-AM frames
6)效果显示
实虚部振幅代码
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)
视图效果
帧代码
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)
视图效果
7)创建数据存储代码
frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);
8) 拆分为训练、验证和测试代码
splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);
9) 将数据导入内存代码
% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDS, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
validFrames = transform(validDS, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);
% Read the training and validation labels into the memory
trainLabels = transform(trainDS, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDS, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);
6、训练 CNN
1)说明
使用的 CNN 由五个卷积层和一个全连接层组成。除最后一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层。在最后一个卷积层中,最大池化层被一个全局平均池化层取代。输出层具有 softmax 激活。
2)实现代码
modClassNet = helperModClassCNN(modulationTypes,sps,spf);
3)配置网络代码
maxEpochs = 20;
miniBatchSize = 1024;
trainingPlots = "none";
metrics = [];
verbose = true;
validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize);
options = trainingOptions('sgdm', ...
InitialLearnRate = 3e-1, ...
MaxEpochs = maxEpochs, ...
MiniBatchSize = miniBatchSize, ...
Shuffle = 'every-epoch', ...
Plots = trainingPlots, ...
Verbose = verbose, ...
ValidationData = {rxValidFrames,rxValidLabels}, ...
ValidationFrequency = validationFrequency, ...
ValidationPatience = 5, ...
Metrics = metrics, ...
LearnRateSchedule = 'piecewise', ...
LearnRateDropPeriod = 6, ...
LearnRateDropFactor = 0.75, ...
OutputNetwork='best-validation-loss');
4)训练网络代码
if trainNow == true
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Training the network\n', elapsedTime)
trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options);
else
load trainedModulationClassificationNetwork
end
5)训练结果评估代码
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
% Read the test frames into the memory
testFrames = transform(testDS, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);
% Read the test labels into the memory
testLabels = transform(testDS, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);
scores = predict(trainedNet,cat(3,rxTestFrames{:}));
rxTestPred = scores2label(scores,modulationTypes);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")
7、使用 SDR 进行测试
1)说明
使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。要执行此测试,您必须有专用的 SDR 用于发送和接收。
2)代码实现
radioPlatform = "ADALM-PLUTO";
switch radioPlatform
case "ADALM-PLUTO"
if helperIsPlutoSDRInstalled() == true
radios = findPlutoRadio();
if length(radios) >= 2
helperModClassSDRTest(radios);
else
disp('Selected radios not found. Skipping over-the-air test.')
end
end
case {"USRP B2xx","USRP X3xx","USRP N2xx"}
if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
txRadio = findPlutoRadio();
rxRadio = findsdru();
switch radioPlatform
case "USRP B2xx"
idx = contains({rxRadio.Platform}, {'B200','B210'});
case "USRP X3xx"
idx = contains({rxRadio.Platform}, {'X300','X310'});
case "USRP N2xx"
idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
end
rxRadio = rxRadio(idx);
if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
helperModClassSDRTest(rxRadio);
else
disp('Selected radios not found. Skipping over-the-air test.')
end
end
end
3)视图效果
8、总结
基于卷积神经网络(CNN)的调制分类在Matlab中可以通过深度学习工具箱等相关工具来实现。下面是对基于CNN的调制分类在Matlab中的关键步骤的总结:
总结步骤:
-
数据准备:准备带有标签的调制信号数据集,确保每个样本包含一个已知调制方式的信号。
-
数据预处理:对信号数据进行预处理,包括归一化、降噪等操作,以保证数据的质量。
-
数据特征提取:将信号数据转换为适合CNN输入的格式,可以在时域或频域下提取信号特征,并将其表示为矩阵形式。
-
构建CNN模型:定义CNN模型的结构,包括卷积层、池化层、激活函数层和全连接层等。可以根据具体需求自定义网络结构。
-
模型训练:使用训练集数据对CNN模型进行训练,通过反向传播算法不断调整模型参数以优化模型性能。
-
模型评估:使用测试集数据评估训练好的CNN模型的性能,包括准确率、召回率等指标,对模型进行优化和调整。
-
模型应用:将训练好的CNN模型用于未知信号的调制分类,通过模型预测得到信号的调制方式。
-
参数调优:根据模型评估结果,调整模型结构、超参数等进行优化,以提高调制分类的准确性和性能。
通过以上步骤,可以在Matlab中实现基于CNN的调制分类任务,从而对不同调制方式的信号进行准确分类和识别。在实际应用中,可以根据具体问题的需求对模型进行定制和调整,以获得更好的性能和效果。
9、源代码
代码
%% 基于卷积神经网络的调制分类
%使用卷积神经网络 (CNN) 进行调制分类
%生成合成的、通道减损波形。使用生成的波形作为训练数据,训练 CNN 进行调制分类
%% 使用 CNN 预测调制类型
%可识别以下八种数字调制类型和三种模拟调制类型
%二相相移键控 (BPSK)
%四相相移键控 (QPSK)
%八相相移键控 (8-PSK)
%十六相正交调幅 (16-QAM)
%六十四相正交调幅 (64-QAM)
%四相脉冲振幅调制 (PAM4)
%高斯频移键控 (GFSK)
%连续相位频移键控 (CPFSK)
%广播 FM (B-FM)
%双边带振幅调制 (DSB-AM)
%单边带振幅调制 (SSB-AM)
modulationTypes = categorical(sort(["BPSK", "QPSK", "8PSK", ...
"16QAM", "64QAM", "PAM4", "GFSK", "CPFSK", ...
"B-FM", "DSB-AM", "SSB-AM"]));
%% 加载训练的网络
load trainedModulationClassificationNetwork
trainedNet
%经过训练的 CNN 接受 1024 个通道减损采样,并预测每个帧的调制类型
%生成几个因莱斯多径衰落、中心频率和采样时间漂移以及 AWGN 而有所减损的 PAM4 帧。
%以下函数生成合成信号来测试 CNN。然后使用 CNN 预测帧的调制类型。
%randi:生成随机位
%pammod (Communications Toolbox):PAM4 调制位
%rcosdesign (Signal Processing Toolbox):设计平方根升余弦脉冲整形滤波器
%filter:脉冲确定符号的形状
%comm.RicianChannel (Communications Toolbox):应用莱斯多径通道
%comm.PhaseFrequencyOffset (Communications Toolbox):应用时钟偏移引起的相位和/或频率偏移
%interp1:应用时钟偏移引起的计时漂移
%awgn (Communications Toolbox):添加 AWGN
% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(123456)
% Random bits
d = randi([0 3], 1024, 1);
% PAM4 modulation
syms = pammod(d,4);
% Square-root raised cosine filter
filterCoeffs = rcosdesign(0.35,4,8);
tx = filter(filterCoeffs,1,upsample(syms,8));
% Channel
SNR = 30;
maxOffset = 5;
fc = 902e6;
fs = 200e3;
multipathChannel = comm.RicianChannel(...
'SampleRate', fs, ...
'PathDelays', [0 1.8 3.4] / 200e3, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4);
frequencyShifter = comm.PhaseFrequencyOffset(...
'SampleRate', fs);
% Apply an independent multipath channel
reset(multipathChannel)
outMultipathChan = multipathChannel(tx);
% Determine clock offset factor
clockOffset = (rand() * 2*maxOffset) - maxOffset;
C = 1 + clockOffset / 1e6;
% Add frequency offset
frequencyShifter.FrequencyOffset = -(C-1)*fc;
outFreqShifter = frequencyShifter(outMultipathChan);
% Add sampling time drift
t = (0:length(tx)-1)' / fs;
newFs = fs * C;
tp = (0:length(tx)-1)' / newFs;
outTimeDrift = interp1(t, outFreqShifter, tp);
% Add noise
rx = awgn(outTimeDrift,SNR,0);
% Frame generation for classification
unknownFrames = helperModClassGetNNFrames(rx);
% Classification
scores1 = predict(trainedNet,unknownFrames);
prediction1 = scores2label(scores1,modulationTypes);
%返回分类器预测
prediction1
%分类器还返回一个包含每一帧分数的向量
%分数对应于每个帧具有预测的调制类型的概率。绘制分数图。
helperModClassPlotScores(scores1,modulationTypes)
%% 生成用于训练的波形
%为每种调制类型生成 10000 个帧,其中 80% 用于训练,10% 用于验证,10% 用于测试。
%网络训练阶段使用训练和验证帧
%使用测试帧获得最终分类准确度。每帧的长度为 1024 个样本,采样率为 200 kHz。对于数字调制类型,八个采样表示一个符号。
trainNow = false;
if trainNow == true
numFramesPerModType = 10000;
else
numFramesPerModType = 200;
end
percentTrainingSamples = 80;
percentValidationSamples = 10;
percentTestSamples = 10;
sps = 8; % Samples per symbol
spf = 1024; % Samples per frame
fs = 200e3; % Sample rate
fc = [902e6 100e6]; % Center frequencies
%创建通道减损:AWGN/莱斯多径衰落/时钟偏移,导致中心频率偏移和采样时间漂移
%AWGN:通道增加 SNR 为 30 dB 的 AWGN。使用 awgn (Communications Toolbox) 函数实现通道
%莱斯多径:通道使用 comm.RicianChannel (Communications Toolbox) System object™ 通过莱斯多径衰落通道传递信号。
%时钟偏移:时钟偏移是发射机和接收机的内部时钟源不准确造成的。
maxDeltaOff = 5;
deltaOff = (rand()*2*maxDeltaOff) - maxDeltaOff;
C = 1 + (deltaOff/1e6);
%频率偏移:基于时钟偏移因子 C 和中心频率,对每帧进行频率偏移
%采样率偏移:基于时钟偏移因子 C,对每帧进行采样率偏移。
%合并后的通道:使用 helperModClassTestChannel 对象对帧应用所有三种通道减损
channel = helperModClassTestChannel(...
'SampleRate', fs, ...
'SNR', SNR, ...
'PathDelays', [0 1.8 3.4] / fs, ...
'AveragePathGains', [0 -2 -10], ...
'KFactor', 4, ...
'MaximumDopplerShift', 4, ...
'MaximumClockOffset', 5, ...
'CenterFrequency', 902e6)
%使用 info 对象函数查看有关通道的基本信息
chInfo = info(channel)
%波形生成
% Set the random number generator to a known state to be able to regenerate
% the same frames every time the simulation is run
rng(12)
tic
numModulationTypes = length(modulationTypes);
channelInfo = info(channel);
transDelay = 50;
pool = getPoolSafe();
if ~isa(pool,"parallel.ClusterPool")
dataDirectory = fullfile(tempdir,"ModClassDataFiles");
else
dataDirectory = uigetdir("","Select network location to save data files");
end
disp("Data file directory is " + dataDirectory)
fileNameRoot = "frame";
% Check if data files exist
dataFilesExist = false;
if exist(dataDirectory,'dir')
files = dir(fullfile(dataDirectory,sprintf("%s*",fileNameRoot)));
if length(files) == numModulationTypes*numFramesPerModType
dataFilesExist = true;
end
end
if ~dataFilesExist
disp("Generating data and saving in data files...")
[success,msg,msgID] = mkdir(dataDirectory);
if ~success
error(msgID,msg)
end
for modType = 1:numModulationTypes
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Generating %s frames\n', ...
elapsedTime, modulationTypes(modType))
label = modulationTypes(modType);
numSymbols = (numFramesPerModType / sps);
dataSrc = helperModClassGetSource(modulationTypes(modType), sps, 2*spf, fs);
modulator = helperModClassGetModulator(modulationTypes(modType), sps, fs);
if contains(char(modulationTypes(modType)), {'B-FM','DSB-AM','SSB-AM'})
% Analog modulation types use a center frequency of 100 MHz
channel.CenterFrequency = 100e6;
else
% Digital modulation types use a center frequency of 902 MHz
channel.CenterFrequency = 902e6;
end
for p=1:numFramesPerModType
% Generate random data
x = dataSrc();
% Modulate
y = modulator(x);
% Pass through independent channels
rxSamples = channel(y);
% Remove transients from the beginning, trim to size, and normalize
frame = helperModClassFrameGenerator(rxSamples, spf, spf, transDelay, sps);
% Save data file
fileName = fullfile(dataDirectory,...
sprintf("%s%s%03d",fileNameRoot,modulationTypes(modType),p));
save(fileName,"frame","label")
end
end
else
disp("Data files exist. Skip data generation.")
end
%显示波形
helperModClassPlotTimeDomain(dataDirectory,modulationTypes,fs)
helperModClassPlotSpectrogram(dataDirectory,modulationTypes,fs,sps)
%创建数据存储
%使用 signalDatastore 对象来管理包含生成的复杂波形的文件
frameDS = signalDatastore(dataDirectory,'SignalVariableNames',["frame","label"]);
%拆分为训练、验证和测试
splitPercentages = [percentTrainingSamples,percentValidationSamples,percentTestSamples];
[trainDS,validDS,testDS] = helperModClassSplitData(frameDS,splitPercentages);
%将数据导入内存
%神经网络训练是迭代进行
% Read the training and validation frames into the memory
pctExists = parallelComputingLicenseExists();
trainFrames = transform(trainDS, @helperModClassReadFrame);
rxTrainFrames = readall(trainFrames,"UseParallel",pctExists);
validFrames = transform(validDS, @helperModClassReadFrame);
rxValidFrames = readall(validFrames,"UseParallel",pctExists);
% Read the training and validation labels into the memory
trainLabels = transform(trainDS, @helperModClassReadLabel);
rxTrainLabels = readall(trainLabels,"UseParallel",pctExists);
validLabels = transform(validDS, @helperModClassReadLabel);
rxValidLabels = readall(validLabels,"UseParallel",pctExists);
%% 训练 CNN
%CNN 由五个卷积层和一个全连接层组成
%一个卷积层外,每个卷积层后面都有一个批量归一化层、修正线性单元 (ReLU) 激活层和最大池化层
modClassNet = helperModClassCNN(modulationTypes,sps,spf);
%配置 TrainingOptionsSGDM 以使用小批量大小为 1024 的 SGDM 求解器
maxEpochs = 20;
miniBatchSize = 1024;
trainingPlots = "none";
metrics = [];
verbose = true;
validationFrequency = floor(numel(rxTrainLabels)/miniBatchSize);
options = trainingOptions('sgdm', ...
InitialLearnRate = 3e-1, ...
MaxEpochs = maxEpochs, ...
MiniBatchSize = miniBatchSize, ...
Shuffle = 'every-epoch', ...
Plots = trainingPlots, ...
Verbose = verbose, ...
ValidationData = {rxValidFrames,rxValidLabels}, ...
ValidationFrequency = validationFrequency, ...
ValidationPatience = 5, ...
Metrics = metrics, ...
LearnRateSchedule = 'piecewise', ...
LearnRateDropPeriod = 6, ...
LearnRateDropFactor = 0.75, ...
OutputNetwork='best-validation-loss');
%训练神网络
if trainNow == true
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Training the network\n', elapsedTime)
trainedNet = trainnet(rxTrainFrames,rxTrainLabels,modClassNet,"crossentropy",options);
else
load trainedModulationClassificationNetwork
end
%通过获得测试帧的分类准确度来评估经过训练的网络
elapsedTime = seconds(toc);
elapsedTime.Format = 'hh:mm:ss';
fprintf('%s - Classifying test frames\n', elapsedTime)
% Read the test frames into the memory
testFrames = transform(testDS, @helperModClassReadFrame);
rxTestFrames = readall(testFrames,"UseParallel",pctExists);
% Read the test labels into the memory
testLabels = transform(testDS, @helperModClassReadLabel);
rxTestLabels = readall(testLabels,"UseParallel",pctExists);
scores = predict(trainedNet,cat(3,rxTestFrames{:}));
rxTestPred = scores2label(scores,modulationTypes);
testAccuracy = mean(rxTestPred == rxTestLabels);
disp("Test accuracy: " + testAccuracy*100 + "%")
%% 使用 SDR 进行测试
%使用 helperModClassSDRTest 函数,通过空口信号测试经过训练的网络的性能。
radioPlatform = "ADALM-PLUTO";
switch radioPlatform
case "ADALM-PLUTO"
if helperIsPlutoSDRInstalled() == true
radios = findPlutoRadio();
if length(radios) >= 2
helperModClassSDRTest(radios);
else
disp('Selected radios not found. Skipping over-the-air test.')
end
end
case {"USRP B2xx","USRP X3xx","USRP N2xx"}
if (helperIsUSRPInstalled() == true) && (helperIsPlutoSDRInstalled() == true)
txRadio = findPlutoRadio();
rxRadio = findsdru();
switch radioPlatform
case "USRP B2xx"
idx = contains({rxRadio.Platform}, {'B200','B210'});
case "USRP X3xx"
idx = contains({rxRadio.Platform}, {'X300','X310'});
case "USRP N2xx"
idx = contains({rxRadio.Platform}, 'N200/N210/USRP2');
end
rxRadio = rxRadio(idx);
if (length(txRadio) >= 1) && (length(rxRadio) >= 1)
helperModClassSDRTest(rxRadio);
else
disp('Selected radios not found. Skipping over-the-air test.')
end
end
end
工程文件
https://download.csdn.net/download/XU157303764/89498445