diff --git a/ot/solvers.py b/ot/solvers.py index decf6177e..5f8f65870 100644 --- a/ot/solvers.py +++ b/ot/solvers.py @@ -1040,9 +1040,17 @@ def solve_gromov( # potentials = (log['u'], log['v']) TODO else: # partial FGW - if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): - raise (ValueError("Partial FGW mass given in reg is too large")) + if unbalanced is None: + raise ( + ValueError( + "Partial GW mass given in `unbalanced` must be float and not None" + ) + ) + elif unbalanced > nx.sum(a) or unbalanced > nx.sum(b): + raise ( + ValueError("Partial GW mass given in `unbalanced` is too large") + ) # default values for solver if max_iter is None: max_iter = 1000