Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CC authorizers #3052

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions nvflare/app_opt/confidential_computing/aci_authorizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import time

import jwt
import requests
from jwt import PyJWKClient

from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer

ACI_NAMESPACE = "x-ms"
maa_endpoint = "sharedeus2.eus2.attest.azure.net"


class ACIAuthorizer(CCAuthorizer):
def __init__(self, retry_count=5, retry_sleep=2):
self.retry_count = retry_count
self.retry_sleep = retry_sleep

def generate(self):
count = 0
token = ""
while True:
count = count + 1
try:
r = requests.post(
"http://localhost:8284/attest/maa",
IsaacYangSLA marked this conversation as resolved.
Show resolved Hide resolved
data=json.dumps({"maa_endpoint": maa_endpoint, "runtime_data": "ewp9"}),
headers={"Content-Type": "application/json"},
)
if r.status_code == requests.codes.ok:
token = r.json().get("token")
break
except:
if count > self.retry_count:
break
time.sleep(self.retry_sleep)
return token

def verify(self, token):
try:
header = jwt.get_unverified_header(token)
alg = header.get("alg")
jwks_client = PyJWKClient(f"https://{maa_endpoint}/certs")
signing_key = jwks_client.get_signing_key_from_jwt(token)
claims = jwt.decode(token, signing_key.key, algorithms=[alg])
if claims:
return True
except:
return False
return False

def get_namespace(self) -> str:
return ACI_NAMESPACE
138 changes: 0 additions & 138 deletions nvflare/app_opt/confidential_computing/cc_helper.py

This file was deleted.

144 changes: 74 additions & 70 deletions nvflare/app_opt/confidential_computing/gpu_authorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,81 +13,85 @@
# limitations under the License.


from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer

GPU_NAMESPACE = "x-nv-gpu-"


class GPUAuthorizer(CCAuthorizer):
"""Note: This is just a fake implementation for GPU authorizer. It will be replaced later
with the real implementation.
"""

