Background information:
I have written a TensorFlow model very similar to the premade iris classification model provided by TensorFlow. The differences are relatively minor:
- I am classifying football exercises, not iris species.
- I have 10 features and one label, not 4 features and one label.
- I have 5 different exercises, as opposed to 3 iris species.
- My trainData contains around 3500 rows, not only 120.
- My testData contains around 330 rows, not only 30.
- I am using a DNN classifier with n_classes=6, not 3.
I now want to export the model as a .tflite
file. But according to the TensorFlow Developer Guide, I need to first export the model to a tf.GraphDef
file, then freeze it and only then will I be able to convert it. However, the tutorial provided by TensorFlow to create a .pb
file from a custom model only seems to be optimized for image classification models.
So how do I convert a model like the iris classification example model into a .tflite
file? Is there an easier, more direct way to do it, without having to export it to a .pb
file, then freeze it and so on? An example based on the iris classification code or a link to a more explicit tutorial would be very useful!
Other information:
- OS: macOS 10.13.4 High Sierra
- TensorFlow Version: 1.8.0
- Python Version: 3.6.4
- Using PyCharm Community 2018.1.3
The iris classification code can be cloned by entering the following command:
git clone
But in case you don't want to download the whole package, here it is:
This is the classifier file called
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""An Example of a DNNClassifier for the Iris dataset."""from __future__ import absolute_importfrom __future__ import divisionfrom __future__ import print_functionimport argparseimport tensorflow as tfimport iris_dataparser = argparse.ArgumentParser()parser.add_argument('--batch_size', default=100, type=int, help='batch size')parser.add_argument('--train_steps', default=1000, type=int,help='number of training steps')def main(argv):args = parser.parse_args(argv[1:])# Fetch the data(train_x, train_y), (test_x, test_y) = iris_data.load_data()# Feature columns describe how to use the input.my_feature_columns = []for key in train_x.keys():my_feature_columns.append(tf.feature_column.numeric_column(key=key))# Build 2 hidden layer DNN with 10, 10 units respectively.classifier = tf.estimator.DNNClassifier(feature_columns=my_feature_columns,# Two hidden layers of 10 nodes each.hidden_units=[10, 10],# The model must choose between 3 classes.n_classes=3)# Train the Model.classifier.train(input_fn=lambda: iris_data.train_input_fn(train_x, train_y,args.batch_size),steps=args.train_steps)# Evaluate the model.eval_result = classifier.evaluate(input_fn=lambda: iris_data.eval_input_fn(test_x, test_y,args.batch_size))print('\nTest set accuracy: {accuracy:0.3f}\n'.format(**eval_result))# Generate predictions from the modelexpected = ['Setosa', 'Versicolor', 'Virginica']predict_x = {'SepalLength': [5.1, 5.9, 6.9],'SepalWidth': [3.3, 3.0, 3.1],'PetalLength': [1.7, 4.2, 5.4],'PetalWidth': [0.5, 1.5, 2.1],}predictions = classifier.predict(input_fn=lambda: iris_data.eval_input_fn(predict_x,labels=None,batch_size=args.batch_size))template = '\nPrediction is "{}" ({:.1f}%), expected "{}"'for pred_dict, expec in zip(predictions, expected):class_id = pred_dict['class_ids'][0]probability = pred_dict['probabilities'][class_id]print(template.format(iris_data.SPECIES[class_id],100 * probability, expec))if __name__ == '__main__':# tf.logging.set_verbosity(tf.logging.INFO)
And this is the data file called
import pandas as pdimport tensorflow as tfTRAIN_URL = ""TEST_URL = ""CSV_COLUMN_NAMES = ['SepalLength', 'SepalWidth','PetalLength', 'PetalWidth', 'Species']SPECIES = ['Setosa', 'Versicolor', 'Virginica']def maybe_download():train_path = tf.keras.utils.get_file(TRAIN_URL.split('/')[-1], TRAIN_URL)test_path = tf.keras.utils.get_file(TEST_URL.split('/')[-1], TEST_URL)return train_path, test_pathdef load_data(y_name='Species'):"""Returns the iris dataset as (train_x, train_y), (test_x, test_y)."""train_path, test_path = maybe_download()train = pd.read_csv(train_path, names=CSV_COLUMN_NAMES, header=0)train_x, train_y = train, train.pop(y_name)test = pd.read_csv(test_path, names=CSV_COLUMN_NAMES, header=0)test_x, test_y = test, test.pop(y_name)return (train_x, train_y), (test_x, test_y)def train_input_fn(features, labels, batch_size):"""An input function for training"""# Convert the inputs to a Dataset.dataset =, labels))# Shuffle, repeat, and batch the examples.dataset = dataset.shuffle(1000).repeat().batch(batch_size)# Return the dataset.return datasetdef eval_input_fn(features, labels, batch_size):"""An input function for evaluation or prediction"""features = dict(features)if labels is None:# No labels, use only features.inputs = featureselse:inputs = (features, labels)# Convert the inputs to a Dataset.dataset = Batch the examplesassert batch_size is not None, "batch_size must not be None"dataset = dataset.batch(batch_size)# Return the dataset.return dataset
** UPDATE **
Ok so I have found a seemingly very useful piece of code on this page:
import tensorflow as tfimg = tf.placeholder(name="img", dtype=tf.float32, shape=(1, 64, 64, 3))val = img + tf.constant([1., 2., 3.]) + tf.constant([1., 4., 4.])out = tf.identity(val, name="out")with tf.Session() as sess:tflite_model = tf.contrib.lite.toco_convert(sess.graph_def, [img], [out])open("test.tflite", "wb").write(tflite_model)
This little guy directly converts a simple model to a TensorFlow Lite Model. Now all I have to do is find a way to adapt this to the iris classification model. Any suggestions?