Skip to content

Commit

Permalink
Extract clean workspace connection provider (#2411)
Browse files Browse the repository at this point in the history
# Description

Please add an informative description that covers that changes made by
the pull request and link all relevant issues.

# All Promptflow Contribution checklist:
- [ ] **The pull request does not introduce [breaking changes].**
- [ ] **CHANGELOG is updated for new features, bug fixes or other
significant changes.**
- [ ] **I have read the [contribution guidelines](../CONTRIBUTING.md).**
- [ ] **Create an issue and link to the pull request to get dedicated
review from promptflow team. Learn more: [suggested
workflow](../CONTRIBUTING.md#suggested-workflow).**

## General Guidelines and Best Practices
- [ ] Title of the pull request is clear and informative.
- [ ] There are a small number of commits, each of which have an
informative message. This means that previously merged commits do not
appear in the history of the PR. For more information on cleaning up the
commits in your PR, [see this
page](https://github.com/Azure/azure-powershell/blob/master/documentation/development-docs/cleaning-up-commits.md).

### Testing Guidelines
- [ ] Pull request includes test coverage for the included changes.

---------

Signed-off-by: Brynn Yin <[email protected]>
  • Loading branch information
brynn-code authored Mar 21, 2024
1 parent c1f3dcf commit 0ae6a0a
Show file tree
Hide file tree
Showing 22 changed files with 9,205 additions and 10,956 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"src/promptflow/promptflow/azure/_restclient/flow/**",
"src/promptflow/promptflow/azure/_restclient/swagger.json",
"src/promptflow/promptflow/azure/_models/**",
"src/promptflow/promptflow/core/_connection_provider/_models/**",
"src/promptflow/tests/**",
"src/promptflow-tools/tests/**",
"**/flow.dag.yaml",
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks

exclude: '(^docs/)|flows|scripts|src/promptflow/promptflow/azure/_restclient/|src/promptflow/promptflow/azure/_models/|src/promptflow/tests/test_configs|src/promptflow-tools'
exclude: '(^docs/)|flows|scripts|src/promptflow/promptflow/azure/_restclient/|src/promptflow/promptflow/core/_connection_provider/_models/|src/promptflow/promptflow/azure/_models/|src/promptflow/tests/test_configs|src/promptflow-tools'

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ exclude =
build
src/promptflow/promptflow/azure/_restclient
src/promptflow/promptflow/azure/_models
src/promptflow/promptflow/core/_connection_provider/_models
src/promptflow/tests/test_configs/*
import-order-style = google

Expand Down
31 changes: 8 additions & 23 deletions src/promptflow/promptflow/_sdk/entities/_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@


class _Connection(_CoreConnection, YAMLTranslatableMixin):
SUPPORTED_TYPES = {}

@classmethod
def _casting_type(cls, typ):
type_dict = {
Expand Down Expand Up @@ -133,10 +135,11 @@ def _resolve_cls_and_type(cls, data, params_override=None):
if type_str is None:
raise ValidationException("type is required for connection.")
type_str = cls._casting_type(type_str)
type_cls = _supported_types.get(type_str)
type_cls = cls.SUPPORTED_TYPES.get(type_str)
if type_cls is None:
raise ValidationException(
f"connection_type {type_str!r} is not supported. Supported types are: {list(_supported_types.keys())}"
f"Connection type {type_str!r} is not supported. "
f"Supported types are: {list(cls.SUPPORTED_TYPES.keys())}"
)
return type_cls, type_str

Expand Down Expand Up @@ -208,26 +211,6 @@ def _load(
)
return connection

def _to_execution_connection_dict(self) -> dict:
value = {**self.configs, **self.secrets}
secret_keys = list(self.secrets.keys())
return {
"type": self.class_name, # Required class name for connection in executor
"module": self.module,
"value": {k: v for k, v in value.items() if v is not None}, # Filter None value out
"secret_keys": secret_keys,
}

@classmethod
def _from_execution_connection_dict(cls, name, data) -> "_Connection":
type_cls, _ = cls._resolve_cls_and_type(data={"type": data.get("type")[: -len("Connection")]})
value_dict = data.get("value", {})
if type_cls == CustomConnection:
secrets = {k: v for k, v in value_dict.items() if k in data.get("secret_keys", [])}
configs = {k: v for k, v in value_dict.items() if k not in secrets}
return CustomConnection(name=name, configs=configs, secrets=secrets)
return type_cls(name=name, **value_dict)

def _get_scrubbed_secrets(self):
"""Return the scrubbed secrets of connection."""
return {key: val for key, val in self.secrets.items() if self._is_scrubbed_value(val)}
Expand Down Expand Up @@ -550,7 +533,9 @@ def _from_mt_rest_object(cls, mt_rest_obj):
)


_supported_types = {
# Note: Do not import this from core connection.
# As we need the class here.
_Connection.SUPPORTED_TYPES = {
v.TYPE: v
for v in globals().values()
if isinstance(v, type) and issubclass(v, _Connection) and not v.__name__.startswith("_")
Expand Down
32 changes: 4 additions & 28 deletions src/promptflow/promptflow/azure/_utils/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,33 +4,16 @@

import jwt

from promptflow.exceptions import ValidationException
from promptflow.core._connection_provider._utils import get_arm_token, get_token


def is_arm_id(obj) -> bool:
return isinstance(obj, str) and obj.startswith("azureml://")


def get_token(credential, resource) -> str:
from azure.ai.ml._azure_environments import _resource_to_scopes

azure_ml_scopes = _resource_to_scopes(resource)
token = credential.get_token(*azure_ml_scopes).token
# validate token has aml audience
decoded_token = jwt.decode(
token,
options={"verify_signature": False, "verify_aud": False},
)
if decoded_token.get("aud") != resource:
msg = """AAD token with aml scope could not be fetched using the credentials being used.
Please validate if token with {0} scope can be fetched using credentials provided to PFClient.
Token with {0} scope can be fetched using credentials.get_token({0})
"""
raise ValidationException(
message=msg.format(*azure_ml_scopes),
)

return token
# Add for backward compitability
get_token = get_token
get_arm_token = get_arm_token


def get_aml_token(credential) -> str:
Expand All @@ -40,13 +23,6 @@ def get_aml_token(credential) -> str:
return get_token(credential, resource)


def get_arm_token(credential) -> str:
from azure.ai.ml._azure_environments import _get_base_url_from_metadata

resource = _get_base_url_from_metadata()
return get_token(credential, resource)


def get_authorization(credential=None) -> str:
token = get_arm_token(credential=credential)
return "Bearer " + token
Expand Down
Loading

0 comments on commit 0ae6a0a

Please sign in to comment.