MNIST Classification with Dual LPHM Reconciliation Function
As introduced in the Quickstart tutorial, Dual LPHM reconciliation function can reduce the number of learnable parameters in the RPN model a lot.
In this example, we will re-design the 3-layer RPN model built in the previous example by replacing the identity_reconciliation
with the dual_lphm_reconciliation
function for the MNIST image data classification.
It will dramatically reduce the learnable parameter numbers from 39696000
to 9330
, which will greatly save the memory
space required for storing the parameters and their gradients in learning.
According to the experimental evaluation, RPN with dual_lphm_reconciliation
achieves a descent testing accuracy about 0.9810
.
The testing can actually be further improved by tuning the rank parameter r
of the dual_lphm_reconciliation
function.
We will still use mps
as the device in the config file.
Python Code and Model Configurations
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
|
rpn with identity reconciliation for mnist classification output
training model...
100%|██████████| 938/938 [00:32<00:00, 29.12it/s, epoch=0/20, loss=0.154, lr=0.002, metric_score=0.906, time=32.2]
Epoch: 0, Test Loss: 0.1287281839414292, Test Score: 0.9614, Time Cost: 3.4373199939727783
100%|██████████| 938/938 [00:32<00:00, 29.25it/s, epoch=1/20, loss=0.103, lr=0.0018, metric_score=0.969, time=67.7]
Epoch: 1, Test Loss: 0.10673198028865039, Test Score: 0.969, Time Cost: 3.2739510536193848
100%|██████████| 938/938 [00:31<00:00, 30.01it/s, epoch=2/20, loss=0.0153, lr=0.00162, metric_score=1, time=102]
Epoch: 2, Test Loss: 0.09450612840640127, Test Score: 0.9715, Time Cost: 3.2066781520843506
100%|██████████| 938/938 [00:31<00:00, 30.25it/s, epoch=3/20, loss=0.149, lr=0.00146, metric_score=0.938, time=136]
Epoch: 3, Test Loss: 0.09049446262555659, Test Score: 0.9724, Time Cost: 3.3913190364837646
100%|██████████| 938/938 [00:31<00:00, 29.56it/s, epoch=4/20, loss=0.0124, lr=0.00131, metric_score=1, time=172]
Epoch: 4, Test Loss: 0.08241389291420294, Test Score: 0.9754, Time Cost: 3.3962209224700928
100%|██████████| 938/938 [00:32<00:00, 29.08it/s, epoch=5/20, loss=0.357, lr=0.00118, metric_score=0.875, time=207]
Epoch: 5, Test Loss: 0.07358021313060598, Test Score: 0.9781, Time Cost: 3.2792601585388184
100%|██████████| 938/938 [00:32<00:00, 28.77it/s, epoch=6/20, loss=0.0193, lr=0.00106, metric_score=1, time=243]
Epoch: 6, Test Loss: 0.07322978683759215, Test Score: 0.9792, Time Cost: 3.478147029876709
100%|██████████| 938/938 [00:32<00:00, 29.27it/s, epoch=7/20, loss=0.012, lr=0.000957, metric_score=1, time=279]
Epoch: 7, Test Loss: 0.08070176002194299, Test Score: 0.9786, Time Cost: 3.3920912742614746
100%|██████████| 938/938 [00:31<00:00, 29.89it/s, epoch=8/20, loss=0.0161, lr=0.000861, metric_score=1, time=313]
Epoch: 8, Test Loss: 0.07754188040375211, Test Score: 0.9802, Time Cost: 3.2393269538879395
100%|██████████| 938/938 [00:30<00:00, 30.34it/s, epoch=9/20, loss=0.0198, lr=0.000775, metric_score=1, time=348]
Epoch: 9, Test Loss: 0.07705280321200117, Test Score: 0.9804, Time Cost: 3.2612552642822266
100%|██████████| 938/938 [00:31<00:00, 30.11it/s, epoch=10/20, loss=0.00647, lr=0.000697, metric_score=1, time=382]
Epoch: 10, Test Loss: 0.08428076295295929, Test Score: 0.9782, Time Cost: 3.4526174068450928
100%|██████████| 938/938 [00:31<00:00, 29.53it/s, epoch=11/20, loss=0.283, lr=0.000628, metric_score=0.969, time=417]
Epoch: 11, Test Loss: 0.07899063697800701, Test Score: 0.9814, Time Cost: 3.278550148010254
100%|██████████| 938/938 [00:31<00:00, 30.20it/s, epoch=12/20, loss=0.00808, lr=0.000565, metric_score=1, time=452]
Epoch: 12, Test Loss: 0.08305486166922214, Test Score: 0.9798, Time Cost: 3.269629955291748
100%|██████████| 938/938 [00:31<00:00, 29.41it/s, epoch=13/20, loss=0.00124, lr=0.000508, metric_score=1, time=487]
Epoch: 13, Test Loss: 0.0823607053136387, Test Score: 0.9806, Time Cost: 3.2996010780334473
100%|██████████| 938/938 [00:31<00:00, 29.90it/s, epoch=14/20, loss=0.0641, lr=0.000458, metric_score=0.969, time=521]
Epoch: 14, Test Loss: 0.08280810932788232, Test Score: 0.981, Time Cost: 3.3404059410095215
100%|██████████| 938/938 [00:32<00:00, 28.97it/s, epoch=15/20, loss=0.00654, lr=0.000412, metric_score=1, time=557]
Epoch: 15, Test Loss: 0.0896499605672028, Test Score: 0.9809, Time Cost: 3.358131170272827
100%|██████████| 938/938 [00:31<00:00, 29.48it/s, epoch=16/20, loss=0.000128, lr=0.000371, metric_score=1, time=592]
Epoch: 16, Test Loss: 0.08846969538192688, Test Score: 0.9812, Time Cost: 3.2872140407562256
100%|██████████| 938/938 [00:30<00:00, 30.36it/s, epoch=17/20, loss=0.002, lr=0.000334, metric_score=1, time=627]
Epoch: 17, Test Loss: 0.09712753198534886, Test Score: 0.9805, Time Cost: 3.468175172805786
100%|██████████| 938/938 [00:31<00:00, 30.07it/s, epoch=18/20, loss=0.00103, lr=0.0003, metric_score=1, time=661]
Epoch: 18, Test Loss: 0.10845135996438492, Test Score: 0.9795, Time Cost: 3.36928391456604
100%|██████████| 938/938 [00:30<00:00, 30.94it/s, epoch=19/20, loss=0.00051, lr=0.00027, metric_score=1, time=695]
Epoch: 19, Test Loss: 0.1052672725711357, Test Score: 0.9807, Time Cost: 3.479506254196167
model checkpoint saving to ./ckpt/mnist_dual_lphm_configs_checkpoint...
evaluating result...
accuracy 0.9807