【深度学习】用LSTM写诗,生成式的方式写诗系列之一

Epoch 4: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=18.5, loss=5.8]
[5] loss: 5.828, accuracy: 18.389 , lr:0.001000
Epoch 5: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=19.2, loss=5.68]
[6] loss: 5.739, accuracy: 18.732 , lr:0.001000
Epoch 6: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=19.6, loss=5.57]
[7] loss: 5.629, accuracy: 19.197 , lr:0.001000
Epoch 7: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=19.9, loss=5.45]
[8] loss: 5.517, accuracy: 19.745 , lr:0.001000
Epoch 8: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=20.4, loss=5.38]
[9] loss: 5.402, accuracy: 20.316 , lr:0.001000
Epoch 9: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=21, loss=5.27]
[10] loss: 5.299, accuracy: 20.887 , lr:0.001000
Epoch 10: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=21.8, loss=5.16]
[11] loss: 5.210, accuracy: 21.427 , lr:0.001000
Epoch 11: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=22.2, loss=5.1]
[12] loss: 5.136, accuracy: 21.923 , lr:0.001000
Epoch 12: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=22.5, loss=5.06]
[13] loss: 5.071, accuracy: 22.379 , lr:0.001000
Epoch 13: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=22.9, loss=5.02]
[14] loss: 5.011, accuracy: 22.819 , lr:0.001000
Epoch 14: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=23.3, loss=4.94]
[15] loss: 4.959, accuracy: 23.212 , lr:0.001000
Epoch 15: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=23.7, loss=4.91]
[16] loss: 4.910, accuracy: 23.564 , lr:0.001000
Epoch 16: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=24.3, loss=4.82]
[17] loss: 4.862, accuracy: 23.914 , lr:0.001000
Epoch 17: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=24, loss=4.83]
[18] loss: 4.818, accuracy: 24.228 , lr:0.001000
Epoch 18: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=24.7, loss=4.77]
[19] loss: 4.775, accuracy: 24.523 , lr:0.001000
Epoch 19: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=24.6, loss=4.73]
[20] loss: 4.734, accuracy: 24.808 , lr:0.001000
Epoch 20: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=25, loss=4.69]
[21] loss: 4.694, accuracy: 25.090 , lr:0.001000
Epoch 21: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=25, loss=4.71]
[22] loss: 4.657, accuracy: 25.346 , lr:0.001000
Epoch 22: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=25.8, loss=4.62]
[23] loss: 4.619, accuracy: 25.587 , lr:0.001000
Epoch 23: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=25.9, loss=4.59]
[24] loss: 4.584, accuracy: 25.825 , lr:0.001000
Epoch 24: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.3, loss=4.52]
[25] loss: 4.549, accuracy: 26.078 , lr:0.001000
Epoch 25: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=26.3, loss=4.53]
[26] loss: 4.516, accuracy: 26.280 , lr:0.001000
Epoch 26: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.6, loss=4.49]
[27] loss: 4.483, accuracy: 26.517 , lr:0.001000
Epoch 27: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=26.8, loss=4.46]
[28] loss: 4.451, accuracy: 26.746 , lr:0.001000
Epoch 28: 100%|███████████████████████████████████████████████████████| 63/63 [1:00:07<00:00, 57.26s/batch, acc=27.1, loss=4.41]
[29] loss: 4.422, accuracy: 26.937 , lr:0.001000
Epoch 29: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=27.2, loss=4.38]
[30] loss: 4.389, accuracy: 27.182 , lr:0.001000
Epoch 30: 100%|████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=27, loss=4.4]
[31] loss: 4.361, accuracy: 27.371 , lr:0.001000
Epoch 31: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=27.5, loss=4.34]
[32] loss: 4.332, accuracy: 27.589 , lr:0.001000
Epoch 32: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=27.6, loss=4.31]
[33] loss: 4.304, accuracy: 27.791 , lr:0.001000
Epoch 33: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.89batch/s, acc=27.9, loss=4.28]
[34] loss: 4.277, accuracy: 28.014 , lr:0.001000
Epoch 34: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=28, loss=4.26]
[35] loss: 4.248, accuracy: 28.200 , lr:0.001000
Epoch 35: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=28.6, loss=4.22]
[36] loss: 4.222, accuracy: 28.433 , lr:0.001000
Epoch 36: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=28.3, loss=4.21]
[37] loss: 4.196, accuracy: 28.625 , lr:0.001000
Epoch 37: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=29.1, loss=4.16]
[38] loss: 4.169, accuracy: 28.858 , lr:0.001000
Epoch 38: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29.2, loss=4.13]
[39] loss: 4.142, accuracy: 29.056 , lr:0.001000
Epoch 39: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29, loss=4.13]
[40] loss: 4.116, accuracy: 29.282 , lr:0.001000
Epoch 40: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=29.5, loss=4.12]
[41] loss: 4.092, accuracy: 29.477 , lr:0.001000
Epoch 41: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=29.7, loss=4.08]
[42] loss: 4.066, accuracy: 29.716 , lr:0.001000
Epoch 42: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=29.8, loss=4.06]
[43] loss: 4.042, accuracy: 29.918 , lr:0.001000
Epoch 43: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.5, loss=3.99]
[44] loss: 4.016, accuracy: 30.146 , lr:0.001000
Epoch 44: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.2, loss=4.01]
[45] loss: 3.990, accuracy: 30.398 , lr:0.001000
Epoch 45: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=30.6, loss=3.96]
[46] loss: 3.968, accuracy: 30.607 , lr:0.001000
Epoch 46: 100%|█████████████████████████████████████████████████████████| 63/63 [40:05<00:00, 38.19s/batch, acc=30.6, loss=3.96]
[47] loss: 3.945, accuracy: 30.814 , lr:0.001000
Epoch 47: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=30.9, loss=3.94]
[48] loss: 3.918, accuracy: 31.073 , lr:0.001000
Epoch 48: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=31.1, loss=3.91]
[49] loss: 3.893, accuracy: 31.322 , lr:0.001000
Epoch 49: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=32, loss=3.86]
[50] loss: 3.869, accuracy: 31.574 , lr:0.001000
Epoch 50: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=31.2, loss=3.9]
[51] loss: 3.846, accuracy: 31.811 , lr:0.001000
Epoch 51: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=31.7, loss=3.85]
[52] loss: 3.823, accuracy: 32.042 , lr:0.001000
Epoch 52: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=32.4, loss=3.8]
[53] loss: 3.798, accuracy: 32.325 , lr:0.001000
Epoch 53: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.2, loss=3.8]
[54] loss: 3.776, accuracy: 32.552 , lr:0.001000
Epoch 54: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.88batch/s, acc=32.4, loss=3.79]
[55] loss: 3.755, accuracy: 32.794 , lr:0.001000
Epoch 55: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.8, loss=3.75]
[56] loss: 3.729, accuracy: 33.081 , lr:0.001000
Epoch 56: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=32.8, loss=3.74]
[57] loss: 3.708, accuracy: 33.301 , lr:0.001000
Epoch 57: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=33.8, loss=3.68]
[58] loss: 3.683, accuracy: 33.597 , lr:0.001000
Epoch 58: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=33.5, loss=3.67]
[59] loss: 3.661, accuracy: 33.838 , lr:0.001000
Epoch 59: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=34, loss=3.65]
[60] loss: 3.639, accuracy: 34.106 , lr:0.001000
Epoch 60: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=34, loss=3.65]
[61] loss: 3.619, accuracy: 34.350 , lr:0.001000
Epoch 61: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=34.1, loss=3.64]
[62] loss: 3.595, accuracy: 34.632 , lr:0.001000
Epoch 62: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=34.6, loss=3.57]
[63] loss: 3.573, accuracy: 34.872 , lr:0.001000
Epoch 63: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=34.8, loss=3.58]
[64] loss: 3.553, accuracy: 35.140 , lr:0.001000
Epoch 64: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=35.1, loss=3.53]
[65] loss: 3.531, accuracy: 35.394 , lr:0.001000
Epoch 65: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=34.8, loss=3.56]
[66] loss: 3.512, accuracy: 35.636 , lr:0.001000
Epoch 66: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=35.1, loss=3.55]
[67] loss: 3.490, accuracy: 35.896 , lr:0.001000
Epoch 67: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=36.1, loss=3.49]
[68] loss: 3.471, accuracy: 36.147 , lr:0.001000
Epoch 68: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36, loss=3.48]
[69] loss: 3.451, accuracy: 36.413 , lr:0.001000
Epoch 69: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=36.5, loss=3.44]
[70] loss: 3.436, accuracy: 36.595 , lr:0.001000
Epoch 70: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36.5, loss=3.45]
[71] loss: 3.412, accuracy: 36.873 , lr:0.001000
Epoch 71: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=36.2, loss=3.44]
[72] loss: 3.393, accuracy: 37.130 , lr:0.001000
Epoch 72: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=36.4, loss=3.44]
[73] loss: 3.375, accuracy: 37.342 , lr:0.001000
Epoch 73: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=37.1, loss=3.4]
[74] loss: 3.355, accuracy: 37.608 , lr:0.001000
Epoch 74: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=37.2, loss=3.37]
[75] loss: 3.337, accuracy: 37.853 , lr:0.001000
Epoch 75: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=37.9, loss=3.35]
[76] loss: 3.318, accuracy: 38.105 , lr:0.001000
Epoch 76: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=37.5, loss=3.35]
[77] loss: 3.303, accuracy: 38.282 , lr:0.001000
Epoch 77: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=37.9, loss=3.31]
[78] loss: 3.285, accuracy: 38.523 , lr:0.001000
Epoch 78: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.1, loss=3.3]
[79] loss: 3.267, accuracy: 38.738 , lr:0.001000
Epoch 79: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=38.9, loss=3.28]
[80] loss: 3.250, accuracy: 38.972 , lr:0.001000
Epoch 80: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.6, loss=3.27]
[81] loss: 3.230, accuracy: 39.248 , lr:0.001000
Epoch 81: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=39.1, loss=3.22]
[82] loss: 3.216, accuracy: 39.435 , lr:0.001000
Epoch 82: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=38.8, loss=3.25]
[83] loss: 3.197, accuracy: 39.675 , lr:0.001000
Epoch 83: 100%|██████████████████████████████████████████████████████████| 63/63 [00:38<00:00,  1.62batch/s, acc=39.7, loss=3.2]
[84] loss: 3.180, accuracy: 39.914 , lr:0.001000
Epoch 84: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=39.4, loss=3.2]
[85] loss: 3.165, accuracy: 40.108 , lr:0.001000
Epoch 85: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.87batch/s, acc=40.1, loss=3.17]
[86] loss: 3.152, accuracy: 40.277 , lr:0.001000
Epoch 86: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=39.9, loss=3.18]
[87] loss: 3.135, accuracy: 40.508 , lr:0.001000
Epoch 87: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=40.4, loss=3.14]
[88] loss: 3.118, accuracy: 40.736 , lr:0.001000
Epoch 88: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.84batch/s, acc=40.5, loss=3.14]
[89] loss: 3.104, accuracy: 40.918 , lr:0.001000
Epoch 89: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=40.7, loss=3.11]
[90] loss: 3.093, accuracy: 41.061 , lr:0.001000
Epoch 90: 100%|██████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=40.8, loss=3.1]
[91] loss: 3.074, accuracy: 41.315 , lr:0.001000
Epoch 91: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.5, loss=3.06]
[92] loss: 3.057, accuracy: 41.559 , lr:0.001000
Epoch 92: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.1, loss=3.09]
[93] loss: 3.043, accuracy: 41.745 , lr:0.001000
Epoch 93: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=41.5, loss=3.06]
[94] loss: 3.029, accuracy: 41.924 , lr:0.001000
Epoch 94: 100%|███████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=42, loss=3.03]
[95] loss: 3.015, accuracy: 42.133 , lr:0.001000
Epoch 95: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.85batch/s, acc=41.6, loss=3.04]
[96] loss: 3.001, accuracy: 42.302 , lr:0.001000
Epoch 96: 100%|██████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.82batch/s, acc=42, loss=3]
[97] loss: 2.988, accuracy: 42.483 , lr:0.001000
Epoch 97: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=42.9, loss=2.96]
[98] loss: 2.972, accuracy: 42.694 , lr:0.001000
Epoch 98: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.86batch/s, acc=42.1, loss=3.01]
[99] loss: 2.964, accuracy: 42.804 , lr:0.001000
Epoch 99: 100%|█████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=42.2, loss=3.01]
[100] loss: 2.953, accuracy: 42.973 , lr:0.001000
Finished Training using %.3f seconds 6896.93013882637

