1- import sys
2-
3- try :
4- import neural_compressor as inc
5- print ("neural_compressor version {}" .format (inc .__version__ ))
6- except :
7- try :
8- import lpot as inc
9- print ("LPOT version {}" .format (inc .__version__ ))
10- except :
11- import ilit as inc
12- print ("iLiT version {}" .format (inc .__version__ ))
13-
14- if inc .__version__ == '1.2' :
15- print ("This script doesn't support LPOT 1.2, please install LPOT 1.1, 1.2.1 or newer" )
16- sys .exit (1 )
1+ import neural_compressor as inc
2+ print ("neural_compressor version {}" .format (inc .__version__ ))
173
184import alexnet
195import math
6+ import yaml
207import mnist_dataset
8+ from neural_compressor .quantization import fit
9+ from neural_compressor .config import PostTrainingQuantConfig , TuningCriterion , AccuracyCriterion
2110
2211
2312def save_int8_frezon_pb (q_model , path ):
2413 from tensorflow .python .platform import gfile
2514 f = gfile .GFile (path , 'wb' )
26- f .write (q_model .as_graph_def ().SerializeToString ())
15+ f .write (q_model .graph . as_graph_def ().SerializeToString ())
2716 print ("Save to {}" .format (path ))
2817
2918
@@ -44,23 +33,31 @@ def __iter__(self):
4433 yield x_test [begin :], label_test [begin :]
4534
4635
47- def auto_tune (input_graph_path , yaml_config , batch_size ):
36+ def auto_tune (input_graph_path , config , batch_size ):
4837 fp32_graph = alexnet .load_pb (input_graph_path )
49- quan = inc .Quantization (yaml_config )
5038 dataloader = Dataloader (batch_size )
51-
52- q_model = quan (
53- fp32_graph ,
54- q_dataloader = dataloader ,
55- eval_func = None ,
56- eval_dataloader = dataloader )
39+ assert (dataloader )
40+
41+ tuning_criterion = TuningCriterion (** config ["tuning_criterion" ])
42+ accuracy_criterion = AccuracyCriterion (** config ["accuracy_criterion" ])
43+ q_model = fit (
44+ model = input_graph_path ,
45+ conf = PostTrainingQuantConfig (** config ["quant_config" ],
46+ tuning_criterion = tuning_criterion ,
47+ accuracy_criterion = accuracy_criterion ,
48+ ),
49+ calib_dataloader = dataloader ,
50+ )
5751 return q_model
5852
5953
60- yaml_file = "alexnet.yaml"
6154batch_size = 200
6255fp32_frezon_pb_file = "fp32_frezon.pb"
6356int8_pb_file = "alexnet_int8_model.pb"
6457
65- q_model = auto_tune (fp32_frezon_pb_file , yaml_file , batch_size )
58+ with open ("quant_config.yaml" ) as f :
59+ config = yaml .safe_load (f .read ())
60+ config
61+
62+ q_model = auto_tune (fp32_frezon_pb_file , config , batch_size )
6663save_int8_frezon_pb (q_model , int8_pb_file )
0 commit comments