543. Diameter of Binary Tree

Jacob
4 min readDec 1, 2022

Solution to Leetcode Problem 543

Photo by Gilly Stewart on Unsplash

Introduction

Binary Tree problems can be very tricky to solve, because by nature, they almost always require recursion. This means an understanding of the Binary Tree Data Structure + Recursion, both of which are whole sub-topics on their own.

As a result, this is definitely one of the harder ‘Easy’ problems in Leetcode, but as we’ll see, it becomes easier to understand once we begin breaking it down.

The Problem (from Leetcode)

Question: Given the root of a binary tree, return the length of the diameter of the tree.

The diameter of a binary tree is the length of the longest path between any two nodes in a tree. This path may or may not pass through the root.

The length of a path between two nodes is represented by the number of edges between them.

Source: Leetcode

Key Terms

  • Diameter: length of the longest path between any two nodes in a tree. This path may or may not pass through the root.
  • Edge: a path between two adjacent nodes
  • Node: a data structure which stores a value, a left pointer, and a right pointer
  • Height: the length of a path from a node to a leaf node
  • Leaf Node: a final node at the end of a tree with no child nodes

The Solution

The key to the solution lies in understanding that we require a function that is able to do three things:

  1. traverse through the tree recursively
  2. compare and set the maximum diameter as it traverses the tree
  3. return the height of a node
class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
# initialise and store the longest diameter in a variable
self.diameter = 0

A global variable diameter is initialised, which will be used in the recursive function we will create just below.

# traverse through the tree depth first
def calculateNodeHeight(root: Optional[TreeNode]) -> int:
# BASE CASE: when hitting a None node, return height of zero.
if not root:
return 0

# return height of left child nodes and right child nodes
left_height = calculateNodeHeight(root.left)
right_height = calculateNodeHeight(root.right)

# compare the current largest diameter against the sum of left_height + right_height
self.diameter = max(self.diameter, left_height + right_height)

# if node is not None, return height of current node, which is the larger of either the left or right child node, + 1
return max(left_height, right_height) + 1

Let’s take it step-by-step with the recursive function calculateNodeHeight

  1. The Base Case, which is necessary to end a recursive function, returns 0 when a node is None . A node that doesn’t exist has a height of 0.
  2. If the node exists, we want to find out the height of the left child, and so we recursively call calculateNodeHeight to get the height.
  3. Similarly, we also want to find out the height of the right child, and we recursively call calculateNodeHeight to get the height.
  4. We want to see if the path (height of left child + height of right child) that passes through the CURRENT node, is larger than the current value stored in diameter . If so, the larger value is stored by using max(self.diameter, left_height + right_height)
  5. Finally, if the node exists, we return the larger height of either left or right child. As we have moved up a node along a path, 1 must be added to the height.

We can see how this plays out as the recursive function progresses through the tree:

# initial call
diameter = 0
left_height = calculateNodeHeight(2)
right_height = calculateNodeHeight(3)
diameter = max(diameter, calculateNodeHeight(2) + calculateNodeHeight(3))

left_height = calculateNodeHeight(2)
right_height = (max(0, 0) + 1) # height of left and right for node(3) is 0
diameter = max(0, calculateNodeHeight(2) + (max(0, 0) + 1))

left_height = max(calculateNodeHeight(4), calculateNodeHeight(5)) + 1
right_height = 1
diameter = max(0, max(calculateNodeHeight(4), calculateNodeHeight(5)) + 1)

# similar to above, height of left and right for both node(4) and node(5) is zero
left_height = max((max(0, 0) + 1), (max(0, 0) + 1)) + 1
right_height = 1
diameter = max(0, max((max(0, 0) + 1), (max(0, 0) + 1)) + 1)

left_height = max(1, 1) + 1
right_height = 1
diameter = max(0, max(1, 1) + 1)

left_height = 2
right_height = 1
diameter = max(0, 2 + 1)

# finally, the diameter calcualted is 3
diameter = 3

Finally, we call the recursive function inside the parent diameterOfBinaryTree , and return diameter which will have been changed by calculateNodeHeight .

class Solution:
def diameterOfBinaryTree(self, root: Optional[TreeNode]) -> int:
# need to store the longest diameter in a variable
self.diameter = 0


# traverse through the tree depth first
def calculateNodeHeight(root: Optional[TreeNode]) -> int:
# BASE CASE: when hitting a None node, return height of zero.
if not root:
return 0

# return height of left child nodes and right child nodes
left_height = calculateNodeHeight(root.left)
right_height = calculateNodeHeight(root.right)

# compare the current largest diameter against the sum of left_height + right_height
self.diameter = max(self.diameter, left_height + right_height)

# if node is not None, return height of current node, which is the larger of either the left or right child node, + 1
return max(left_height, right_height) + 1

calculateNodeHeight(root)
return self.diameter

Conclusion

That’s it! To be honest, this problem should probably be at least a ‘Medium’ difficulty as it requires a decent understanding of recursion which can be tough to wrap your head around.

However, once you grasp the concept, and move through each recursive call through the Call Stack step-by-step, things will begin to fall into place.

Onwards with the Leetcode grind. Goodluck!

--

--

Jacob

A Singaporean perspective. Sometimes I write. Sometimes I code. Mostly I watch lots of movies.