In this post I want to explore some of the key similarities and differences between two popular deep learning frameworks: PyTorch and TensorFlow. Why those two and not the others? There are many deep learning frameworks and many of them are viable tools, I chose those two just because I was interested in comparing them specifically.
Origins
TensorFlow is developed by Google Brain and actively used at Google both for research and production needs. Its closed-source predecessor is called DistBelief.
PyTorch is a cousin of lua-based Torch framework which is actively used at Facebook. However, PyTorch is not a simple set of wrappers to support popular language, it was rewritten and tailored to be fast and feel native.
The best way to compare two frameworks is to code something up in both of them. I’ve written a companion jupyter notebook for this post and you can get it here. All code will be provided in the post too.
First, let’s code a simple approximator for the following function in both frameworks:
We will try to find unknown parameter phi given data x and function values f(x). Yes, using stochastic gradient descent for this is an overkill and analytical solution may be found easily, but this problem will serve our purpose well as a simple example.
We will solve this with PyTorch first:
If you have some experience in deep learning frameworks you may have noticed that we are implementing gradient descent by hand. Not very convenient, huh? Gladly, PyTorch has
optimize
module which contains implementations of popular optimization algorithms such as RMSProp or Adam. We will use SGD with momentum
As you can see, we quickly inferred true exponent from training data. And now let’s go on with TensorFlow:
As you can see, implementation in TensorFlow works too (surprisingly 🙃). It took more iterations to recover the exponent, but I am sure that the cause is I did not fiddle with optimiser’s parameters enough to reach comparable results.
Now we are ready to explore some differences.
Difference #0 — adoption
Currently, TensorFlow is considered as a to-go tool by many researchers and industry professionals. The framework is well documented and if the documentation will not suffice there are many extremely well-written tutorials on the internet. You can find hundreds of implemented and trained models on github, start here.
PyTorch is relatively new compared to its competitor (and is still in beta), but it is quickly getting its momentum. Documentation and official tutorials are also nice. PyTorch also include several implementations of popular computer vision architectures which are super-easy to use.
Difference #1 — dynamic vs static graph definition
Both frameworks operate on tensors and view any model as a directed acyclic graph (DAG), but they differ drastically on how you can define them.
TensorFlow follows ‘data as code and code is data’ idiom. In TensorFlow you define graph statically before a model can run. All communication with outer world is performed via
tf.Session
object and tf.Placeholder
which are tensors that will be substituted by external data at runtime.
In PyTorch things are way more imperative and dynamic: you can define, change and execute nodes as you go, no special session interfaces or placeholders. Overall, the framework is more tightly integrated with Python language and feels more native most of the times. When you write in TensorFlow sometimes you feel that your model is behind a brick wall with several tiny holes to communicate over. Anyways, this still sounds like a matter of taste more or less.
However, those approaches differ not only in a software engineering perspective: there are several dynamic neural network architectures that can benefit from the dynamic approach. Recall RNNs: with static graphs, the input sequence length will stay constant. This means that if you develop a sentiment analysis model for English sentences you must fix the sentence length to some maximum value and pad all smaller sequences with zeros. Not too convenient, huh. And you will get more problems in the domain of recursive RNNs and tree-RNNs. Currently Tensorflow has limited support for dynamic inputs via Tensorflow Fold. PyTorch has it by-default.
Difference #2 — Debugging
Since computation graph in PyTorch is defined at runtime you can use our favorite Python debugging tools such as pdb, ipdb, PyCharm debugger or old trusty print statements.
This is not the case with TensorFlow. You have an option to use a special tool called tfdbg which allows to evaluate tensorflow expressions at runtime and browse all tensors and operations in session scope. Of course, you won’t be able to debug any python code with it, so it will be necessary to use pdb separately.
Tensorboard is awesome when it comes to visualization 😎. This tool comes with TensorFlow and it is very useful for debugging and comparison of different training runs. For example, consider you trained a model, then tuned some hyperparameters and trained it again. Both runs can be displayed at Tensorboard simultaneously to indicate possible differences. Tensorboard can:
- Display model graph
- Plot scalar variables
- Visualize distributions and histograms
- Visualize images
- Visualize embeddings
- Play audio
Tensorboard can display various summaries which can be collected via
tf.summary
module. We will define summary operations for our toy exponent example and use tf.summary.FileWriter
to save them to disk.
To launch Tensorboard execute
tensorboard --logdir=./tensorboard
. This tool is very convenient to use on cloud instances since it is a webapp.
Currently, PyTorch has no equivalent for Tensorboard, however, integrations do exist. Otherwise, you are free to use standard plotting tools — matplotlib and seaborn.
Difference #4 — Deployment
If we start talking about deployment TensorFlow is a clear winner for now: is has TensorFlow Serving which is a framework to deploy your models on a specialized gRPC server. Mobile is also supported.
When we switch back to PyTorch we may use Flask or another alternative to code up a REST API on top of the model. This could be done with TensorFlow models as well if gRPC is not a good match for your usecase. However, TensorFlow Serving may be a better option if performance is a concern.
Tensorflow also supports distributed training which PyTorch lacks for now.
Difference #5 — A Framework or a library
Let’s build a CNN classifier for handwritten digits. Now PyTorch will really start to look like a framework. Recall that a programming framework gives us useful abstractions in certain domain and a convenient way to use them to solve concrete problems. That is the essence that separates a framework from a library.
Here we introduce
datasets
module which contains wrappers for popular datasets used to benchmark deep learning architectures. Also nn.Module
is used to build a custom convolutional neural network classifier. nn.Module
is a building block PyTorch gives us to create complex deep learning architectures. There are large amounts of ready to use modules in torch.nn
package that we can use as a base for our model. Notice how PyTorch uses object oriented approach to define basic building blocks and give us some 'rails' to move on while providing ability to extend functionality via subclassing.
Here goes a slightly modified version of https://github.com/pytorch/examples/blob/master/mnist/main.py:
Plain TensorFlow feels a lot more like a library rather than a framework: all operations are pretty low-level and you will need to write lots of boilerplate code even when you might not want to (let’s define those biases and weights again and again and …).
As the time as passed a whole ecosystem of high-level wrappers started to emerge around TensorFlow. Each of those aims to simplify the way you work with the library. Many of them are currently located at
tensorflow.contrib
module (which is not considered a stable API) and some started to migrate to the main repository (see tf.layers
).
So, you have a lot of freedom on how to use TensorFlow and what framework will suit the task best: TFLearn, tf.contrib.learn, Sonnet, Keras, plain
tf.layers
, etc. To be honest, Keras deserves another post but is currently out of the scope of this comparison.
Here we will use
tf.layers
and tf.contrib.learn
to build our CNN classifier. The code follows the official tutorial on tf.layers:
So, both TensorFlow and PyTorch provide useful abstractions to reduce amounts of boilerplate code and speed up model development. The main difference between them is that PyTorch may feel more “pythonic” and has an object-oriented approach while TensorFlow has several options from which you may choose.
Personally, I consider PyTorch to be more clear and developer-friendly. It’s
torch.nn.Module
gives you the ability to define reusable modules in an OOP manner and I find this approach very flexible and powerful. Later you can compose all kind of modules via torch.nn.Sequential
(hi Keras ✋🏻). Also, you have all built-in modules in a functional form, which can be very convenient. Overall, all parts of the API play well together.
Of course, you can write very clean code in plain TensorFlow but it just takes more skill and trial-and-error before you get it. When it goes to higher-level frameworks such as Keras or TFLearn get ready to lose at least some of the flexibility TensorFlow has to offer.
Conclusion
TensorFlow is very powerful and mature deep learning library with strong visualization capabilities and several options to use for high-level model development. It has production-ready deployment options and support for mobile platforms. TensorFlow is a good option if you:
- Develop models for production
- Develop models which need to be deployed on mobile platforms
- Want good community support and comprehensive documentation
- Want rich learning resources in various forms (TensorFlow has an an entire MOOC)
- Want or need to use Tensorboard
- Need to use large-scale distributed model training
PyTorch is still a young framework which is getting momentum fast. You may find it a good fit if you:
- Do research or your production non-functional requirements are not very demanding
- Want better development and debugging experience
- Love all things Pythonic
If you have the time the best advice would be to try both and see what fits your needs best.
Thanks for sharing this pretty post, it was good and helpful. Share more like this.
ReplyDeletePython Training in Chennai
ccna course in Chennai
R Programming Training in Chennai
AWS Training in Chennai
DevOps Training in Chennai
Angularjs Training in Chennai
RPA Training in Chennai
Blue Prism Training in Chennai
Data Science Course in Chennai
Very nice post. You can also get the best training from
ReplyDeleteMachine Learning training in Pallikranai Chennai
Machine Learning training in Medavakkam Chennai"
Pytorch training in Pallikaranai chennai
Data science training in Medavakkam Chennai"
Such a good post. Thanks for sharing with us.
ReplyDeleteMachine Learning training in Pallikranai Chennai
Machine Learning training in Medavakkam Chennai"
Pytorch training in Pallikaranai chennai
Data science training in Medavakkam Chennai"
The development of artificial intelligence (AI) has propelled more programming architects, information scientists, and different experts to investigate the plausibility of a vocation in machine learning. Notwithstanding, a few newcomers will in general spotlight a lot on hypothesis and insufficient on commonsense application. Machine Learning Final Year Projects In case you will succeed, you have to begin building machine learning projects in the near future.
ReplyDeleteProjects assist you with improving your applied ML skills rapidly while allowing you to investigate an intriguing point. Furthermore, you can include projects into your portfolio, making it simpler to get a vocation, discover cool profession openings, and Final Year Project Centers in Chennai even arrange a more significant compensation.
Data analytics is the study of dissecting crude data so as to make decisions about that data. Data analytics advances and procedures are generally utilized in business ventures to empower associations to settle on progressively Python Training in Chennai educated business choices. In the present worldwide commercial center, it isn't sufficient to assemble data and do the math; you should realize how to apply that data to genuine situations such that will affect conduct. In the program you will initially gain proficiency with the specialized skills, including R and Python dialects most usually utilized in data analytics programming and usage; Python Training in Chennai at that point center around the commonsense application, in view of genuine business issues in a scope of industry segments, for example, wellbeing, promoting and account.
Very informative post. Keep up the good work. I would really look forward to your other posts
ReplyDeleteCyber Security Training Course in Chennai | Certification | Cyber Security Online Training Course | Ethical Hacking Training Course in Chennai | Certification | Ethical Hacking Online Training Course | CCNA Training Course in Chennai | Certification | CCNA Online Training Course | RPA Robotic Process Automation Training Course in Chennai | Certification | RPA Training Course Chennai | SEO Training in Chennai | Certification | SEO Online Training Course