Skip to content

Commit

Permalink
Fix black
Browse files Browse the repository at this point in the history
  • Loading branch information
sovrasov committed Jun 21, 2024
1 parent 82dd778 commit bcaca78
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions tests/python/accuracy/test_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,10 +124,7 @@ def test_image_models(data, dump, result, model_data):
encoder_model = eval(model_data["encoder_type"])(
encoder_adapter, configuration={}, preload=True
)
model = eval(model_data["prompter"])(
encoder_model,
model
)
model = eval(model_data["prompter"])(encoder_model, model)

if dump:
result.append(model_data)
Expand All @@ -144,10 +141,26 @@ def test_image_models(data, dump, result, model_data):
image = np.stack([image for _ in range(8)])
if "prompter" in model_data:
if model_data["prompter"] == "SAMLearnableVisualPrompter":
model.learn(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0]))
model.learn(
image,
points=Prompt(
np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(
1, 2
),
[0],
),
)
outputs = model(image)
else:
outputs = model(image, points=Prompt(np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(1, 2), [0]))
outputs = model(
image,
points=Prompt(
np.array([image.shape[0] / 2, image.shape[1] / 2]).reshape(
1, 2
),
[0],
),
)
else:
outputs = model(image)
if isinstance(outputs, ClassificationResult):
Expand Down

0 comments on commit bcaca78

Please sign in to comment.