import {
	Box,
	Heading,
	Kbd,
	Text,
	Image,
	Center,
	VStack,
	HStack,
} from '@chakra-ui/react';
import React, { useEffect } from 'react';
import MathJaxTex from '../../components/MathJaxTex';
import { useLocation, useNavigate } from 'react-router-dom';

const BackPropTheOverlord: React.FC = () => {
	const navigate = useNavigate();

	const { pathname } = useLocation();

	useEffect(() => {
		window.scrollTo(0, 0);
	}, [pathname]);

	return (
		<Box p={5}>
			<Heading>Backpropagation - The search for a new god</Heading>
			<Text>
				Title and blog content may be a little contradicting for some
				readers. If you're thinking, this blog would explain about the
				search for god or godly elements we know/believe, then be
				prepared to be disappointed. Rather, this blog will explain
				about the key ingredient/path to create a new god(s), humans
				would be creating and worshiping in the future aka AGI,
				singularity, blah blah. Enough about the dystopian discussion
				and let's dive into the interesting part of logic/math behind
				the process of creating them or in simple words AI training.
			</Text>
			<Text mt={5}>
				AI is a black box to a certain extent, but many people think AI
				is a total blackbox just because they don’t understand the math
				behind it. People are just too comfortable using highly
				abstracted AI frameworks to build their models and often
				struggle when the model doesn’t perform the way they intend to.
				This blog aims at imparting some intuition about the AI training
				process, using a multi-layer neural network via the math behind
				it aka backpropagation of partial derivatives. Hang on, don’t
				get disheartened with the terminologies. Term/field AI is very
				broad, and artificial neural networks are sub-class of AI we
				hear & use often. Goal of any AI model is to take inputs,
				perform some processing and give an output. Examples include, an
				image of face as input and age/mood as an output; a task of
				writing story as input and a story as output.
			</Text>
			<Text mt={5}>
				To start with, let me introduce a very basic & cliche graphical
				representation of multi-layer neural networks.
			</Text>
			<Center>
				<VStack>
					<Image
						width="250px"
						src="https://pub-17b7496e137e40fcbe7057d6a4735482.r2.dev/backpropagation-the-search-for-a-new-god:1.png"
						alt="backprop-1"
					/>
					<Text fontSize="2xs" as="i" color="gray.500">
						Image source: My computer
					</Text>
				</VStack>
			</Center>
			<Text mt={5}>
				Small blue circle in the above diagram is called a hidden unit
				of a neural network or a neuron (if you want to draw analogies
				with biological neural networks) and black arrows can be thought
				of as a synapse. Collection of these hidden units (rectangular
				box of circles) is one hidden layer in a neural network and
				multi-layer networks may have one or more than one hidden
				layers. Hidden units are some-sort-of black boxes that take
				input from the previous layer, do some magic and produce output.
				‘Magic’ in the previous sentence refers to a non-linear
				mathematical function applied to the sum of the product of
				inputs and ‘weights’. AI training or in this case neural network
				training is just a process of finding the right set of ‘weights’
				that would make our neural network work the way we intended.
				Non-linear mathematical functions are used to model complex
				goals (input-output) and they’re referred to as activation
				functions. We won’t dwell into questions like how activation
				functions are chosen, why non-linear functions are chosen, and
				why a particular activation function is chosen? Answers to these
				questions will make up a separate blog. For now, let’s choose
				sigmoid as an activation function. Mathematically it is
				represented as{' '}
				<MathJaxTex
					text="$
						z=\sum\limits_{i=0}^{n}{w_i*x_i}
					  $"
				/>{' '}
				and{' '}
				<MathJaxTex
					text="$
						y=\sigma({z})=\frac{1}{1+e^{-z}}
					  $"
				/>{' '}
				, where <MathJaxTex text="$y$" /> is output and{' '}
				<MathJaxTex text="$x$" /> is collection of{' '}
				<MathJaxTex text="$n$" /> input features. For example: An image
				input (ex: face) is just a collection of multiple pixels
				represented in the form of matrix and pixels in an image are its
				simplest features (
				<a
					style={{ color: 'blue', textDecoration: 'underline' }}
					onClick={() => navigate('/blog/feel-the-pixels')}
				>
					feel the pixels
				</a>
				). <MathJaxTex text="$y$" /> is then passed to subsequent layers
				as input and this process continues till the output layer.
				Iterating over each element <MathJaxTex text="$x_i$" /> in the
				input vector <MathJaxTex text="$x$" />, computing its product
				with <MathJaxTex text="$w_i$" />
				and summing them is a compute intensive process. To make it
				computationally efficient, let’s recall our linear algebra
				classes and the definition of dot product of two vectors,{' '}
				<MathJaxTex text="$a=\left[\begin{array}{cc}x_1 & x_2\end{array}\right]$" />
				,{' '}
				<MathJaxTex text="$b=\left[\begin{array}{cc}y_1 & y_2\end{array}\right]$" />
				, then{' '}
				<MathJaxTex text="$a \cdot b = x_1y_1+x_2y_2 = x^Ty= \sum\limits_{i=1}^{2}{w_i*y_i}$" />
				. GPUs are very fast at matrix multiplications, so representing
				our equation in the form of matrix/vector product will help us
				solve the equation for <MathJaxTex text="$z$" /> very
				efficiently. In vector notation, <MathJaxTex text="$z=w^Tx$" />{' '}
				and <MathJaxTex text="$y=\sigma({z})$" />. With a hope that, you
				now have a basic idea of how neural networks are represented. As
				mentioned earlier, AI/neural network training refers to finding
				optimal set of weights <MathJaxTex text="$w$" /> so that we get
				correct <MathJaxTex text="$y$" />. AI model cannot give correct{' '}
				<MathJaxTex text="$y$" /> if it doesn’t know what kind of output
				is expected i.e., if an image of face is passed to a model that
				doesn’t know whether output should be an age or a mood. So, a
				large set of known input-output <MathJaxTex text="$x$" /> -
				<MathJaxTex text="$\ t$" /> pairs are collected (called dataset)
				and model is trained on this dataset to find actual
				output/target <MathJaxTex text="$t$" /> given an input
				<MathJaxTex text="$x$" />. For example, a large pair of images
				with face (input) and age (output) are collected to train a
				model that should predict the age, and pairs of images with
				face-mood, if the model should predict mood. Initially, model
				weights <MathJaxTex text="$w$" /> are initialized to zero or
				small random numbers. It is obvious that the output{' '}
				<MathJaxTex text="$y$" /> given initial weights and a input
				<MathJaxTex text="$\ x$" /> won’t match with original output{' '}
				<MathJaxTex text="$t$" />, because we don’t know the correct
				magic potion (weights) yet. To find correct set of weights,
				error between predicted output <MathJaxTex text="$y$" /> and
				actual output <MathJaxTex text="$t$" /> is computed and modeled
				using a mathematical function. Least-square error is one of the
				many error functions commonly used{' '}
				<MathJaxTex text="$E=\frac{1}{2}\sum\limits_{j}^{m}{(y_j-t_j)^2}$" />
				, where <MathJaxTex text="$m$" /> is the total number of
				datapoints available or size of dataset. If you look closely, it
				is nothing but euclidian distance between two points{' '}
				<MathJaxTex text="$y_j$" /> and <MathJaxTex text="$t_j$" />. To
				bring predicted output <MathJaxTex text="$y_j$" /> and actual
				output <MathJaxTex text="$t_j$" />, we want this distance to be
				as low as possible and ideally zero. In other words, we want to
				change weights such that we get to a point of zero (or low)
				error. In simple terms, we want to find minima of the function{' '}
				<MathJaxTex text="$E$" /> and if you recall calculus, minima of
				a function is where its derivative is zero i.e.,{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{w}}=0$" />.
			</Text>
			<Center>
				<VStack>
					<Image
						width="250px"
						src="https://pub-17b7496e137e40fcbe7057d6a4735482.r2.dev/backpropagation-the-search-for-a-new-god:2.png"
						alt="backprop-5"
					/>
					<Text fontSize="2xs" as="i" color="gray.500">
						Image source: My computer
					</Text>
				</VStack>
			</Center>
			<Text mt={5}>
				If we want to move to a destination (in our case, low error), we
				should know the direction to move and if you recall calculus,
				derivative of function gives us the direction.{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{w}} > 0$" />{' '}
				means if we increase <MathJaxTex text="$w$" />, we move in
				positive direction and{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{w}} < 0$" />{' '}
				means, we move in negative direction (or towards minima) if we
				increase <MathJaxTex text="$w$" />. So, mathematically this
				process of weight change can be represented by{' '}
				<MathJaxTex text="$w = w - l*\frac{\partial{E}}{\partial{w}} < 0$" />{' '}
				, where <MathJaxTex text="$l$" /> is the learning rate that
				tells how fast or slow we should move/change the weights and
				negative sign indicates movement in opposite direction of the
				derivative (or gradient) to attain minima. Now, we have got a
				basic learning algorithm to train an AI model and this
				particular one is called gradient descent (there are different
				versions of GD and we won’t get to them now). Using chain rule
				of derivatives we can derive{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{w}}$" /> as
				(derive it yourself to verify):{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{w}} = \frac{\partial{E}}{\partial{z}}*x\\$" />{' '}
				<MathJaxTex text="$\frac{\partial{E}}{\partial{z}} = \frac{\partial{E}}{\partial{y}}*\sigma'({x})\\$" />
				<MathJaxTex text="$\frac{\partial{E}}{\partial{y}} = (y-t)\\$" />
				If you see closely observe, we are sending gradients (or partial
				derivatives) back to node <MathJaxTex text="$w$" />, all the way
				from <MathJaxTex text="$E$" />. In simple words, gradients are
				backpropagated and this is called backpropagation. So, AI
				training is nothing but computing a predicted output{' '}
				<MathJaxTex text="$y$" /> and minimizing the Error{' '}
				<MathJaxTex text="$E$" /> between
				<MathJaxTex text="$\ y$" /> and actual output{' '}
				<MathJaxTex text="$t$" /> by backpropagating gradients to
				compute optimal sets of weights by repeating the process
				repeatedly until error is minimized ‘enough’. Performing above
				computations for large datasets and for large amounts of
				repetitions is a Herculean task. Imagine computing millions of
				partial derivatives (gradients) in a large network, repeatedly
				for weeks, months or even years. Well, we need to represent the
				derivatives of complex error functions in a better way. As
				mentioned earlier, partial derivatives of an error function are
				computed using chain rule. Let’s forget earlier notations and
				try to derive some relations using chain rule Using chain rule,
				<MathJaxTex text="$\\if, \ z=f(x(t), y(t))\\\frac{\partial{z}}{\partial{t}}=\frac{\partial{z}}{\partial{x}}*\frac{\partial{x}}{\partial{t}}+\frac{\partial{z}}{\partial{y}}*\frac{\partial{y}}{\partial{t}}\\$" />
				Let’s call 
				<MathJaxTex text="$\frac{\partial{z}}{\partial{t}} = \bar{t}; \ \frac{\partial{z}}{\partial{x}} = \bar{x}; \ \frac{\partial{z}}{\partial{y}} = \bar{y}$" />{' '}
				and rewrite the chain rule equation as{' '}
				<MathJaxTex text="$\bar{t}=\bar{x}* \frac{\partial{x}}{\partial{t}}+ \bar{y} * \frac{\partial{y}}{\partial{t}}$" />
			</Text>
			<Center>
				<VStack>
					<HStack>
						<Image
							width="250px"
							src="https://pub-17b7496e137e40fcbe7057d6a4735482.r2.dev/backpropagation-the-search-for-a-new-god:3.png"
							alt="backprop-2"
						/>
						<Image
							width="250px"
							src="https://pub-17b7496e137e40fcbe7057d6a4735482.r2.dev/backpropagation-the-search-for-a-new-god:4.png"
							alt="backprop-3"
						/>
					</HStack>
					<Text fontSize="2xs" as="i" color="gray.500">
						Image source: My computer
					</Text>
				</VStack>
			</Center>
			<Text mt={5}>
				To explain in simple terms, we would already know{' '}
				<MathJaxTex text="$\bar{x}$" /> and{' '}
				<MathJaxTex text="$\bar{y}$" /> before arriving at{' '}
				<MathJaxTex text="$\bar{t}$" />
				(remember we are traveling backwards), and we only have to
				compute <MathJaxTex text="$\frac{\partial{x}}{\partial{t}}$" />{' '}
				and <MathJaxTex text="$\frac{\partial{y}}{\partial{t}}$" />,
				which is relatively an easy thing to do, as{' '}
				<MathJaxTex text="$\frac{\partial{x}}{\partial{t}}$" /> and{' '}
				<MathJaxTex text="$\frac{\partial{y}}{\partial{t}}$" /> are
				nothing but partial derivatives of very simple/primitive
				functions <MathJaxTex text="$x(t)$" /> and{' '}
				<MathJaxTex text="$y(t)$" />. To generalize,
				<MathJaxTex text="$\ \bar{v_i}=\sum\limits_{n-1}^{1}{\bar{v_j}*\frac{\partial{y_i}}{\partial{x_i}}}; \ j \in children(v_i)$" />{' '}
				where <MathJaxTex text="$i$" /> is parent node and{' '}
				<MathJaxTex text="$j$" /> is a child node. In{' '}
				<MathJaxTex text="$x(t)=t$" />, <MathJaxTex text="$\ t$" />{' '}
				would be parent node and <MathJaxTex text="$x$" /> would be a
				child node. Computers are good with vectors and{' '}
				<MathJaxTex text="$\sum{xy}=a \cdot b = x^Ty$" /> , remember? In
				vector notations, <MathJaxTex text="$\\\bar{v_i} = v^TJ\\$" />
				<MathJaxTex text="$J=\frac{\partial{y}}{\partial{x}}=\left[\begin{array}{ccc}\frac{\partial{y_1}}{\partial{x_1}} & ... & \frac{\partial{y_1}}{\partial{x_n}}\\\frac{\partial{y_2}}{\partial{x_1}} & ... & \frac{\partial{y_2}}{\partial{x_n}}\\\frac{\partial{y_m}}{\partial{x_1}} & ... & \frac{\partial{y_m}}{\partial{x_n}}\end{array}\right]$" />{' '}
				is called Jacobian and is a matrix of partial derivatives of
				primitives. I've skipped many things in an attempt to simplify
				the concept and would probably write a detailed and more math
				intensive blog later. Enough with math and theory, let’s see how
				AI training is implemented with a simple code without use of
				highly abstracted frameworks in next blog (in draft)
			</Text>
		</Box>
	);
};

export default BackPropTheOverlord;
