MNIST Image Classification
In this example, we will build a 3-layer RPN model with taylor_expansion
, identity_reconciliation
and zero_remainder
functions to classify the MNIST dataset of hand-written digit images.
We will use mps
as the device in the config file, and you can change it according to your machine before running the script code.
Python Code and Model Configurations
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
|
rpn with identity reconciliation for mnist classification output
training model...
100%|██████████| 938/938 [00:36<00:00, 25.44it/s, epoch=0/25, loss=0.0452, lr=0.002, metric_score=1, time=36.9]
Epoch: 0, Test Loss: 0.14725748572838165, Test Score: 0.9555, Time Cost: 3.299213171005249
100%|██████████| 938/938 [00:35<00:00, 26.08it/s, epoch=1/25, loss=0.0666, lr=0.0018, metric_score=0.969, time=76.2]
Epoch: 1, Test Loss: 0.08832730701012288, Test Score: 0.9717, Time Cost: 3.023775100708008
100%|██████████| 938/938 [00:36<00:00, 25.47it/s, epoch=2/25, loss=0.0129, lr=0.00162, metric_score=1, time=116]
Epoch: 2, Test Loss: 0.08240678668799617, Test Score: 0.9765, Time Cost: 3.0006258487701416
100%|██████████| 938/938 [00:36<00:00, 25.85it/s, epoch=3/25, loss=0.000843, lr=0.00146, metric_score=1, time=155]
Epoch: 3, Test Loss: 0.09966024028365429, Test Score: 0.9731, Time Cost: 3.0690932273864746
100%|██████████| 938/938 [00:35<00:00, 26.38it/s, epoch=4/25, loss=0.000343, lr=0.00131, metric_score=1, time=194]
Epoch: 4, Test Loss: 0.08925511088793404, Test Score: 0.9739, Time Cost: 3.025567054748535
100%|██████████| 938/938 [00:35<00:00, 26.17it/s, epoch=5/25, loss=0.0211, lr=0.00118, metric_score=1, time=233]
Epoch: 5, Test Loss: 0.11491756975460037, Test Score: 0.9699, Time Cost: 3.1549229621887207
100%|██████████| 938/938 [00:35<00:00, 26.25it/s, epoch=6/25, loss=0.129, lr=0.00106, metric_score=0.969, time=272]
Epoch: 6, Test Loss: 0.09543848116054737, Test Score: 0.9765, Time Cost: 3.0333187580108643
100%|██████████| 938/938 [00:35<00:00, 26.34it/s, epoch=7/25, loss=0.0192, lr=0.000957, metric_score=1, time=310]
Epoch: 7, Test Loss: 0.06982691217252265, Test Score: 0.9811, Time Cost: 3.0253307819366455
100%|██████████| 938/938 [00:35<00:00, 26.46it/s, epoch=8/25, loss=0.000674, lr=0.000861, metric_score=1, time=349]
Epoch: 8, Test Loss: 0.10375890898652708, Test Score: 0.9732, Time Cost: 3.0132219791412354
100%|██████████| 938/938 [00:35<00:00, 26.44it/s, epoch=9/25, loss=0.00115, lr=0.000775, metric_score=1, time=387]
Epoch: 9, Test Loss: 0.08423868822431006, Test Score: 0.979, Time Cost: 3.0301530361175537
100%|██████████| 938/938 [00:35<00:00, 26.30it/s, epoch=10/25, loss=0.0073, lr=0.000697, metric_score=1, time=426]
Epoch: 10, Test Loss: 0.09018090074097593, Test Score: 0.9792, Time Cost: 3.027726173400879
100%|██████████| 938/938 [00:35<00:00, 26.46it/s, epoch=11/25, loss=0.00995, lr=0.000628, metric_score=1, time=465]
Epoch: 11, Test Loss: 0.09117337604856153, Test Score: 0.978, Time Cost: 3.0312867164611816
100%|██████████| 938/938 [00:35<00:00, 26.36it/s, epoch=12/25, loss=0.000759, lr=0.000565, metric_score=1, time=503]
Epoch: 12, Test Loss: 0.11916581861087498, Test Score: 0.9772, Time Cost: 3.162600040435791
100%|██████████| 938/938 [00:35<00:00, 26.28it/s, epoch=13/25, loss=0.102, lr=0.000508, metric_score=0.969, time=542]
Epoch: 13, Test Loss: 0.09118759378719253, Test Score: 0.9824, Time Cost: 2.988072156906128
100%|██████████| 938/938 [00:35<00:00, 26.11it/s, epoch=14/25, loss=0.0461, lr=0.000458, metric_score=0.969, time=581]
Epoch: 14, Test Loss: 0.0869425951682757, Test Score: 0.9803, Time Cost: 3.0128040313720703
100%|██████████| 938/938 [00:35<00:00, 26.35it/s, epoch=15/25, loss=6.94e-6, lr=0.000412, metric_score=1, time=620]
Epoch: 15, Test Loss: 0.08488982132411006, Test Score: 0.9826, Time Cost: 2.9784598350524902
100%|██████████| 938/938 [00:35<00:00, 26.29it/s, epoch=16/25, loss=0.000103, lr=0.000371, metric_score=1, time=658]
Epoch: 16, Test Loss: 0.08816149920134123, Test Score: 0.9823, Time Cost: 2.9910941123962402
100%|██████████| 938/938 [00:35<00:00, 26.33it/s, epoch=17/25, loss=4.12e-5, lr=0.000334, metric_score=1, time=697]
Epoch: 17, Test Loss: 0.10713126366508123, Test Score: 0.9829, Time Cost: 3.158099889755249
100%|██████████| 938/938 [00:35<00:00, 26.09it/s, epoch=18/25, loss=0.00225, lr=0.0003, metric_score=1, time=736]
Epoch: 18, Test Loss: 0.09688288162248873, Test Score: 0.9829, Time Cost: 3.0078930854797363
100%|██████████| 938/938 [00:35<00:00, 26.15it/s, epoch=19/25, loss=3.83e-6, lr=0.00027, metric_score=1, time=775]
Epoch: 19, Test Loss: 0.11367125700611343, Test Score: 0.9831, Time Cost: 2.995252847671509
100%|██████████| 938/938 [00:35<00:00, 26.37it/s, epoch=20/25, loss=2.62e-5, lr=0.000243, metric_score=1, time=813]
Epoch: 20, Test Loss: 0.11589900395485465, Test Score: 0.9826, Time Cost: 2.9824440479278564
100%|██████████| 938/938 [00:35<00:00, 26.38it/s, epoch=21/25, loss=5.81e-5, lr=0.000219, metric_score=1, time=852]
Epoch: 21, Test Loss: 0.10221088574088256, Test Score: 0.9838, Time Cost: 2.989346742630005
100%|██████████| 938/938 [00:35<00:00, 26.32it/s, epoch=22/25, loss=3.58e-6, lr=0.000197, metric_score=1, time=891]
Epoch: 22, Test Loss: 0.11218179630007304, Test Score: 0.9842, Time Cost: 3.181006908416748
100%|██████████| 938/938 [00:35<00:00, 26.31it/s, epoch=23/25, loss=0.00181, lr=0.000177, metric_score=1, time=929]
Epoch: 23, Test Loss: 0.10169062788332937, Test Score: 0.9857, Time Cost: 3.075958013534546
100%|██████████| 938/938 [00:36<00:00, 26.04it/s, epoch=24/25, loss=1.94e-5, lr=0.00016, metric_score=1, time=969]
Epoch: 24, Test Loss: 0.10714568164599787, Test Score: 0.9855, Time Cost: 3.008065700531006
model checkpoint saving to ./ckpt/mnist_configs_checkpoint...
evaluating result...
accuracy 0.9855