Since 2012, the amount of compute in the largest machine learning training runs has been doubling every three and a half months: the deployment of the state-of-the-art developments is expensive. With the slow, hard to scale inference, compression techniques like knowledge distillation become key in realizing the full potential of these models on low-power platforms.
Knowledge distillation helps reduce the lag and size of applications from lane detection and steering to semantic segmentation and pose estimation by teaching the lighter model mimic the output and internal data representations of the parent model.
The original goal was to add five new low-power, high-accuracy applications controlled via an interactive dashboard, providing a proof of concept for the mobile-first paradigm of machine learning.
After the sync with the core TensorFlow.js team, the project scope then was
narrowed down to align with the promise of
module is similar in form to
tfjs-models: both show the capabilities of TensorFlow.js with
demos and implement overlapping models. Nevertheless, they are different in
spirit: while the
repo serves as a convenient toolbox providing off-the-shelf building blocks,
repo is a starter kit for personal projects.
The model garden cannot capture all of the use cases, striving instead to make fine-tuning as easy as possible. Since some models, like MobileNet, are more suitable for re-training than others, like Pix2Pix, the ones with the larger minimum viable audience are given priority.
DeepLab assigns a semantic label (a human, road, Harley-Davidson, and so on) to each pixel of the input image. Three types of pre-trained models are available:
Despite the fact that CityScapes recognizes the least number of objects, it is the slowest and most compute-intensive model of all three variants. This is a known bug which is yet to be resolved.
identification of healthy and cancerous cells from CT scans
- geographical: early detection of forest fires
- artistic: low-cost CGI overlays
EfficientNet classifies images into 1000 ImageNet classes, building on the success of MobileNet.
Despite the greater accuracy than
other alternatives focusing on mobile platforms and a lot of excitement around
its release, the model did not pass the quality assurance tests of
tfjs-models: even B0, the lightest variant of EfficientNet, is
slower in-browser than
When the full WebGPU API support comes to TensorFlow.js, a factor of magnitude improvements in performance might allow offering heavier, more accurate models as a viable alternative to the existing solutions.
Converting EfficientNet from the pre-trained checkpoint revealed a
tfjs-converter, brought by breaking changes associated with the
upcoming TF 2.0 release. This might have been resolved by the recent updates,
but further testing is required to identify the source of the issue.
automated sorting of the produce into grades
- retail: visual search of similar products
- marketing: adaptive ads reacting to customers wearing specific brands
PSENet detects text by first feeding the image through feature pyramid network extracting features from the image classifier lacking the top dense and activation layers, applying the progressive scale expansion algorithm to extract pixels that most likely correspond to text regions, separating them into distinct components, and then reducing the components to bounding boxes.
Two non-trivial post-processing methods are available out of the box:
The progressive scale expansion algorithm is written in pure TypeScript to take advantage of the JS engine optimizations and avoid the overhead associated with the TensorFlow.js implementation details.
The model is available in two variants:
Ported using the
by Michael Liu
The GIF above demonstrates this model.
Check out the commit
to load the appropriate weights together with corresponding pre-processing
and post-processing methods:
The model size is 115 MB non-quantized, 59 MB quantized to 2 bytes, and 29 MB quantized to 1 byte, while the inference time is 7-10 seconds on average.
# cd tfjs-models git checkout 4d963c4
Adapted from the PyTorch implementation by the original authors
This is the primary supported variant.
Since the inference time and model size disqualified the vanilla PSENet
from the model garden, the second part of GSoC focused on optimizing text
detection for mobile inference. The PyTorch implementation of the model
from the original authors promised to improve the quality of predictions
with 3 major differences in the approach from the
The switch to a more lightweight FPN with MobileNet, not ResNet 50 as the backbone, reduced the raw model size by the factor of magnitude, from 115 MB to 16 MB.
After 185 attempts on the AI Platform to make training work, the results, however, looked more than strange. Despite the train loss and metrics improving as expected, validation results were disappointing, showing that 30 to 40 epochs were not enough to learn even the simplest of examples. At this point, the realization came that the problem was much deeper down the stack.
|Sample input||Sample label|
tf.estimator-based setup and resolved for the final, 186th AI
Platform job by re-writing the pipeline using only the machinery of
Despite the weight improvements and good validation performance (0.98
accuracy, 0.99 precision, 0.99 F1-score on 2000 images), several pipeline
design decisions prevented the new model from beating the results of the
Some examples behave well...
And some miss important features...
While others show the imperfections of training beyond any doubt.
Despite these hurdles, PSENet offers a promising approach for in-browser text detection, and the TensorFlow.js port will be finalized when the weights are updated.
- privacy: masking of sensitive text information from photos
- education: extraction of text written on whiteboards
- knowledge management: detecting key parts of architectural drawings for automated annotation
Adopting the following heuristics would have helped to avoid a lot of the timesinks in the development process:
Check the training setup early
Overfitting on a single batch is a simple and effective way to spot problems early, since they may come from unexpected places not necessarily reflected in stack traces or loss anomalies.
Andrej Karpathy gives this and other advice in his recipe on training neural networks.
Maintain a reproducible development environment
The code with a lot of dependencies breaks often, and if it works fine now, it might be impossible to say how in two weeks. Freezing dependencies in the virtual environment after successful deployment reduces the pain of starting anew for someone else (which might as well be yourself).
Take a chance on the bleeding edge software before the major release
Upgrading to the TensorFlow 2.0 resolved cryptic problems with
models and could have eliminated the issues with
tf.estimator, which had features propagated down from the TF
2.0 beta releases.
Other ideas for growing the model garden are collected in the scratchpad.
Porting the models called for contributions to the TensorFlow.js ecosystem and beyond.
The list below highlights all of them.
tf.dataon AI Platform
I am grateful to the following amazing people for the gift of a valuable learning experience that GSoC has become, helping me grow into a pro and making the world a better place: