visual.draw_tensor_graph¶
- lucid.visual.draw_tensor_graph(tensor: Tensor, horizontal: bool = False, title: str | None = None, start_id: int | None = None) Figure ¶
Visualizes the computational graph of a given Tensor object, showing the flow of operations and tensor dependencies. This function is useful for understanding how gradients propagate through a model during backpropagation.
Function Signature¶
def draw_tensor_graph(
tensor: Tensor,
horizontal: bool = False,
title: str | None = None,
start_id: int | None = None,
) -> plt.Figure
Parameters¶
tensor (Tensor): The root output tensor from which to start the graph traversal.
horizontal (bool, optional): If True, the graph is drawn left-to-right. Defaults to top-down.
title (str or None, optional): Optional title for the graph plot.
start_id (int or None, optional): If provided, highlights the tensor with the specified ID in blue.
Returns¶
plt.Figure: The matplotlib Figure object containing the plotted graph.
Example¶
import lucid
import lucid.nn.functional as F
from lucid.visual import draw_tensor_graph
x = lucid.random.rand(1, 3, 8, 8, requires_grad=True)
w = lucid.random.randn(4, 3, 3, 3, requires_grad=True)
b = lucid.random.randn(4, requires_grad=True)
out = F.conv2d(x, w, b, stride=1, padding=1)
fig = draw_tensor_graph(out, horizontal=True, title="Conv2D Output Graph")
fig.show()
Note
The visualization shows Tensor shapes and operations with color coding:
lightgreen: operations
lightblue: intermediate tensors requiring grad
lightgray: intermediate tensors not requiring grad
violet: Parameter tensors
red: the output tensor
yellow: tensor marked by start_id