Model Serialization: Lesson Learnt from Tensorflow 1.x and 2.x

.... And Why I'm So Fucked by Tensorflow

Youtube Stream

About Me

  • A Python Developer
  • Interested in machine learning, applied math and its development
  • Core developer of uTensor
  • uTensor

utensor

  • utensor_cgen: code generator for uTensor

utensor-cgen

Model Development and Deployment

  • define the graph
  • training the graph
  • graph transformation: graph rewriting, including quantization, node fusion, node removal, etc.
  • saving the graph: model serialization

Quantization

weight-quantization credit

In [2]:
# Tensorflow 1.x
import tensorflow as tf
from tensorflow import import_graph_def
from tensorflow.tools.graph_transforms import TransformGraph

print(tf.__version__)

graph = tf.Graph()
with graph.as_default():
    with tf.gfile.GFile("simple_model.pb", "rb") as fid:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(fid.read())
    out_tensor, = import_graph_def(
        graph_def,
        return_elements=["y_pred:0"]
    )
out_tensor
1.13.0-rc1
Out[2]:
<tf.Tensor 'import/y_pred:0' shape=(10,) dtype=int64>

float model

simple-model-float

In [3]:
# Quantization in Tensorflow 1.x
quant_graph_def = TransformGraph(
    graph_def,
    inputs=[],
    outputs=["y_pred"],
    transforms=["quantize_weights", "quantize_nodes"]
)

with open('quant_simple_model.pb', 'wb') as fid:
    fid.write(quant_graph_def.SerializeToString())

quantized model

simple-model-quant

dynamic quantization

simple-model-quant-zoom

In [1]:
import tensorflow as tf

print(tf.__version__)
2.3.1
In [ ]:
# Tensorflow 2.x: Tensorflow Lite
model = ... # A tensorflow.keras.Model instance, **trained**
model.save('model_path') # save model, normal keras save/load api

# trainable graph -> constant graph in TF 2.x
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2

model_func = tf.function(lambda x: model(x))
model_func = model_func.get_concrete_function(tf.TensorSpec(...)) # setup the input spec
model_func = convert_variables_to_constants_v2(model_func, lower_control_flow=False)

# save the freezed graph as pb file
with open('const_graph.pb', 'wb') as fid:
    fid.write(model_func.graph.as_graph_def().SerializeToString())

# create a converter which will convert a keras model to tflite flatbuffer
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
# represent_ds is an callable which return a generator that will return representative dataset
converter.representative_dataset = represent_ds
In [ ]:
# tflite_buffer are bytes
tflite_buffer = converter.convert()

with open('model.tflite', 'wb') as fid:
    fid.write(tflite_buffer)

Keras float model

keras-float

Keras quantized model (TFLite)

keras-quant

Graph Rewriting

Implement with isomorphic subgraph matching

Why I'm Sooo Fucked

fucked-by-tf

Inconsistent Operation Name

  • ex: Add vs QuantizedAdd vs AddOp
  • Hard to identify the type of an operation/node in the graph
  • Hard to implement/test isomorphic subgraph matching

Operation name/type legalization is required

Fused Operation

  • ex: MatMul + Add + <activation_func> => FullyConnected
  • Hard to define a generic intermediate representation
    • FullyConnected => MatMul + Add + <activation_func>?
    • MatMul + Add + <activation_func> => FullyConnected?
    • Which is better and why?

Implementation Differences Across Versions

  • Take tf.nn.dropout as example

Dropout in Tensorflow 1.x

dropout-v1

Dropout in Tensorflow 2.x

dropout-v2

Breaking Changes of Frameworks

  • Changes in quantization scheme
    • Dynamic Quantization v.s Static Quantization
    • Quantization-Awared Training
  • Inconsistent Saving/Loading API

Q & A

joker