class Variable:
_next_id = 0 # Class variable to keep track of the next available ID
@staticmethod
def get_next_id():
current_id = Variable._next_id
Variable._next_id += 1
return current_id
def __init__(self, data, requires_grad=False, parents=None, op=None):
if not isinstance(data, np.ndarray):
data = np.array(data, dtype=float)
self.id = Variable.get_next_id() # Assign unique ID to each instance
self.data = data
self.requires_grad = requires_grad
self.parents = parents if parents else []
self.op = op
self.grad = None # Will hold the gradient w.r.t. this Variable's data
# For convenience, store backward function for this node
self._grad_fn = None
# If this is a leaf node requiring grad, initialize the grad to zero
if self.requires_grad:
self.zero_grad()
def zero_grad(self):
"""Reset this Variable's gradient to all zeros."""
self.grad = np.zeros_like(self.data, dtype=float)
for parent in self.parents:
parent.zero_grad()
def set_grad_fn(self, fn):
"""Attach a gradient callback that calculates d(out)/d(this) during backprop."""
self._grad_fn = fn
def backward(self, grad=None):
"""
Backpropagate through this Variable.
:param grad: The gradient from the 'upstream' node. If None,
and the Variable is a scalar (i.e. shape == ()),
we'll use a gradient of 1.0 by default.
"""
if not self.requires_grad:
# If this Variable doesn't require grad, we do nothing
return
if grad is None:
# If it's a scalar, start with d(output)/d(output) = 1
if self.data.size == 1:
grad = np.ones_like(self.data)
else:
raise RuntimeError(
"grad must be specified for non-scalar Variable."
)
# Accumulate the gradient in self.grad
self.grad += grad
# Call the stored gradient function (if any) to propagate grads to parents
if self._grad_fn is not None:
grads_to_parents = self._grad_fn(grad)
# grads_to_parents is a list of gradients w.r.t each parent
for parent, g in zip(self.parents, grads_to_parents):
parent.backward(g)
def __repr__(self):
return f"Variable(id={self.id}, data={self.data}, grad={self.grad}, requires_grad={self.requires_grad}, op={self.op})\\n"
def parameters(self):
"""
Collect all descendant Variables originating from this Variable.
"""
children = set([self])
def recurse(variable):
if variable.op and hasattr(variable, 'parents'):
for parent in variable.parents:
if parent not in children:
children.add(parent)
recurse(parent)
recurse(self)
return children
def __add__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
out = Variable(self.data + other.data,
requires_grad=(self.requires_grad or other.requires_grad),
parents=[self, other] if self.id != other.id else [self],
op='add')
def _grad_fn(grad_out):
"""
dL/dx = grad_out * 1
dL/dy = grad_out * 1
"""
grad_self = grad_out
grad_other = grad_out
return [grad_self, grad_other]
out.set_grad_fn(_grad_fn)
return out
def __radd__(self, other):
return self.__add__(other)
def __sub__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
out = Variable(self.data - other.data,
requires_grad=(self.requires_grad or other.requires_grad),
parents=[self, other] if self.id != other.id else [self],
op='sub')
def _grad_fn(grad_out):
"""
dL/dx = grad_out * 1
dL/dy = grad_out * -1
"""
grad_self = grad_out
grad_other = -grad_out
return [grad_self, grad_other]
out.set_grad_fn(_grad_fn)
return out
def __rsub__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
return other.__sub__(self)
def __mul__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
out = Variable(self.data * other.data,
requires_grad=(self.requires_grad or other.requires_grad),
parents=[self, other] if self.id != other.id else [self],
op='mul')
def _grad_fn(grad_out):
"""
dL/dx = grad_out * y
dL/dy = grad_out * x
"""
grad_self = grad_out * other.data
grad_other = grad_out * self.data
return [grad_self, grad_other]
out.set_grad_fn(_grad_fn)
return out
def __rmul__(self, other):
return self.__mul__(other)
def __truediv__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
out = Variable(self.data / other.data,
requires_grad=(self.requires_grad or other.requires_grad),
parents=[self, other],
op='div')
def _grad_fn(grad_out):
"""
If z = x / y,
dL/dx = grad_out * (1 / y)
dL/dy = grad_out * (-x / y^2)
"""
grad_self = grad_out * (1.0 / other.data)
grad_other = grad_out * (-self.data / (other.data ** 2))
return [grad_self, grad_other]
out.set_grad_fn(_grad_fn)
return out
def __rtruediv__(self, other):
other = other if isinstance(other, Variable) else Variable(other)
return other.__truediv__(self)
# Implement unary negation for use with expressions like -x
def __neg__(self):
out = Variable(-self.data,
requires_grad=self.requires_grad,
parents=[self],
op='neg')
def _grad_fn(grad_out):
return [-grad_out]
out.set_grad_fn(_grad_fn)
return out