先训练了100轮次,后面应该还能增长,但是不等了

数据初探:

class DictObj(object):
    def __init__(self, map):
        self.map = map
        
    def __getattr__(self, attr):
        if attr in self.map:
            return self.map[attr]
        else:
            raise AttributeError("No such attribute: " + attr)
        
Config = DictObj({
    'poem_path':os.path.join(base_dir, "tang.npz"),
    "tensorboard_path":os.path.join(base_dir, "tensorboard"),
    "model_save_path":os.path.join(base_dir,"modelDict"),
    "embedding_dim":100,
    "hidden_dim":1024,
    "lr":0.001,
    "LSTM_layers":2,
    'batch_size':512,
    'epochs':500,
    'dropout':0.2,
    'ealier_stop':10,
    'device':torch.device('cuda' if torch.cuda.is_available() else 'cpu')
})
def view_data(poem_path):
    datas = np.load(poem_path, allow_pickle=True)
    data = datas['data'] #(57580,125)
    ix2word = datas['ix2word'].item() # datas['word2ix'].item() 8293
    word2ix = datas['word2ix'].item() # datas['word2ix'].item() 8293
    word_data = np.zeros((1,data.shape[1]), dtype = str) # 将所有的0 转化成 ''
    # 看一下其中一行的数据是什么?
    row = np.random.randint(0, data.shape[0]) # 随机选一行,左闭右开没问题
    print(data[row])
    for i in range(data.shape[1]):
        word_data[0][i] = ix2word[data[row][i]]
    print(word_data)