def __init__(self, verifiers: list) -> None:
"""
import json
import logging
import uuid

Args:
verifiers (list):
each element in this list is a dictionary and the keys of dictionary are
"devices", "env", "url", "appraisal_policy_file" and "result_policy_file."
import jwt
from nv_attestation_sdk import attestation

the values of devices are "gpu" and "cpu"
the values of env are "local" and "test"
currently, valid combination is gpu + local
url must be an empty string
appraisal_policy_file must point to an existing file
currently supports an empty file only
result_policy_file must point to an existing file
currently supports the following content only
from nvflare.app_opt.confidential_computing.cc_authorizer import CCAuthorizer

.. code-block:: json
GPU_NAMESPACE = "x-nv-gpu"
IsaacYangSLA marked this conversation as resolved.
Show resolved Hide resolved
default_policy = """{
"version":"1.0",
"authorization-rules":{
"sub":"NVIDIA-GPU-ATTESTATION",
"secboot":true,
"x-nvidia-gpu-manufacturer":"NVIDIA Corporation",
"x-nvidia-attestation-type":"GPU",
"x-nvidia-attestation-detailed-result":{
"x-nvidia-gpu-driver-rim-schema-validated":true,
"x-nvidia-gpu-vbios-rim-cert-validated":true,
"x-nvidia-gpu-attestation-report-cert-chain-validated":true,
"x-nvidia-gpu-driver-rim-schema-fetched":true,
"x-nvidia-gpu-attestation-report-parsed":true,
"x-nvidia-gpu-nonce-match":true,
"x-nvidia-gpu-vbios-rim-signature-verified":true,
"x-nvidia-gpu-driver-rim-signature-verified":true,
"x-nvidia-gpu-arch-check":true,
"x-nvidia-gpu-measurements-match":true,
"x-nvidia-gpu-attestation-report-signature-verified":true,
"x-nvidia-gpu-vbios-rim-schema-validated":true,
"x-nvidia-gpu-driver-rim-cert-validated":true,
"x-nvidia-gpu-vbios-rim-schema-fetched":true,
"x-nvidia-gpu-vbios-rim-measurements-available":true
},
"x-nvidia-gpu-driver-version":"535.104.05",
IsaacYangSLA marked this conversation as resolved.
Show resolved Hide resolved
"hwmodel":"GH100 A01 GSP BROM",
"measres":"comparison-successful",
"x-nvidia-gpu-vbios-version":"96.00.5E.00.02"
}
}
"""

{
"version":"1.0",
"authorization-rules":{
"x-nv-gpu-available":true,
"x-nv-gpu-attestation-report-available":true,
"x-nv-gpu-info-fetched":true,
"x-nv-gpu-arch-check":true,
"x-nv-gpu-root-cert-available":true,
"x-nv-gpu-cert-chain-verified":true,
"x-nv-gpu-ocsp-cert-chain-verified":true,
"x-nv-gpu-ocsp-signature-verified":true,
"x-nv-gpu-cert-ocsp-nonce-match":true,
"x-nv-gpu-cert-check-complete":true,
"x-nv-gpu-measurement-available":true,
"x-nv-gpu-attestation-report-parsed":true,
"x-nv-gpu-nonce-match":true,
"x-nv-gpu-attestation-report-driver-version-match":true,
"x-nv-gpu-attestation-report-vbios-version-match":true,
"x-nv-gpu-attestation-report-verified":true,
"x-nv-gpu-driver-rim-schema-fetched":true,
"x-nv-gpu-driver-rim-schema-validated":true,
"x-nv-gpu-driver-rim-cert-extracted":true,
"x-nv-gpu-driver-rim-signature-verified":true,
"x-nv-gpu-driver-rim-driver-measurements-available":true,
"x-nv-gpu-driver-vbios-rim-fetched":true,
"x-nv-gpu-vbios-rim-schema-validated":true,
"x-nv-gpu-vbios-rim-cert-extracted":true,
"x-nv-gpu-vbios-rim-signature-verified":true,
"x-nv-gpu-vbios-rim-driver-measurements-available":true,
"x-nv-gpu-vbios-index-conflict":true,
"x-nv-gpu-measurements-match":true
}
}

"""
super().__init__()
self.verifiers = verifiers
class GPUAuthorizer(CCAuthorizer):
def __init__(self, verifier_url="https://nras.attestation.nvidia.com/v1/attest/gpu", policy_file=None):
IsaacYangSLA marked this conversation as resolved.
Show resolved Hide resolved
self._can_generate = True
self.client = attestation.Attestation()
self.client.set_name("nvflare_node")
IsaacYangSLA marked this conversation as resolved.
Show resolved Hide resolved
nonce = uuid.uuid4().hex + uuid.uuid1().hex
self.client.set_nonce(nonce)
if policy_file is None:
self.remote_att_result_policy = default_policy
else:
self.remote_att_result_policy = open(policy_file).read()
self.client.add_verifier(attestation.Devices.GPU, attestation.Environment.REMOTE, verifier_url, "")
self.logger = logging.getLogger(self.__class__.__name__)

def generate(self):
try:
self.client.attest()
token = self.client.get_token()
except BaseException:
self.can_generate = False
token = "[[],{}]"
return token

def verify(self, eat_token):
try:
jwt_token = json.loads(eat_token)[1]
claims = jwt.decode(jwt_token.get("REMOTE_GPU_CLAIMS"), options={"verify_signature": False})
# With claims, we will retrieve the nonce
nonce = claims.get("eat_nonce")
self.client.set_nonce(nonce)
self.client.set_token(name="nvflare_node", eat_token=eat_token)
result = self.client.validate_token(self.remote_att_result_policy)
except BaseException as e:
self.logger.info(f"Token verification failed {e=}")
result = False
return result

def get_namespace(self) -> str:
return GPU_NAMESPACE

def generate(self) -> str:
raise NotImplementedError

def verify(self, token: str) -> bool:
raise NotImplementedError
Loading
Loading