Python match-case for function simplification
2025/06/19·StepFun can make mistakes, check important info.
Can you use python case statement to simplify this function?
```python
# build layer map
@functools.cache
def build_layer_map(self):
if self.layer_map is not None:
return self.layer_map
if self.pipeline_parallel_size == 1:
assert self.virtual_pipeline_parallel_size == 1, f"only support vpp1 for pp1"
# NOTE: for this case, no need to use num_layers or layer_map, return None is fine
return None
# below is for pp>1 and vpp>1 and user do not provide layer_map
# build default layer map, which is evenly distributed across pp and vpp
assert self.num_layers is not None, f"num_layers is needed to build default layer map for pp{self.pipeline_parallel_size} and vpp{self.virtual_pipeline_parallel_size}"
layer_per_split = self.num_layers // (self.pipeline_parallel_size * self.virtual_pipeline_parallel_size)
# make sure it is divisible
assert self.num_layers % (self.pipeline_parallel_size * self.virtual_pipeline_parallel_size) == 0, f"num_layers must be divisible by (pipeline_parallel_size * virtual_pipeline_parallel_size), but got {self.num_layers} % {self.pipeline_parallel_size * self.virtual_pipeline_parallel_size} != 0"
rows, cols = self.pipeline_parallel_size, self.virtual_pipeline_parallel_size
layer_map = np.array(
[ [ [] for _ in range(cols) ] for _ in range(rows) ], dtype=object
)
# orchestrate the layer map first by pp, then vpp
global_layer_idx = 0
for vpp_rank in range(self.virtual_pipeline_parallel_size):
for pp_rank in range(self.pipeline_parallel_size):
layer_map[pp_rank][vpp_rank] = list(range(global_layer_idx, global_layer_idx + layer_per_split))
global_layer_idx += layer_per_split
return layer_map
```
Certainly! You can use Python's match
statement (introduced in Python 3.10) to simplify the function by replacing the nested if-else
statements with a more readable and concise structure. Here's how you can refactor the build_layer_map
function using the match
statement:
Refactored Function:
Explanation:
-
Caching: The function is still decorated with @cache
to cache the result and avoid redundant computations.
-
Match Statement:
- The
match
statement checks the values of self.pipeline_parallel_size
and self.virtual_pipeline_parallel_size
.
- Case (1, 1): This handles the special case where both
pipeline_parallel_size
and virtual_pipeline_parallel_size
are 1. It returns None
as no layer map is needed.
- Case (pp, vpp) if pp > 1 and vpp > 1: This case handles the general scenario where both parallel sizes are greater than 1. It builds the layer map as per the logic provided in the original function.
- Case _: This is a catch-all case that raises a
ValueError
for any unexpected configurations.
-
Layer Map Construction: The logic for constructing the layer map remains the same as in the original function but is now contained within the appropriate case
block.
Benefits:
- Readability: The
match
statement makes the code more readable by clearly separating different cases.
- Maintainability: Adding new cases or modifying existing ones becomes more straightforward.
- Conciseness: It reduces the nesting depth of
if-else
statements, making the code cleaner.
Note:
Ensure that you are using Python 3.10 or later to utilize the match
statement. If you are using an earlier version of Python, you will need to stick with the traditional if-else
structure.