view_data(Config.poem_path)

数据处理:

class PoemDataset(Dataset):
    def __init__(self, poem_path, seq_len):
        super().__init__()
        # np 文件的地址
        self.poem_path = poem_path
        # 序列长度,48 是认为规定的,也可以是其它值,因为大部分是5言或者7言,加上表达就是 6,或8, 取48确保是整句话
        self.seq_len = seq_len
        self.poem_data, self.ix2word, self.word2ix = self.get_raw_data()
        self.no_space_data = self.filter_space()
        print("no_space_data len:", self.no_space_data[0:200])
    
    def __len__(self):
        return len(self.no_space_data)//(self.seq_len)

    def __getitem__(self, idx):
        txt = self.no_space_data[idx*self.seq_len:(idx+1)*self.seq_len]
        label = self.no_space_data[idx*self.seq_len+1:(idx+1)*self.seq_len+1]
        return torch.LongTensor(txt), torch.LongTensor(label)
    
    def filter_space(self):
        # 7197500 个文本
        tensor_data = torch.from_numpy(self.poem_data).view(-1)
        no_space_data = []  
        for i in range(tensor_data.shape[0]):
            word_idx = tensor_data[i].item()
            if word_idx!= 8292:
                no_space_data.append(word_idx)
        return no_space_data

        
    def get_raw_data(self):
        datas = np.load(self.poem_path, allow_pickle=True)
        data = datas['data']
        ix2word = datas['ix2word'].item()
        word2ix = datas['word2ix'].item()
        return data, ix2word, word2ix
