Skip to content

Commit

Permalink
Fix PostCommit Python Examples Spark/Direct jobs (#33781)
Browse files Browse the repository at this point in the history
* Fix distribopt_test.py for NumPy 2

* Fix distribopt.py for NumPy 2
  • Loading branch information
Amar3tto authored Jan 28, 2025
1 parent 8f32cc4 commit de92049
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 2 deletions.
17 changes: 16 additions & 1 deletion sdks/python/apache_beam/examples/complete/distribopt.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,28 @@ def _optimize_production_parameters(sim):

# Run L-BFGS-B optimizer
result = minimize(lambda x: np.sum(sim.simulate(x)), x0, bounds=bounds)
return result.x.tolist(), sim.simulate(result.x)

# Ensure result.x is always a list, regardless of NumPy version
x_values = result.x if isinstance(result.x, list) else result.x.tolist()

# Ensure simulation output is also properly converted
costs = sim.simulate(result.x)
costs = costs if isinstance(costs, list) else costs.tolist()

return x_values, costs

def process(self, element):
mapping_identifier, greenhouse = element[0]
crops, quantities = zip(*element[1])
sim = Simulator(quantities)
optimum, costs = self._optimize_production_parameters(sim)

# Ensure NumPy arrays are converted to lists before yielding
if isinstance(optimum, np.ndarray):
optimum = optimum.tolist()
if isinstance(costs, np.ndarray):
costs = costs.tolist()

solution = (mapping_identifier, (greenhouse, optimum))
yield pvalue.TaggedOutput('solution', solution)
for crop, cost, quantity in zip(crops, costs, quantities):
Expand Down
5 changes: 4 additions & 1 deletion sdks/python/apache_beam/examples/complete/distribopt_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,11 @@ def test_basics(self):
# Only 1 result
self.assertEqual(len(lines), 1)

# Handle NumPy string representation before parsing
cleaned_line = lines[0].replace("np.str_('", "'").replace("')", "'")

# parse result line and verify optimum
optimum = make_tuple(lines[0])
optimum = make_tuple(cleaned_line)
self.assertAlmostEqual(optimum['cost'], 454.39597, places=3)
self.assertDictEqual(optimum['mapping'], EXPECTED_MAPPING)
production = optimum['production']
Expand Down

0 comments on commit de92049

Please sign in to comment.