Neural Networks in pure JAX (with automatic differentiation)

Описание к видео Neural Networks in pure JAX (with automatic differentiation)

(Reverse-mode) automatic differentiation is the secret sauce of deep learning, allowing to differentiate almost arbitrary neural architectures. Let's use the abstractions of the JAX DL framework in Python to implement a simple MLP. Here is the code: https://github.com/Ceyron/machine-lea...

-------

👉 This educational series is supported by the world-leaders in integrating machine learning and artificial intelligence with simulation and scientific computing, Pasteur Labs and Institute for Simulation Intelligence. Check out https://simulation.science/ for more on their pursuit of 'Nobel-Turing' technologies (https://arxiv.org/abs/2112.03235 ), and for partnership or career opportunities.

-------

📝 : Check out the GitHub Repository of the channel, where I upload all the handwritten notes and source-code files (contributions are very welcome): https://github.com/Ceyron/machine-lea...

📢 : Follow me on LinkedIn or Twitter for updates on the channel and other cool Machine Learning & Simulation stuff:   / felix-koehler   and   / felix_m_koehler  

💸 : If you want to support my work on the channel, you can become a Patreon here:   / mlsim  

🪙: Or you can make a one-time donation via PayPal: https://www.paypal.com/paypalme/Felix...

-------

Timestamps:
00:00 Intro
01:18 Dataset that somehow looks like a sine function
01:56 Forward pass of the Multilayer Perceptron
03:22 Weight initialization due to Xavier Glorot
04:20 Idea of "Learning" as approximate optimization
04:49 Reverse-mode autodiff requires us to only write the forward pass
05:34 Imports
05:52 Constants and Hyperparameters
06:19 Producing the random toy dataset
08:33 Draw initial parameter guesses
12:05 Implementing the forward/primal pass
13:58 Implementing the loss metric
14:57 Transform forward pass to get gradients by autodiff
20:03 Training loop (using plain gradient descent)
23:21 Improving training speed by JIT compilation
24:25 Plotting loss history
24:47 Plotting final network prediction & Discussion
25:44 Summary
26:59 Outro

Комментарии

Информация по комментариям в разработке