poem_dataset = PoemDataset(Config.poem_path, 96)
[8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292 8292
 8292 8292 8292 8292 8292 8292 8292 8291 5428 6933 3469 7066 3465 6407
 8248 7009   82 7435  925 3469 3576  232  786 5272 2296 7066 4807 6103
 6663 2958 2003 2173   28 7066 1987 8061 4299  848 4874 7435 8290]
[['<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<' '<'
  '<' '<' '冬' '月' '内' ',' '无' '叶' '艾' '枝' '枯' '。' '草' '内' '急' '寻' '蛇' '床'
  '子' ',' '烧' '烟' '入' '中' '自' '消' '除' ',' '速' '救' '免' '灾' '虞' '。' '<']]
no_space_data len: [8291, 6731, 4770, 1787, 8118, 7577, 7066, 4817, 648, 7121, 1542, 6483, 7435, 7686, 2889, 1671, 5862, 1949, 7066, 2596, 4785, 3629, 1379, 2703, 7435, 6064, 6041, 4666, 4038, 4881, 7066, 4747, 1534, 70, 3788, 3823, 7435, 4907, 5567, 201, 2834, 1519, 7066, 782, 782, 2063, 2031, 846, 7435, 8290, 8291, 2309, 2596, 6483, 2260, 7316, 7066, 6332, 5274, 2125, 5029, 7792, 7435, 4186, 8087, 7047, 6622, 6933, 7066, 6134, 3564, 3766, 6920, 6157, 7435, 7086, 4770, 5849, 4776, 4981, 7066, 4857, 2649, 3020, 332, 1727, 7435, 7458, 7294, 3465, 5149, 1671, 7066, 2834, 6000, 3942, 3534, 1534, 7435, 4102, 7460, 758, 3961, 3374, 7066, 7904, 6811, 4449, 2121, 6802, 7435, 6182, 27, 7912, 1756, 7440, 7066, 201, 7909, 8118, 201, 4662, 7435, 7824, 1508, 3154, 152, 5862, 7066, 7976, 6043, 258, 47, 7878, 7435, 8290, 8291, 3495, 70, 7113, 4839, 5237, 7066, 65, 3941, 2031, 2260, 5418, 7435, 411, 6773, 2878, 4686, 482, 7066, 1989, 5617, 4992, 8245, 676, 7435, 4236, 1418, 4915, 7686, 7363, 7066, 5708, 7541, 7440, 5237, 2192, 7435, 3114, 5913, 7989, 3069, 1845, 7066, 7047, 3534, 4921, 6622, 6933, 7435, 1664, 2260, 2003, 4816, 7151, 7066, 5036, 2219, 5849, 4898, 174, 7435, 201, 7228, 222]

因为有空格,啥的,要先吧空格之类的去掉。

def show_dataset():
    idx,label = poem_dataset[0]
    for id in idx:
        print(poem_dataset.ix2word[id.item()], end=' ')
    print("\n")
    for la in label:
        print(poem_dataset.ix2word[la.item()], end=' ')
    '''
    <START> 度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 

    度 门 能 不 访 , 冒 雪 屡 西 东 。 已 想 人 如 玉 , 遥 怜 马 似 骢 。 乍 迷 金 谷 路 , 稍 变 上 阳 宫 。 还 比 相 思 意 , 纷 纷 正 满 空 。 <EOP> <START> 逍 遥 东 城 隅 , 双 树 寒 葱 蒨 。 广 庭 流 华 月 , 高 阁 凝 余 霰 。 杜 门 非 养 素 , 抱 疾 阻 良 䜩 。 孰 谓 无 他 人 , 思 君 岁 云
    '''
# 构建模型
class PoemModel(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, dropout=dropout, batch_first=True)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(hidden_dim, vocab_size)

    def forward(self, input, hidden=None):
        embeds = self.embedding(input)
        batch_size,seq_len,embedding_dim = embeds.shape
        if hidden is None:
            h0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)
            c0 = torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim).to(Config.device)
        else:
            h0,c0 = hidden
        output, hidden = self.lstm(embeds, (h0, c0))
        # output = torch.tanh(self.dropout(self.fc1(output)))
        output = self.fc(output)
        return output, hidden

vocab_size = len(poem_dataset.word2ix)
model = PoemModel(vocab_size, Config.embedding_dim, Config.hidden_dim, Config.LSTM_layers, Config.dropout).to(Config.device)
input_data, label_data = next(iter(dataloader))
print(input_data.shape, label_data.shape)
output, hidden = model(input_data.to(Config.device))
# output.shape torch.Size([1024, 96, 8293]) hidden[0].shape torch.Size([3, 1024, 1024]) hidden[1].shape torch.Size([3, 1024, 1024])  label_data.shape torch.Size([1024, 96])
a = 1

def accuracy(output, label_data):
    pred = output.argmax(dim=2)
    correct = (pred == label_data).sum().item()
    total = label_data.numel()
    return correct / total * 100
# 训练模型
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=Config.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
def train(model, dataloader, criterion, optimizer, scheduler, epochs):
    if not os.path.exists(Config.model_save_path):
        os.makedirs(Config.model_save_path)
    best_acc = 0.0
    early_stop = 0
    start_time = time.time()
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        running_acc = 0.0
        last_acc = 0.0
        with tqdm(dataloader, unit="batch") as tepoch:
            for input_data, label_data in tepoch:
                tepoch.set_description(f"Epoch {epoch}")
                input_data, label_data = input_data.to(Config.device), label_data.to(Config.device)
                optimizer.zero_grad()
                output, hidden = model(input_data)
                current_acc = accuracy(output, label_data)
                running_acc += current_acc
                loss = criterion(output.view(-1, vocab_size), label_data.view(-1))
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                tepoch.set_postfix(loss=loss.item(), acc=current_acc)
        scheduler.step()
        last_acc = running_acc / len(dataloader)
        if last_acc > best_acc:
            best_acc = last_acc
            torch.save(model.state_dict(), os.path.join(Config.model_save_path, "best_model.pth"))
        else:
            early_stop += 1
        torch.save(model.state_dict(), os.path.join(Config.model_save_path, "last_model.pth"))
        print('[%d] loss: %.3f, accuracy: %.3f , lr:%.6f'  % (epoch + 1, running_loss / len(dataloader), last_acc,scheduler.get_last_lr()[0]))

        if early_stop >= Config.ealier_stop:
            print("Early Stop")
            print("Best Accuracy: %.3f" % best_acc)
            break
    print('Finished Training using %.3f seconds', time.time() - start_time)
train(model, dataloader, criterion, optimizer, scheduler, Config.epochs)  

模型构建以及训练如上,

现在看 500轮次,50个忍耐度的效果比较好

Epoch 145: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.81batch/s, acc=61.6, loss=1.62]
[146] loss: 1.580, accuracy: 62.374 , lr:0.001000
Epoch 146: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.80batch/s, acc=62.5, loss=1.57]
[147] loss: 1.581, accuracy: 62.360 , lr:0.001000
Epoch 147: 100%|██████████████████████████████████████████████████████████████████████████| 63/63 [00:07<00:00,  8.83batch/s, acc=61.9, loss=1.61]
[148] loss: 1.582, accuracy: 62.324 , lr:0.001000
Early Stop
Best Accuracy: 62.397
Finished Training using %.3f seconds 1205.000694513321

使用效果

不满意的地方,写死96=seq_len 是不对的。
应该是 配合 padding 使用,并mask padding来指导损失 @todo, 下一篇文章我会搞定!

import torch
from train03 import Config
from train03 import PoemModel

from train03 import PoemDataset
import os

poem_dataset = PoemDataset(Config.poem_path, 96)

vocab_size = len(poem_dataset.word2ix)
model = PoemModel(
    vocab_size, 
    Config.embedding_dim, 
    Config.hidden_dim, 
    Config.LSTM_layers, 
    Config.dropout).to(Config.device)
model.load_state_dict(torch.load(os.path.join(Config.model_save_path, "best_model.pth")))

def generate(model, start_words, ix2word, word2ix, device):
     results = list(start_words)
     start_words_len = len(start_words)
     # 第一个词语是<START>
     input = torch.Tensor([word2ix['<START>']]).view(1, 1).long()
     # 最开始的隐状态初始为0矩阵
     # torch.zeros(Config.LSTM_layers, batch_size, Config.hidden_dim)
     hidden = torch.zeros((2,Config.LSTM_layers * 1, 1, Config.hidden_dim), dtype=torch.float32).to(Config.device)
     input = input.to(Config.device)
     hidden = hidden.to(Config.device)
     model.eval()
     with torch.no_grad():
        for i in range(48):
            output, hidden = model(input, hidden)
            # 如果在给定的句首中,input为句首中的下一个字
            if i < start_words_len:
                w = results[i]
                input = input.data.new([word2ix[w]]).view(1, 1)
            else:
                top_index = output.data[0].topk(1)[1][0].item()
                w = ix2word[top_index]
                results.append(w)
                input = input.data.new([top_index]).view(1, 1)
            if w == '<EOP>':
                del results[-1]
                break
        return results

雨 余 虚 馆 竹 阴 清 , 独 坐 寒 窗 昼 未 醒 。 云 布 远 村 红 叶 返 , 水 深 秋 竹 翠 梢 寒 。 泉 声 入 阁 慙 嘉 石 , 山 色 题 诗 好 赋 诗 。

但是我有一点不太理解。
他是输入一个字,输出一个字,这一点好像不妥。不应该是 输入一个 生成1个, 然后输入两个,生成1个,然后输入3个生成1个么。。。 大神请指教一下吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/908449.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring3(代理模式 Spring1案例补充 Aop 面试题)

Spring3 代理模式概述介绍什么是代理模式&#xff1f;为什么要使用代理模式&#xff1f;有哪几种代理模式&#xff1f;静态代理基于接口实现的动态代理(JDK自带)基于子类的动态代理 Spring_AOP_01案例补充(添加事务管理)实现完整代码&#xff1a;常规实现&#xff1a;代理实现 …

开源模型应用落地-Qwen2.5-7B-Instruct与TGI实现推理加速

一、前言 目前&#xff0c;大语言模型已升级至Qwen2.5版本。无论是语言模型还是多模态模型&#xff0c;均在大规模多语言和多模态数据上进行预训练&#xff0c;并通过高质量数据进行后期微调以贴近人类偏好。在本篇学习中&#xff0c;将集成 Hugging Face的TGI框架实现模型推理…

Android 使用ninja加速编译的方法

ninja的简介 随着Android版本的更迭&#xff0c;makefile体系逐渐增多&#xff0c;导致make单编模块的时间越来越长&#xff0c;每次都需要半个小时甚至更长时间&#xff0c;其原因为每次make都会重新加载所有mk文件&#xff0c;再生成ninja编译&#xff0c;此完整过程十分耗时…

javaNIO核心知识.中

Channel&#xff08;通道&#xff09; Channel 是一个通道&#xff0c;它建立了与数据源&#xff08;如文件、网络套接字等&#xff09;之间的连接。我们可以利用它来读取和写入数据&#xff0c;就像打开了一条自来水管&#xff0c;让数据在 Channel 中自由流动。 BIO 中的流…

缓存、注解、分页

一.缓存 作用&#xff1a;应用查询上&#xff0c;内存中的块区域。 缓存查询结果&#xff0c;减少与数据库的交互&#xff0c;从而提高运行效率。 1.SqlSession 缓存 1. 又称为一级缓存&#xff0c;mybatis自动开启。 2. 作用范围&#xff1a;同一…

流畅!HTMLCSS打造网格方块加载动画

效果演示 这个动画的效果是五个方块在网格中上下移动&#xff0c;模拟了一个连续的加载过程。每个方块的动画都是独立的&#xff0c;但是它们的时间间隔和路径被设计为相互协调&#xff0c;以创建出流畅的动画效果。 HTML <div class"loadingspinner"><…

【skywalking 】More than 15,000 ‘grammar‘ tokens have been presented. 【未解决请求答案】

问题 skywalking相关版本信息 jdk&#xff1a;17skywalking&#xff1a;10.1.0apache-skywalking-java-agent&#xff1a;9.3.0ElasticSearch : 8.8.2 问题描述 More than 15,000 grammar tokens have been presented. To prevent Denial Of Service attacks, parsing has b…

docker desktop使用ubuntu18.04带图形化+运行qemu

记录一下docker desktop使用ubuntu18.04带图形化命令和使用步骤 1. 下载镜像 参考&#xff1a;【Docker教程】Docker部署Ubuntu18.04(带图形化界面) 命令&#xff1a; docker pull kasmweb/ubuntu-bionic-desktop:1.10.02. 启动镜像 命令&#xff1a; docker run -d -it …

jmeter压测工具环境搭建(Linux、Mac)

目录 java环境安装 1、anaconda安装java环境&#xff08;推荐&#xff09; 2、直接在本地环境安装java环境 yum方式安装jdk 二进制方式安装jdk jmeter环境安装 1、jmeter单机安装 启动jmeter 配置环境变量 jmeter配置中文 2、jmeter集群搭建 多台机器部署jmeter集群…

ai翻唱部分步骤

模型部署 我是用的RVC进行的训练&#xff0c;也可以使用so-vits-svc。 通过百度网盘分享的文件&#xff1a;RVC-beta 链接&#xff1a;https://pan.baidu.com/s/1c99jR2fLChoqUFqf9gLUzg 提取码&#xff1a;4090 以Nvida显卡为例&#xff0c;分别下载“RVC1006Nvidia”和…

算法笔记-Day09(字符篇)

151. 反转字符串中的单词 class Solution {public String reverseWords(String s) {int lens.length(),count0;StringBuffer tempnew StringBuffer();StringBuffer ansnew StringBuffer();for(int i0;i<len;i){if(s.charAt(i)! &&(i0 || s.charAt(i-1) )){while(i&l…

安科瑞电能质量治理产品在光伏电站的应用有效解决了光伏电站面临的功率因数过低和谐波问题-安科瑞黄安南

1. 概述 随着全球对可再生能源需求的增加&#xff0c;分布式光伏电站的建设和发展迅速。然而&#xff0c;分布式光伏电站的运行过程中面临着一系列问题&#xff0c;比如导致企业关口计量点功率因数过低、谐波污染等。这些问题不仅影响光伏电站自身的运行效率&#xff0c;还会对…

Leetcode137只出现一次的数字|| 及其拓展

简述&#xff1a; 虽然标题是这么描述的&#xff0c;但是我们不是一上来就解这道题&#xff0c;先看一下他的子题和扩展 子题&#xff1a;136. 只出现一次的数字 - 力扣&#xff08;LeetCode&#xff09; 扩展题&#xff1a; 所以我们由易到难&#xff0c;先来看第一道&#x…

leetcode 382.链表随机结点

1.题目要求: 2.题目代码: /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListNode() : val(0), next(nullptr) {}* ListNode(int x) : val(x), next(nullptr) {}* ListNode(int x, ListNode *next) : val(x)…

GaussDB Ustore存储引擎解读

目录 一、数据库存储引擎 二、GaussDB Ustore存储引擎 总结 本文将介绍GaussDB中的Ustore存储引擎&#xff0c;包括Ustore的设计背景、特点介绍和适用业务场景等。 一、数据库存储引擎 数据库的存储引擎负责在内存和磁盘上存储、检索和管理数据&#xff0c;确保每个节点的…

使用 Sortable.js 库 实现 Vue3 elementPlus 的 el-table 拖拽排序

文章目录 实现效果Sortable.js介绍下载依赖添加类名导入sortablejs初始化拖拽实例拖拽完成后的处理总结 在开发过程中&#xff0c;我们经常需要处理表格数据&#xff0c;并为用户提供便捷的排序方式。特别是在需要管理长列表、分类数据或动态内容时&#xff0c;拖拽排序功能显得…

机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎

&#x1f497;&#x1f497;&#x1f497;欢迎来到我的博客&#xff0c;你将找到有关如何使用技术解决问题的文章&#xff0c;也会找到某个技术的学习路线。无论你是何种职业&#xff0c;我都希望我的博客对你有所帮助。最后不要忘记订阅我的博客以获取最新文章&#xff0c;也欢…

.net core 接口,动态接收各类型请求的参数

[HttpPost] public async Task<IActionResult> testpost([FromForm] object info) { //Postman工具测试结果&#xff1a; //FromBody,Postman的body只有rawjson时才进的来 //参数为空时&#xff0c;Body(form-data、x-www-form-urlencoded)解析到的数据也有所…

探索Unity:从游戏引擎到元宇宙体验,聚焦内容创作

unity是实时3D互动内容创作和运营平台&#xff0c;包括游戏开发、美术、建筑、汽车设计、影视在内的所有创作者&#xff0c;借助Unity将创意变成现实。提供一整套完善的软件解决方案&#xff0c;可用于创作、运营和变现任何实时互动的2D和3D内容&#xff0c;支持平台包括手机、…

构造有向(无向)加权图

邻接表的一般构造 #include<bits/stdc.h> #define N 1e4 using namespace std; typedef struct BP{ int P;//边所指的顶点位置 struct BP *nextB;//指向下一条边的指针 int Q;//储存边的信息 }BP; typedef struct DP{ int date;//顶点信息 BP *FirstB;//指向第一条连接…